Simple autoencoder validation scriptΒΆ

To test the autoencoder we need to load the trained model (with tf.keras.models.load_model), load an original MSLP field and convert it to a tensor (exactly as we did when building the model training data), run the test field through the autoencoder (with the model predict_on_batch function), convert the autoencoded tensor back into the original units (reverse the normalisation), and then plot the original and encoded fields (with the Meteographica package).

#!/usr/bin/env python

# Compare a 20CRv2c prmsl field with the same field passed through
#   the autoencoder.

import tensorflow as tf
tf.enable_eager_execution()
import numpy

import IRData.twcr as twcr
import iris
import datetime
import argparse
import os

import Meteorographica as mg

import matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import cartopy
import cartopy.crs as ccrs

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--year", help="Year",
                    type=int,required=True)
parser.add_argument("--month", help="Integer month",
                    type=int,required=True)
parser.add_argument("--day", help="Day of month",
                    type=int,required=True)
parser.add_argument("--hour", help="Hour of day (0 to 23)",
                    type=int,required=True)
parser.add_argument("--member", help="Ensemble member",
                    default=1,type=int,required=False)
parser.add_argument("--version", help="20CR version",
                    default='2c',type=str,required=False)
parser.add_argument("--variable", help="20CR variable",
                    default='prmsl',type=str,required=False)
parser.add_argument("--epoch", help="Model at which epoch?",
                    type=int,required=True)

args = parser.parse_args()

# Get the 20CR data
ic=twcr.load(args.variable,datetime.datetime(args.year,args.month,
                                            args.day,args.hour),
                           version=args.version)
ic=ic.extract(iris.Constraint(member=args.member))

# Get the autoencoder
model_save_file=("%s/Machine-Learning-experiments/simple_autoencoder/"+
                 "saved_models/Epoch_%04d") % (
                 os.getenv('SCRATCH'),args.epoch)
autoencoder=tf.keras.models.load_model(model_save_file)

# Normalisation - Pa to mean=0, sd=1 - and back
def normalise(x):
   x -= 101325
   x /= 3000
   return x

def unnormalise(x):
   x *= 3000
   x += 101325
   return x

# Run the data through the autoencoder and convert back to iris cube
pm=ic.copy()
pm.data=normalise(pm.data)
ict=tf.convert_to_tensor(pm.data, numpy.float32)
ict=tf.reshape(ict,[1,91*180]) # ????
result=autoencoder.predict_on_batch(ict)
result=tf.reshape(result,[91,180])
pm.data=unnormalise(result)

# Make a comparison plot - original on top, encoded below
fig=Figure(figsize=(15,15*1.06/1.04),  # Width, Height (inches)
           dpi=100,
           facecolor=(0.88,0.88,0.88,1),
           edgecolor=None,
           linewidth=0.0,
           frameon=False,
           subplotpars=None,
           tight_layout=None)
canvas=FigureCanvas(fig)

# Global projection
projection=ccrs.RotatedPole(pole_longitude=180.0, pole_latitude=90.0)
extent=[-180,180,-90,90]

# Top half for the originals
ax_orig=fig.add_axes([0.02,0.51,0.96,0.47],projection=projection)
ax_orig.set_axis_off()
ax_orig.set_extent(extent, crs=projection)
ax_post=fig.add_axes([0.02,0.02,0.96,0.47],projection=projection)
ax_post.set_axis_off()
ax_post.set_extent(extent, crs=projection)

# Background, grid and land for both
ax_orig.background_patch.set_facecolor((0.88,0.88,0.88,1))
ax_post.background_patch.set_facecolor((0.88,0.88,0.88,1))
mg.background.add_grid(ax_orig)
mg.background.add_grid(ax_post)
land_img_orig=ax_orig.background_img(name='GreyT', resolution='low')
land_img_post=ax_post.background_img(name='GreyT', resolution='low')

# Plot the pressures as contours
mg.pressure.plot(ax_orig,ic,
                 scale=0.01,
                 resolution=0.25,
                 levels=numpy.arange(870,1050,7),
                 colors='blue',
                 label=True,
                 linewidths=2)
mg.pressure.plot(ax_post,pm,
                 scale=0.01,
                 resolution=0.25,
                 levels=numpy.arange(870,1050,7),
                 colors='blue',
                 label=True,
                 linewidths=2)

# Mark the data used
mg.utils.plot_label(ax_post,
              '%04d-%02d-%02d:%02d' % (args.year,args.month,args.day,args.hour),
              facecolor=fig.get_facecolor(),
              x_fraction=0.98,
              horizontalalignment='right')


# Render the figure as a png
fig.savefig("comparison_%04d-%02d-%02d:%02d.png" % 
             (args.year,args.month,args.day,args.hour))