Source code for ML_Utilities.dataset

# (C) British Crown Copyright 2019, Met Office
#
# This code is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the
# Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This code is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#

# Create a tf.data Dataset to provide 20CR data to tf.keras models.

import os
import tensorflow as tf
from tensorflow.data import Dataset
from glob import glob
import numpy

[docs]def dataset(purpose='training',source='20CR2c',variable='prmsl',length=None, shuffle=True,buffer_size=10,reshuffle_each_iteration=False): """Provide a :obj:`tf.data.Dataset` of analysis data, for tf.keras model training or tests. Data must be available in directory $SCRATCH/Machine-Learning-experiments, previously generated by :func:`prepare_data`. Args: purpose (:obj:`str`): 'training' (default) or 'test'. source (:obj:`str`): Where to get the data from - any string, but needs top be supported by :func:`prepare_data`. variable (:obj:`str`): Variable to fetch (e.g. 'prmsl'). length (:obj:`int`): Required length - will be repeated enough times to deliver at least this many data points. If None (default) uses the amount of data on disc as the length (not repeated). shuffle (:obj:`bool`): If True (default), shuffle the data order. If False, present the data in the order of the files on disc. buffer_size (:obj:`int`): Passed to :func:`tf.data.Dataset.shuffle`. reshuffle_each_iteration (:obj:`bool`): Passed to :func:`tf.data.Dataset.shuffle`. Returns: tuple of :obj:`tf.data.Dataset`: Dataset suitable for passing to tf.keras models, and :obj:`int`: length of that dataset. Raises: ValueError: Necessary data not on disc, need to run :func:`create_dataset` to make it. | """ # File names for the serialised tensors (made by :func:`create_dataset`) input_file_dir=("%s/Machine-Learning-experiments/datasets/%s/%s/%s" % (os.getenv('SCRATCH'),source,variable,purpose)) data_files=glob("%s/*.tfd" % input_file_dir) if len(data_files)==0: raise ValueError('No prepared data on disc') if length is not None and length<len(data_files): data_files=data_files[:length] n_steps=len(data_files) data_tfd = tf.constant(data_files) # Create TensorFlow Dataset objects from the file names tr_data = Dataset.from_tensor_slices(data_tfd) tr_data = tr_data.shuffle(buffer_size=buffer_size, reshuffle_each_iteration=reshuffle_each_iteration) nrep=1 if length is not None: nrep=(length//n_steps)+1 tr_data = tr_data.repeat(nrep) # We don't want the file names, we want their contents, so # add a map to convert from names to contents. def load_tensor(file_name): sict=tf.read_file(file_name) # serialised ict=tf.parse_tensor(sict,numpy.float32) return ict tr_data = tr_data.map(load_tensor) return (tr_data,n_steps*nrep)