ERA5 to HadUK-Grid - train the Variational AutoEncoderΒΆ

Relies on already having a VAE for HadUK-Grid. It loads the weighs for that and freezes the weights for the generator. This forces the output to be compatible with making HadUK-Grid fields by assimilating observations (same generator in both cases => statistically compatible output). This is not necessary - we could just retrain the whole thing from scratch and it would still produce a good conversion, just one that was less consistent with the assimilation product.

#!/usr/bin/env python

# Convolutional Variational Autoencoder for haduk-grid Tmax

import os
import sys
import time
import tensorflow as tf
import pickle

# Distribute across all GPUs
# strategy = tf.distribute.MirroredStrategy()
# Doesn't yet work - need to change the training loop and datasets to be
#  strategy aware.
strategy = tf.distribute.get_strategy()

# Load the model specification
sys.path.append("%s/." % os.path.dirname(__file__))
from autoencoderModel import DCVAE
from autoencoderModel import train_step
from autoencoderModel import compute_loss

# Load the data source provider
from makeDataset import getDataset

# How many images to use?
nTrainingImages = 13697  # Max is 13697
nTestImages = 1522  # Max is 1522

# How many epochs to train for
nEpochs = 201
# Length of an epoch - if None, same as nTrainingImages
nImagesInEpoch = None

if nImagesInEpoch is None:
    nImagesInEpoch = nTrainingImages

# Dataset parameters
bufferSize = 1000  # Untested
batchSize = 32  # Arbitrary

# Set up the training data
trainingData = getDataset(purpose="training", nImages=nTrainingImages).repeat(1)
trainingData = trainingData.shuffle(bufferSize).batch(batchSize)

# Subset of the training data for metrics
validationData = getDataset(purpose="training", nImages=nTestImages).batch(batchSize)

# Set up the test data
testData = getDataset(purpose="test", nImages=nTestImages)
testData = testData.batch(batchSize)

# Instantiate the model
with strategy.scope():
    autoencoder = DCVAE()
    optimizer = tf.keras.optimizers.Adam(1e-4)
    # Start from the trained HUKG->HUKG model
    weights_dir = (
        "%s/Proxy_20CR/models/DCVAE_single_HUKG_Tmax/" + "Epoch_%04d"
    ) % (
        os.getenv("SCRATCH"),
        200,
    )
    load_status = autoencoder.load_weights("%s/ckpt" % weights_dir).expect_partial()
    # Check the load worked
    load_status.assert_existing_objects_matched()
    # Freeze the decoder
    autoencoder.decoder.trainable = False
    for layer in autoencoder.decoder.layers:
        layer.trainable = False
    autoencoder.decoder.compile()


# Save the model weights and the history state after every epoch
history = {}
history["loss"] = []
history["val_loss"] = []


def save_state(model, epoch, loss):
    save_dir = ("%s/Proxy_20CR/models/DCVAE_single_ERA5_to_HUKG_Tmax/" + "Epoch_%04d") % (
        os.getenv("SCRATCH"),
        epoch,
    )
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    model.save_weights("%s/ckpt" % save_dir)
    history["loss"].append(loss)
    # history["val_loss"].append(logs["val_loss"])
    history_file = "%s/history.pkl" % save_dir
    pickle.dump(history, open(history_file, "wb"))


for epoch in range(nEpochs):
    start_time = time.time()
    for train_x in trainingData:
        train_step(autoencoder, train_x, optimizer)
    end_time = time.time()

    train_rmse = tf.keras.metrics.Mean()
    train_logpz = tf.keras.metrics.Mean()
    train_logqz_x = tf.keras.metrics.Mean()
    for test_x in validationData:
        (rmse, logpz, logqz_x) = compute_loss(autoencoder, test_x)
        train_rmse(rmse)
        train_logpz(logpz)
        train_logqz_x(logqz_x)
    test_rmse = tf.keras.metrics.Mean()
    test_logpz = tf.keras.metrics.Mean()
    test_logqz_x = tf.keras.metrics.Mean()
    for test_x in testData:
        (rmse, logpz, logqz_x) = compute_loss(autoencoder, test_x)
        test_rmse(rmse)
        test_logpz(logpz)
        test_logqz_x(logqz_x)
    print("Epoch: {}".format(epoch))
    print("RMSE: {}, {}".format(train_rmse.result(), test_rmse.result()))
    print("logpz: {}, {}".format(train_logpz.result(), test_logpz.result()))
    print("logqz_x: {}, {}".format(train_logqz_x.result(), test_logqz_x.result()))
    print("time: {}".format(end_time - start_time))
    if epoch % 10 == 0:
        save_state(autoencoder, epoch, test_rmse.result())