Make the tf.data.Dataset inputs and outputs for the DCVAEΒΆ

Takes the inputs specified in the configuration file and creates the tf.data.Dataset inputs and outputs for the DCVAE model.

# Make tf.data.Datasets from ERA5 monthly averages

# This is a generic script to make a TensorFlow Dataset
# Follow the instructions in autoencoder.py to use it.

import os
import sys
import tensorflow as tf
import numpy as np
import random
import zarr
import tensorstore as ts


# Find out how what months are available from a source
def getDataAvailability(source):
    zfile = "%s/DCVAE-Climate/normalized_datasets/%s_zarr" % (
        os.getenv("SCRATCH"),
        source,
    )
    zarr_array = zarr.open(zfile, mode="r")
    AvailableMonths = zarr_array.attrs["AvailableMonths"]
    return AvailableMonths


# Make a set of months available in all of a set of sources
def getMonths(
    sources,
    purpose,
    firstYr,
    lastYr,
    testSplit,
    maxTrainingMonths,
    maxTestMonths,
):
    avail = {}
    months_in_all = None
    for source in sources:
        avail[source] = getDataAvailability(source)
        if months_in_all is None:
            months_in_all = set(avail[source].keys())
        else:
            months_in_all = months_in_all.intersection(set(avail[source].keys()))

    # Filter by range of years
    filtered = []
    for month in months_in_all:
        year = int(month[:4])
        if (firstYr is None or year >= firstYr) and (lastYr is None or year <= lastYr):
            filtered.append(month)
    months_in_all = filtered

    months_in_all.sort()  # Months in time order (validation plots)

    # Test/Train split
    if purpose is not None:
        test_ns = list(range(0, len(months_in_all), testSplit))
        if purpose == "Train":
            months_in_all = [
                months_in_all[x] for x in range(len(months_in_all)) if x not in test_ns
            ]
        elif purpose == "Test":
            months_in_all = [
                months_in_all[x] for x in range(len(months_in_all)) if x in test_ns
            ]
        else:
            raise Exception("Unsupported purpose " + purpose)

    # Limit maximum data size
    if purpose == "Train" and maxTrainingMonths is not None:
        if len(months_in_all) >= maxTrainingMonths:
            months_in_all = months_in_all[0:maxTrainingMonths]
        else:
            raise ValueError(
                "Only %d months available, can't provide %d"
                % (len(months_in_all), maxTrainingMonths)
            )
    if purpose == "Test" and maxTestMonths is not None:
        if len(months_in_all) >= maxTestMonths:
            months_in_all = months_in_all[0:maxTestMonths]
        else:
            raise ValueError(
                "Only %d months available, can't provide %d"
                % (len(months_in_all), maxTestMonths)
            )

    # Return a list of months
    #  and a list of lists of indices - onje list per source
    indices = {}
    for source in sources:
        indices[source] = []
        for key in months_in_all:
            indices[source].append(avail[source][key])
    return months_in_all, indices


# Get a dataset
def getDataset(specification, purpose):
    # Get a list of months to use - inputs
    inMonths, inIndices = getMonths(
        specification["inputTensors"],
        purpose,
        specification["startYear"],
        specification["endYear"],
        specification["testSplit"],
        specification["maxTrainingMonths"],
        specification["maxTestMonths"],
    )

    # If the outputs are not the same as the inputs, get them too and use only months in both
    if (
        specification["outputTensors"] is not None
    ):  # I.e. input and output are not the same
        outMonths, outIndices = getMonths(
            specification["outputTensors"],
            purpose,
            specification["startYear"],
            specification["endYear"],
            specification["testSplit"],
            specification["maxTrainingMonths"],
            specification["maxTestMonths"],
        )

        outMonths = sorted(
            list(set(inMonths).intersection(set(outMonths)))
        )  # Shared Months
        if len(outMonths) != len(inMonths):
            raise ValueError(
                "Input and output tensors have different months available"
            )  # Deal with this when it becomes a problem

    # Create TensorFlow Dataset object from the date strings
    tnIData = tf.data.Dataset.from_tensor_slices(tf.constant(inMonths))

    # Open all the source tensorstores
    tsa_in = {}
    for source in specification["inputTensors"]:
        zfile = "%s/DCVAE-Climate/normalized_datasets/%s_zarr" % (
            os.getenv("SCRATCH"),
            source,
        )
        tsa_in[source] = ts.open(
            {
                "driver": "zarr",
                "kvstore": "file://" + zfile,
            }
        ).result()

    # Map functions to get tensors from dates and indices
    def load_input_tensors_from_month_py(month):
        mnth = month.numpy().decode("utf-8")
        source = specification["inputTensors"][0]
        tsa = tsa_in[source]
        idx = inIndices[source][inMonths.index(mnth)]
        ima = tf.convert_to_tensor(tsa[:, :, idx].read().result(), tf.float32)
        ima = tf.reshape(ima, [721, 1440, 1])
        for fni in range(1, len(specification["inputTensors"])):
            source = specification["inputTensors"][fni]
            tsa = tsa_in[source]
            idx = inIndices[source][inMonths.index(mnth)]
            imt = tf.convert_to_tensor(tsa[:, :, idx].read().result(), tf.float32)
            imt = tf.reshape(imt, [721, 1440, 1])
            ima = tf.concat([ima, imt], 2)
        return ima

    def load_input_tensor(month):
        result = tf.py_function(
            load_input_tensors_from_month_py,
            [month],
            tf.float32,
        )
        return result

    # Create Dataset from the source file contents
    tsIData = tnIData.map(
        load_input_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )

    if specification["outputTensors"] is not None:
        tsa_out = {}
        for source in specification["outputTensors"]:
            zfile = "%s/DCVAE-Climate/normalized_datasets/%s_zarr" % (
                os.getenv("SCRATCH"),
                source,
            )
            tsa_out[source] = ts.open(
                {
                    "driver": "zarr",
                    "kvstore": "file://" + zfile,
                }
            ).result()

        def load_output_tensors_from_month_py(month):
            mnth = month.numpy().decode("utf-8")
            source = specification["outputTensors"][0]
            tsa = tsa_out[source]
            idx = outIndices[source][outMonths.index(mnth)]
            ima = tf.convert_to_tensor(tsa[:, :, idx].read().result(), tf.float32)
            ima = tf.reshape(ima, [721, 1440, 1])
            for fni in range(1, len(specification["outputTensors"])):
                source = specification["outputTensors"][fni]
                tsa = tsa_out[source]
                idx = outIndices[source][outMonths.index(mnth)]
                imt = tf.convert_to_tensor(tsa[:, :, idx].read().result(), tf.float32)
                imt = tf.reshape(imt, [721, 1440, 1])
                ima = tf.concat([ima, imt], 2)
            return ima

        def load_output_tensor(month):
            result = tf.py_function(
                load_output_tensors_from_month_py,
                [month],
                tf.float32,
            )
            return result

        tsOData = tnIData.map(
            load_output_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE
        )

    # Zip the data together with the filenames (so we can find the date and source of each
    #   data tensor if we need it).
    if specification["outputTensors"] is not None:
        tz_data = tf.data.Dataset.zip((tnIData, tsIData, tsOData))
    else:
        tz_data = tf.data.Dataset.zip((tnIData, tsIData))

    # Optimisation
    if (purpose == "Train" and specification["trainCache"]) or (
        purpose == "Test" and specification["testCache"]
    ):
        tz_data = tz_data.cache()  # Great, iff you have enough RAM for it

    tz_data = tz_data.prefetch(tf.data.experimental.AUTOTUNE)

    return tz_data