Source code for the forecast model validationΒΆ
#!/usr/bin/env python
# Compare an original weather field with the predictor output.
import tensorflow as tf
tf.enable_eager_execution()
import numpy
import IRData.twcr as twcr
import iris
import datetime
import argparse
import sys
import os
import math
import pickle
import Meteorographica as mg
from pandas import qcut
import matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
sys.path.append('%s/../../lib/' % os.path.dirname(__file__))
from insolation import load_insolation
from geometry import to_analysis_grid
from normalise import normalise_insolation
from normalise import normalise_t2m
from normalise import unnormalise_t2m
from normalise import normalise_prmsl
from normalise import unnormalise_prmsl
from normalise import normalise_wind
from normalise import unnormalise_wind
from plots import plot_cube
from plots import make_wind_seed
from plots import wind_field
from plots import quantile_normalise_t2m
from plots import draw_lat_lon
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", help="Epoch",
type=int,required=False,default=25)
args = parser.parse_args()
dte=datetime.datetime(2010,3,12,18)
# Function to do the multivariate plot
lsmask=iris.load_cube("%s/fixed_fields/land_mask/opfc_global_2019.nc" %
os.getenv('SCRATCH'))
# Random field for the wind noise
z=make_wind_seed(resolution=0.4)
def three_plot(ax,t2m,u10m,v10m,prmsl):
ax.set_xlim(-180,180)
ax.set_ylim(-90,90)
ax.set_aspect('auto')
ax.set_axis_off() # Don't want surrounding x and y axis
ax.add_patch(Rectangle((0,0),1,1,facecolor=(0.6,0.6,0.6,1),fill=True,zorder=1))
# Draw lines of latitude and longitude
draw_lat_lon(ax)
# Add the continents
mask_pc = plot_cube(0.05)
lsmask = iris.load_cube("%s/fixed_fields/land_mask/opfc_global_2019.nc" % os.getenv('SCRATCH'))
lsmask = lsmask.regrid(mask_pc,iris.analysis.Linear())
lats = lsmask.coord('latitude').points
lons = lsmask.coord('longitude').points
mask_img = ax.pcolorfast(lons, lats, lsmask.data,
cmap=matplotlib.colors.ListedColormap(
((0.4,0.4,0.4,0),
(0.4,0.4,0.4,1))),
vmin=0,
vmax=1,
alpha=1.0,
zorder=20)
# Calculate the wind noise
wind_pc=plot_cube(0.5)
cs=iris.coord_systems.RotatedGeogCS(90,180,0)
rw=iris.analysis.cartography.rotate_winds(u10m,v10m,cs)
u10m = rw[0].regrid(wind_pc,iris.analysis.Linear())
v10m = rw[1].regrid(wind_pc,iris.analysis.Linear())
wind_noise_field=wind_field(u10m,v10m,z,sequence=None,epsilon=0.01)
# Plot the temperature
t2m_pc=plot_cube(0.05)
t2m = t2m.regrid(t2m_pc,iris.analysis.Linear())
t2m=quantile_normalise_t2m(t2m)
# Adjust to show the wind
wscale=200
s=wind_noise_field.data.shape
wind_noise_field.data=qcut(wind_noise_field.data.flatten(),wscale,labels=False,
duplicates='drop').reshape(s)-(wscale-1)/2
# Plot as a colour map
wnf=wind_noise_field.regrid(t2m,iris.analysis.Linear())
t2m_img = ax.pcolorfast(lons, lats, t2m.data*1000+wnf.data,
cmap='RdYlBu_r',
alpha=0.8,
vmin=-100,
vmax=1100,
zorder=100)
# Plot the prmsl
prmsl_pc=plot_cube(0.25)
prmsl = prmsl.regrid(prmsl_pc,iris.analysis.Linear())
lats = prmsl.coord('latitude').points
lons = prmsl.coord('longitude').points
lons,lats = numpy.meshgrid(lons,lats)
CS=ax.contour(lons, lats, prmsl.data*0.01,
colors='black',
linewidths=0.5,
alpha=1.0,
levels=numpy.arange(870,1050,10),
zorder=200)
# Load the source data
prmsl=twcr.load('prmsl',dte,version='2c')
prmsl=to_analysis_grid(prmsl.extract(iris.Constraint(member=1)))
t2m=twcr.load('air.2m',dte,version='2c')
t2m=to_analysis_grid(t2m.extract(iris.Constraint(member=1)))
u10m=twcr.load('uwnd.10m',dte,version='2c')
u10m=to_analysis_grid(u10m.extract(iris.Constraint(member=1)))
v10m=twcr.load('vwnd.10m',dte,version='2c')
v10m=to_analysis_grid(v10m.extract(iris.Constraint(member=1)))
insol=to_analysis_grid(load_insolation(dte.year,dte.month,dte.day,dte.hour))
# Convert the source data into tensor format
t2m_t = tf.convert_to_tensor(normalise_t2m(t2m.data),numpy.float32)
t2m_t = tf.reshape(t2m_t,[79,159,1])
prmsl_t = tf.convert_to_tensor(normalise_prmsl(prmsl.data),numpy.float32)
prmsl_t = tf.reshape(prmsl_t,[79,159,1])
u10m_t = tf.convert_to_tensor(normalise_wind(u10m.data),numpy.float32)
u10m_t = tf.reshape(u10m_t,[79,159,1])
v10m_t = tf.convert_to_tensor(normalise_wind(v10m.data),numpy.float32)
v10m_t = tf.reshape(v10m_t,[79,159,1])
insol_t = tf.convert_to_tensor(normalise_insolation(insol.data),numpy.float32)
insol_t = tf.reshape(insol_t,[79,159,1])
# Get predicted versions of the target data
model_save_file=("%s/ML_GCM/predictor/"+
"Epoch_%04d/predictor") % (
os.getenv('SCRATCH'),args.epoch)
autoencoder=tf.keras.models.load_model(model_save_file,compile=False)
ict = tf.concat([t2m_t,prmsl_t,u10m_t,v10m_t,insol_t],2) # Now [79,159,5]
ict = tf.reshape(ict,[1,79,159,5])
result = autoencoder.predict_on_batch(ict)
result = tf.reshape(result,[79,159,5])
# Convert the predicted fields back to unnormalised cubes
t2m_r=t2m.copy()
t2m_r.data = tf.reshape(result.numpy()[:,:,0],[79,159]).numpy()
t2m_r.data = unnormalise_t2m(t2m_r.data)
prmsl_r=prmsl.copy()
prmsl_r.data = tf.reshape(result.numpy()[:,:,1],[79,159]).numpy()
prmsl_r.data = unnormalise_prmsl(prmsl_r.data)
u10m_r=u10m.copy()
u10m_r.data = tf.reshape(result.numpy()[:,:,2],[79,159]).numpy()
u10m_r.data = unnormalise_wind(u10m_r.data)
v10m_r=v10m.copy()
v10m_r.data = tf.reshape(result.numpy()[:,:,3],[79,159]).numpy()
v10m_r.data = unnormalise_wind(v10m_r.data)
# Load the actual data for mthe target time
dte2=dte+datetime.timedelta(hours=6)
prmsl=twcr.load('prmsl',dte2,version='2c')
prmsl=to_analysis_grid(prmsl.extract(iris.Constraint(member=1)))
t2m=twcr.load('air.2m',dte2,version='2c')
t2m=to_analysis_grid(t2m.extract(iris.Constraint(member=1)))
u10m=twcr.load('uwnd.10m',dte2,version='2c')
u10m=to_analysis_grid(u10m.extract(iris.Constraint(member=1)))
v10m=twcr.load('vwnd.10m',dte2,version='2c')
v10m=to_analysis_grid(v10m.extract(iris.Constraint(member=1)))
# Plot the two fields and a scatterplot for each variable
fig=Figure(figsize=(9.6*1.2,10.8),
dpi=100,
facecolor=(0.88,0.88,0.88,1),
edgecolor=None,
linewidth=0.0,
frameon=False,
subplotpars=None,
tight_layout=None)
canvas=FigureCanvas(fig)
# Two maps, original and reconstructed
ax_original=fig.add_axes([0.005,0.525,0.75,0.45])
three_plot(ax_original,t2m,u10m,v10m,prmsl)
ax_reconstructed=fig.add_axes([0.005,0.025,0.75,0.45])
three_plot(ax_reconstructed,t2m_r,u10m_r,v10m_r,prmsl_r)
# Scatterplot of encoded v original
def plot_scatter(ax,ic,pm):
dmin=min(ic.min(),pm.min())
dmax=max(ic.max(),pm.max())
dmean=(dmin+dmax)/2
dmax=dmean+(dmax-dmean)*1.02
dmin=dmean-(dmean-dmin)*1.02
ax.set_xlim(dmin,dmax)
ax.set_ylim(dmin,dmax)
ax.scatter(x=pm.flatten(),
y=ic.flatten(),
c='black',
alpha=0.25,
marker='.',
s=2)
ax.set(ylabel='Original',
xlabel='Encoded')
ax.grid(color='black',
alpha=0.2,
linestyle='-',
linewidth=0.5)
ax_t2m=fig.add_axes([0.83,0.80,0.16,0.17])
plot_scatter(ax_t2m,t2m.data,t2m_r.data)
ax_prmsl=fig.add_axes([0.83,0.55,0.16,0.17])
plot_scatter(ax_prmsl,prmsl.data*0.01,prmsl_r.data*0.01)
ax_u10m=fig.add_axes([0.83,0.30,0.16,0.17])
plot_scatter(ax_u10m,u10m.data,u10m_r.data)
ax_v10m=fig.add_axes([0.83,0.05,0.16,0.17])
plot_scatter(ax_v10m,v10m.data,v10m_r.data)
# Render the figure as a png
fig.savefig("compare_tpuv.png")
Library and utility functions used: