ERA5 to HadUK-Grid - validate the trained VAE¶
Script to make the validation figure
#!/usr/bin/env python
# Plot a validation figure for the autoencoder.
# Fir components
# 1) Original field
# 2) Encoded field
# 3) Difference field
# 4) Original:Encoded scatter
#
import tensorflow as tf
import os
import sys
import random
import numpy as np
import iris
import matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", help="Epoch", type=int, required=True)
parser.add_argument("--year", type=int, required=False, default=1979)
parser.add_argument("--month", type=int, required=False, default=3)
parser.add_argument("--day", type=int, required=False, default=12)
args = parser.parse_args()
sys.path.append("%s/." % os.path.dirname(__file__))
from plot_HUKG_comparison import get_land_mask
from plot_HUKG_comparison import plot_Tmax
from plot_HUKG_comparison import plot_scatter
from plot_HUKG_comparison import plot_colourbar
# Make the HUKG tensor for the specified date
sys.path.append(
"%s/../../../data/prepare_training_tensors_HUKG_Tmax" % os.path.dirname(__file__)
)
from HUKG_load_tmax import HUKG_load_tmax
from HUKG_load_tmax import HUKG_load_tmax_climatology
from HUKG_load_tmax import HUKG_trim
from HUKG_load_tmax import HUKG_load_observations
ht = HUKG_load_tmax(args.year, args.month, args.day)
hc = HUKG_load_tmax_climatology(args.year, args.month, args.day)
ht = ht - hc
ht /= 10
ht += 0.5
ht = HUKG_trim(ht)
ht.data.data[ht.data.mask] = 0.5
msk = ht.data.mask
ht_in = tf.convert_to_tensor(ht.data.data, np.float32)
ht_in = tf.reshape(ht_in, [1, 1440, 896, 1])
# Make the ERA5 tensor for the specified date
sys.path.append(
"%s/../../../data/prepare_training_tensors_ERA5_HKUG_Tmax" % os.path.dirname(__file__)
)
from ERA5_load import ERA5_load_Tmax
from ERA5_load import ERA5_load_Tmax_climatology
et = ERA5_load_Tmax(args.year, args.month, args.day)
ec = ERA5_load_Tmax_climatology(args.year, args.month, args.day)
et = et - ec
et /= 10
et += 0.5
# Convert it to HadUKGrid grid
et = et.regrid(ht, iris.analysis.Linear())
# discard bottom left to make sizes multiply divisible by 2
et = HUKG_trim(et)
et_in = tf.convert_to_tensor(et.data, np.float32)
et_in = tf.reshape(et_in, [1, 1440, 896, 1])
# Set up the model and load the weights at the chosen epoch
sys.path.append("%s/.." % os.path.dirname(__file__))
from autoencoderModel import DCVAE
autoencoder = DCVAE()
weights_dir = (
"%s/Proxy_20CR/models/DCVAE_single_ERA5_to_HUKG_Tmax/" + "Epoch_%04d"
) % (
os.getenv("SCRATCH"),
args.epoch,
)
load_status = autoencoder.load_weights("%s/ckpt" % weights_dir).expect_partial()
# Check the load worked
load_status.assert_existing_objects_matched()
# Make encoded version
encoded = autoencoder.predict_on_batch(tf.reshape(et_in, [1, 1440, 896, 1]))
encoded = np.squeeze(encoded)
encoded[msk]=0.5
encoded = tf.convert_to_tensor(encoded, np.float32)
encoded = tf.reshape(encoded, [1, 1440, 896, 1])
# Discard masked components of ERA5
et_in = tf.squeeze(et_in).numpy()
et_in[msk]=0.5
et_in = tf.convert_to_tensor(et_in, np.float32)
et_in = tf.reshape(et_in, [1, 1440, 896, 1])
# Make the figure
lm = get_land_mask()
fig = Figure(
figsize=(15, 15),
dpi=100,
facecolor=(0.88, 0.88, 0.88, 1),
edgecolor=None,
linewidth=0.0,
frameon=False,
subplotpars=None,
tight_layout=None,
)
canvas = FigureCanvas(fig)
matplotlib.rcParams.update({"font.size": 16})
ax_global = fig.add_axes([0, 0, 1, 1], facecolor="white")
ax_global.set_axis_off()
ax_global.autoscale(enable=False)
ax_global.fill((-0.1, 1.1, 1.1, -0.1), (-0.1, -0.1, 1.1, 1.1), "white")
# Top left - original field
ax_of = fig.add_axes([0.01, 0.565, 0.323, 0.425])
ax_of.set_aspect("auto")
ax_of.set_axis_off()
ofp = plot_Tmax(
ax_of,
(ht_in - 0.5) * 10,
vMin=-5,
vMax=5,
land=lm,
label="Original: %04d-%02d-%02d" % (args.year,args.month,args.day),
)
ax_ocb = fig.add_axes([0.0365, 0.505, 0.27, 0.05])
plot_colourbar(fig, ax_ocb, ofp)
# Top centre - ERA5 field
ax_of = fig.add_axes([0.34, 0.565, 0.323, 0.425])
ax_of.set_aspect("auto")
ax_of.set_axis_off()
ofp = plot_Tmax(
ax_of,
(et_in - 0.5) * 10,
vMin=-5,
vMax=5,
land=lm,
label="ERA5",
)
ax_ocb = fig.add_axes([0.3665, 0.505, 0.27, 0.05])
plot_colourbar(fig, ax_ocb, ofp)
# Bottom centre - ERA difference field
ax_of = fig.add_axes([0.34, 0.065, 0.323, 0.425])
ax_of.set_aspect("auto")
ax_of.set_axis_off()
ofp = plot_Tmax(
ax_of,
(et_in - ht_in) * 10,
vMin=-5,
vMax=5,
land=lm,
label="ERA5 Difference",
)
ax_ocb = fig.add_axes([0.3665, 0.005, 0.27, 0.05])
plot_colourbar(fig, ax_ocb, ofp)
# Top right - encoded field
ax_of = fig.add_axes([0.67, 0.565, 0.323, 0.425])
ax_of.set_aspect("auto")
ax_of.set_axis_off()
ofp = plot_Tmax(
ax_of,
(encoded - 0.5) * 10,
vMin=-5,
vMax=5,
land=lm,
label="Generator",
)
ax_ocb = fig.add_axes([0.6965, 0.505, 0.27, 0.05])
plot_colourbar(fig, ax_ocb, ofp)
# Bottom right - generated difference field
ax_of = fig.add_axes([0.67, 0.065, 0.323, 0.425])
ax_of.set_aspect("auto")
ax_of.set_axis_off()
ofp = plot_Tmax(
ax_of,
(encoded - ht_in) * 10,
vMin=-5,
vMax=5,
land=lm,
label="Generator Difference",
)
ax_ocb = fig.add_axes([0.6965, 0.005, 0.27, 0.05])
plot_colourbar(fig, ax_ocb, ofp)
# Bottom right - scatterplots
xmin = (np.min(
np.concatenate(
(
ht_in.numpy().flatten(),
et_in.numpy().flatten(),
encoded.numpy().flatten(),
)
)
)-0.5)*10
xmax = (np.max(
np.concatenate(
(
ht_in.numpy().flatten(),
et_in.numpy().flatten(),
encoded.numpy().flatten(),
)
)
) -0.5)*10
ax_scatter = fig.add_axes([0.07, 0.29, 0.22, 0.22])
plot_scatter(
ax_scatter, ht_in, et_in, xlab="Original", ylab="ERA5", d_min=xmin, d_max=xmax
)
ax_scatter2 = fig.add_axes([0.07, 0.05, 0.22, 0.22])
plot_scatter(
ax_scatter2,
ht_in,
encoded,
xlab="Original",
ylab="Generator",
d_min=xmin,
d_max=xmax,
)
fig.savefig("comparison.png")
Utility functions used in the plot
# Functions to plot haduk-grid before and after autoencoding
# Takes data in tensorflow format (no geometry metadata, normalised)
import os
import sys
import iris
import numpy as np
import tensorflow as tf
import matplotlib
import cmocean
sys.path.append(
"%s/../../../data/prepare_training_tensors_HUKG_Tmax/" % os.path.dirname(__file__)
)
from HUKG_load_tmax import HUKG_trim
# It's a spatial map, so want the land mask
def get_land_mask():
mask = iris.load_cube(
"%s/fixed_fields/land_mask/HadUKG_land_from_Copernicus.nc"
% os.getenv("DATADIR")
)
return HUKG_trim(mask)
def plot_Tmax(
ax,
tmx,
vMin=0,
vMax=1,
obs=None,
o_size=1,
land=None,
mask=None,
label=None,
):
if land is None:
land = get_land_mask()
lats = land.coord("projection_y_coordinate").points
lons = land.coord("projection_x_coordinate").points
land_img = ax.pcolorfast(
lons, lats, land.data, cmap="Greys", alpha=1.0, vmax=1.2, vmin=-0.5, zorder=10
)
pdata = tf.squeeze(tmx).numpy()
if mask is not None:
pdata[mask] = 0
pdata = np.ma.masked_where(land.data == 0, pdata)
T_img = ax.pcolorfast(
lons,
lats,
pdata,
cmap=cmocean.cm.balance,
vmin=vMin,
vmax=vMax,
alpha=1.0,
zorder=40,
)
if obs is not None:
obs = tf.squeeze(obs)
x = (obs[:, 1].numpy() / 896) * (lons[-1] - lons[0]) + lons[0]
y = (obs[:, 0].numpy() / 1440) * (lats[-1] - lats[0]) + lats[0]
ax.scatter(
x, # ((x/2).astype(int)+1)*2,
y, # ((y/2).astype(int)+1)*2,
s=3.0 * o_size,
c="black",
marker="o",
alpha=1.0,
zorder=60,
)
if label is not None:
ax.text(
lons[0] + (lons[-1] - lons[0]) * 0.03,
lats[0] + (lats[-1] - lats[0]) * 0.02,
label,
horizontalalignment="left",
verticalalignment="bottom",
color="black",
bbox=dict(
facecolor=(0.8, 0.8, 0.8, 0.8),
edgecolor="black",
boxstyle="round",
pad=0.5,
),
size=16,
clip_on=True,
zorder=100,
)
return T_img
def plot_scatter(
ax, t_in, t_out, land=None, d_max=5, d_min=-5, xlab="Original", ylab="Encoded"
):
x = (t_in.numpy().flatten() - 0.5) * 10
y = (t_out.numpy().flatten() - 0.5) * 10
# if land is not None:
# ld = land.data.flatten
y = y[x != 0]
x = x[x != 0]
ax.hexbin(
x=x,
y=y,
cmap=cmocean.cm.ice_r,
bins="log",
mincnt=1,
extent=(d_min,d_max,d_min,d_max),
zorder=50
)
ax.add_line(
matplotlib.lines.Line2D(
xdata=(d_min, d_max),
ydata=(d_min, d_max),
linestyle="solid",
linewidth=1.5,
color=(0.5, 0.5, 0.5, 1),
zorder=100,
)
)
ax.set(xlabel=xlab, ylabel=ylab)
ax.grid(color="black", alpha=0.2, linestyle="-", linewidth=0.5)
def plot_colourbar(
fig,
ax,
T_img,
):
ax.set_axis_off()
cb = fig.colorbar(
T_img, ax=ax, location="bottom", orientation="horizontal", fraction=1.0
)