Source code for ML_Utilities.dataset
# (C) British Crown Copyright 2019, Met Office
#
# This code is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the
# Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This code is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# Create a tf.data Dataset to provide 20CR data to tf.keras models.
import os
import tensorflow as tf
from tensorflow.data import Dataset
from glob import glob
import numpy
[docs]def dataset(purpose='training',source='20CR2c',variable='prmsl',length=None,
shuffle=True,buffer_size=10,reshuffle_each_iteration=False):
"""Provide a :obj:`tf.data.Dataset` of analysis data, for tf.keras model training or tests.
Data must be available in directory $SCRATCH/Machine-Learning-experiments, previously generated by :func:`prepare_data`.
Args:
purpose (:obj:`str`): 'training' (default) or 'test'.
source (:obj:`str`): Where to get the data from - any string, but needs top be supported by :func:`prepare_data`.
variable (:obj:`str`): Variable to fetch (e.g. 'prmsl').
length (:obj:`int`): Required length - will be repeated enough times to deliver at least this many data points. If None (default) uses the amount of data on disc as the length (not repeated).
shuffle (:obj:`bool`): If True (default), shuffle the data order. If False, present the data in the order of the files on disc.
buffer_size (:obj:`int`): Passed to :func:`tf.data.Dataset.shuffle`.
reshuffle_each_iteration (:obj:`bool`): Passed to :func:`tf.data.Dataset.shuffle`.
Returns:
tuple of :obj:`tf.data.Dataset`: Dataset suitable for passing to tf.keras models, and :obj:`int`: length of that dataset.
Raises:
ValueError: Necessary data not on disc, need to run :func:`create_dataset` to make it.
|
"""
# File names for the serialised tensors (made by :func:`create_dataset`)
input_file_dir=("%s/Machine-Learning-experiments/datasets/%s/%s/%s" %
(os.getenv('SCRATCH'),source,variable,purpose))
data_files=glob("%s/*.tfd" % input_file_dir)
if len(data_files)==0:
raise ValueError('No prepared data on disc')
if length is not None and length<len(data_files):
data_files=data_files[:length]
n_steps=len(data_files)
data_tfd = tf.constant(data_files)
# Create TensorFlow Dataset objects from the file names
tr_data = Dataset.from_tensor_slices(data_tfd)
tr_data = tr_data.shuffle(buffer_size=buffer_size,
reshuffle_each_iteration=reshuffle_each_iteration)
nrep=1
if length is not None:
nrep=(length//n_steps)+1
tr_data = tr_data.repeat(nrep)
# We don't want the file names, we want their contents, so
# add a map to convert from names to contents.
def load_tensor(file_name):
sict=tf.read_file(file_name) # serialised
ict=tf.parse_tensor(sict,numpy.float32)
return ict
tr_data = tr_data.map(load_tensor)
return (tr_data,n_steps*nrep)