Default model training scriptΒΆ

Script (autoencoder.py) to train the model specified in the configuration file.

This script trains the model. It takes plenty of resources (a GPU - or several - is strongly recommended).

#!/usr/bin/env python

# Convolutional Variational Autoencoder.

# This is a generic model that can be used for any set of input and output fields
# To make a specific model, copy this file, specify.py, validate.py, and validate_multi.py
# to a new directory (makeDataset and autoencoderModel are generic - don't copy them).
# Then edit specify.py to choose the input and output fields, and the training parameters.
# Then run this file to train the model, and the validate scripts to test the result.

import os
import sys
import time

# Cut down on the TensorFlow warning messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--epoch", help="Restart from epoch", type=int, required=False, default=1
)
args = parser.parse_args()

# Load the data path, data source, and model specification
from specify import specification
from ML_models.default.makeDataset import getDataset
from ML_models.default.autoencoderModel import DCVAE, getModel


# Get Datasets
def getDatasets():
    # Set up the training data
    trainingData = getDataset(specification, purpose="Train").repeat(1)
    trainingData = trainingData.shuffle(specification["shuffleBufferSize"]).batch(
        specification["batchSize"]
    )
    trainingData = specification["strategy"].experimental_distribute_dataset(
        trainingData
    )
    validationData = getDataset(specification, purpose="Train")
    validationData = validationData.batch(specification["batchSize"])
    validationData = specification["strategy"].experimental_distribute_dataset(
        validationData
    )

    # Set up the test data
    testData = getDataset(specification, purpose="Test")
    testData = testData.shuffle(specification["shuffleBufferSize"]).batch(
        specification["batchSize"]
    )
    testData = specification["strategy"].experimental_distribute_dataset(testData)

    return (trainingData, validationData, testData)


# Instantiate and run the model under the control of the distribution strategy
with specification["strategy"].scope():
    trainingData, validationData, testData = getDatasets()

    autoencoder = getModel(specification, epoch=args.epoch)

    # logfile to output the metrics
    log_FN = ("%s/DCVAE-Climate/%s/logs/Training") % (
        os.getenv("SCRATCH"),
        specification["modelName"],
    )
    if not os.path.isdir(os.path.dirname(log_FN)):
        os.makedirs(os.path.dirname(log_FN))
    logfile_writer = tf.summary.create_file_writer(log_FN)
    with logfile_writer.as_default():
        tf.summary.write(
            "OutputNames",
            specification["outputNames"],
            step=0,
        )

    # For each Epoch: train, save state, and report progress
    for epoch in range(args.epoch, specification["nEpochs"] + 1):
        start_time = time.time()

        # Train on all batches in the training data
        for batch in trainingData:
            if specification["trainingMask"] is not None:
                mbatch = tf.where(specification["trainingMask"] != 0, batch[-1], 0.0)
                batch = (batch[:-1], mbatch)
            per_replica_op = specification["strategy"].run(
                autoencoder.train_on_batch, args=(batch, specification["optimizer"])
            )

        end_training_time = time.time()

        # Validation and output only every printInterval epochs
        if epoch % specification["printInterval"] != 0:
            continue

        # Accumulate average losses over all batches in the validation data
        autoencoder.update_metrics(validationData, testData)

        # Save model state
        save_dir = "%s/DCVAE-Climate/%s/weights/Epoch_%04d" % (
            os.getenv("SCRATCH"),
            specification["modelName"],
            epoch,
        )
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        autoencoder.save_weights("%s/ckpt" % save_dir)

        # Update the log file with current metrics
        autoencoder.updateLogfile(logfile_writer, epoch)

        end_monitoring_time = time.time()

        # Report progress
        print("Epoch: {}".format(epoch))
        autoencoder.printState()
        print(
            "time: {} (+{})".format(
                int(end_training_time - start_time),
                int(end_monitoring_time - end_training_time),
            )
        )