Pressure from wind - train the Variational AutoEncoderΒΆ

#!/usr/bin/env python

# Convolutional Variational Autoencoder for 20CR2c

# This one fits a set of P, U & V fields

import os
import sys
import time
import tensorflow as tf
import pickle
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--epoch", help="Restart from epoch", type=int, required=False, default=0
)
args = parser.parse_args()

# Distribute across all GPUs
# strategy = tf.distribute.MirroredStrategy()
# Doesn't yet work - need to change the training loop and datasets to be
#  strategy aware.
strategy = tf.distribute.get_strategy()

# Load the model specification
sys.path.append("%s/." % os.path.dirname(__file__))
from autoencoderModel import DCVAE
from autoencoderModel import train_step
from autoencoderModel import compute_loss

# Load the data source provider
sys.path.append("%s/../PUV_dataset" % os.path.dirname(__file__))
from makeDataset import getDataset


# How many images to use?
nTrainingImages = 10780  # Max is 10780
nTestImages = 1197  # Max is 1197

# How many epochs to train for
nEpochs = 251
# Length of an epoch - if None, same as nTrainingImages
nImagesInEpoch = None

if nImagesInEpoch is None:
    nImagesInEpoch = nTrainingImages

# Dataset parameters
bufferSize = 1000  # Untested
batchSize = 32  # Arbitrary

# Set up the training data
trainingData = getDataset(purpose="training", nImages=nTrainingImages).repeat(5)
trainingData = trainingData.shuffle(bufferSize).batch(batchSize)

# Subset of the training data for metrics
validationData = getDataset(purpose="training", nImages=nTestImages).batch(batchSize)

# Set up the test data
testData = getDataset(purpose="test", nImages=nTestImages)
testData = testData.batch(batchSize)

# Instantiate the model
with strategy.scope():
    autoencoder = DCVAE()
    optimizer = tf.keras.optimizers.Adam(1e-4)
    # If we are doing a restart, load the weights
    if args.epoch > 0:
        weights_dir = ("%s/Proxy_20CR/models/DCVAE_single_PUV/" + "Epoch_%04d") % (
            os.getenv("SCRATCH"),
            args.epoch,
        )
        load_status = autoencoder.load_weights("%s/ckpt" % weights_dir)
        # Check the load worked
        load_status.assert_existing_objects_matched()


# Save the model weights and the history state after every epoch
history = {}
history["loss"] = []
history["val_loss"] = []


def save_state(model, epoch, loss):
    save_dir = ("%s/Proxy_20CR/models/DCVAE_single_PUV/" + "Epoch_%04d") % (
        os.getenv("SCRATCH"),
        epoch,
    )
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    model.save_weights("%s/ckpt" % save_dir)
    history["loss"].append(loss)
    # history["val_loss"].append(logs["val_loss"])
    history_file = "%s/history.pkl" % save_dir
    pickle.dump(history, open(history_file, "wb"))


for epoch in range(nEpochs):
    start_time = time.time()
    for train_x in trainingData:
        train_step(autoencoder, train_x, optimizer)
    end_time = time.time()

    train_rmse_p = tf.keras.metrics.Mean()
    train_rmse_u = tf.keras.metrics.Mean()
    train_rmse_v = tf.keras.metrics.Mean()
    train_logpz = tf.keras.metrics.Mean()
    train_logqz_x = tf.keras.metrics.Mean()
    for test_x in validationData:
        (rmse_p, rmse_u, rmse_v, logpz, logqz_x) = compute_loss(autoencoder, test_x)
        train_rmse_p(rmse_p)
        train_rmse_u(rmse_u)
        train_rmse_v(rmse_v)
        train_logpz(logpz)
        train_logqz_x(logqz_x)
    test_rmse_p = tf.keras.metrics.Mean()
    test_rmse_u = tf.keras.metrics.Mean()
    test_rmse_v = tf.keras.metrics.Mean()
    test_logpz = tf.keras.metrics.Mean()
    test_logqz_x = tf.keras.metrics.Mean()
    for test_x in testData:
        (rmse_p, rmse_u, rmse_v, logpz, logqz_x) = compute_loss(autoencoder, test_x)
        test_rmse_p(rmse_p)
        test_rmse_u(rmse_u)
        test_rmse_v(rmse_v)
        test_logpz(logpz)
        test_logqz_x(logqz_x)
    print("Epoch: {}".format(epoch))
    print("RMSE P: {}, {}".format(train_rmse_p.result(), test_rmse_p.result()))
    print("RMSE U: {}, {}".format(train_rmse_u.result(), test_rmse_u.result()))
    print("RMSE V: {}, {}".format(train_rmse_v.result(), test_rmse_v.result()))
    print("logpz: {}, {}".format(train_logpz.result(), test_logpz.result()))
    print("logqz_x: {}, {}".format(train_logqz_x.result(), test_logqz_x.result()))
    print("time: {}".format(end_time - start_time))
    if epoch%10==0:
        save_state(autoencoder, epoch, test_rmse_p.result())