Plot one frame from the 20CRv2c autoencoded video¶
The video shows temperature, wind, and mean-sea-level pressure, it’s the same as the original 20CRv2c video except that the fields are passed through the autoencoder before plotting, and the 100-dimensional latent space encoded reprsentation is shown as an overlay in the bottom left.
#!/usr/bin/env python
# Atmospheric state - near-surface temperature, u-wind, v-wind, and prmsl.
# Show the version after compression into a latant space.
import os
import sys
import IRData.twcr as twcr
import datetime
import pickle
import tensorflow as tf
tf.enable_eager_execution()
import iris
import numpy
import math
import matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
sys.path.append('%s/../../lib/' % os.path.dirname(__file__))
from insolation import load_insolation
from geometry import to_analysis_grid
from normalise import normalise_insolation
from normalise import normalise_t2m
from normalise import unnormalise_t2m
from normalise import normalise_prmsl
from normalise import unnormalise_prmsl
from normalise import normalise_wind
from normalise import unnormalise_wind
from plots import plot_cube
from plots import wind_field
from plots import quantile_normalise_t2m
from plots import draw_lat_lon
from pandas import qcut
# Fix dask SPICE bug
import dask
dask.config.set(scheduler='single-threaded')
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--year", help="Year",
type=int,required=True)
parser.add_argument("--month", help="Integer month",
type=int,required=True)
parser.add_argument("--day", help="Day of month",
type=int,required=True)
parser.add_argument("--hour", help="Time of day (0 to 23.99)",
type=float,required=True)
parser.add_argument("--pole_latitude", help="Latitude of projection pole",
default=90,type=float,required=False)
parser.add_argument("--pole_longitude", help="Longitude of projection pole",
default=180,type=float,required=False)
parser.add_argument("--npg_longitude", help="Longitude of view centre",
default=0,type=float,required=False)
parser.add_argument("--zoom", help="Scale factor for viewport (1=global)",
default=1,type=float,required=False)
parser.add_argument("--epoch", help="Epoch",
type=int,required=True)
parser.add_argument("--opdir", help="Directory for output files",
default="%s/images/20CRv2c_ls_4var" % \
os.getenv('SCRATCH'),
type=str,required=False)
parser.add_argument("--zfile", help="Noise pickle file name",
default="%s/images/20CRv2c_ls_4var/z.pkl" % \
os.getenv('SCRATCH'),
type=str,required=False)
args = parser.parse_args()
if not os.path.isdir(args.opdir):
os.makedirs(args.opdir)
# Function to do the multivariate plot
lsmask=iris.load_cube("%s/fixed_fields/land_mask/opfc_global_2019.nc" % os.getenv('DATADIR'))
# Random field for the wind noise
z=pickle.load(open( args.zfile, "rb" ) )
def three_plot(ax,t2m,u10m,v10m,prmsl,ls):
ax.set_xlim(-180,180)
ax.set_ylim(-90,90)
ax.set_aspect('auto')
ax.set_axis_off() # Don't want surrounding x and y axis
ax.add_patch(Rectangle((0,0),1,1,facecolor=(0.6,0.6,0.6,1),
fill=True,zorder=1))
# Draw lines of latitude and longitude
draw_lat_lon(ax)
# Add the continents
mask_pc = plot_cube(0.05)
lsmask = iris.load_cube("%s/fixed_fields/land_mask/opfc_global_2019.nc" % os.getenv('SCRATCH'))
lsmask = lsmask.regrid(mask_pc,iris.analysis.Linear())
lats = lsmask.coord('latitude').points
lons = lsmask.coord('longitude').points
mask_img = ax.pcolorfast(lons, lats, lsmask.data,
cmap=matplotlib.colors.ListedColormap(
((0.4,0.4,0.4,0),
(0.4,0.4,0.4,1))),
vmin=0,
vmax=1,
alpha=1.0,
zorder=20)
# Calculate the wind noise
wind_pc=plot_cube(0.2)
cs=iris.coord_systems.RotatedGeogCS(90,180,0)
rw=iris.analysis.cartography.rotate_winds(u10m,v10m,cs)
u10m = rw[0].regrid(wind_pc,iris.analysis.Linear())
v10m = rw[1].regrid(wind_pc,iris.analysis.Linear())
seq=(dte-datetime.datetime(2000,1,1)).total_seconds()/3600
wind_noise_field=wind_field(u10m,v10m,z,sequence=int(seq*5),epsilon=0.01)
# Plot the temperature
t2m=quantile_normalise_t2m(t2m)
t2m_pc=plot_cube(0.05)
t2m = t2m.regrid(t2m_pc,iris.analysis.Linear())
# Adjust to show the wind
wscale=200
s=wind_noise_field.data.shape
wind_noise_field.data=qcut(wind_noise_field.data.flatten(),wscale,
labels=False,
duplicates='drop').reshape(s)-(wscale-1)/2
# Plot as a colour map
wnf=wind_noise_field.regrid(t2m,iris.analysis.Linear())
t2m_img = ax.pcolorfast(lons, lats, t2m.data*1000+wnf.data,
cmap='RdYlBu_r',
alpha=0.8,
vmin=-100,
vmax=1100,
zorder=100)
# Plot the prmsl
prmsl_pc=plot_cube(0.25)
prmsl = prmsl.regrid(prmsl_pc,iris.analysis.Linear())
lats = prmsl.coord('latitude').points
lons = prmsl.coord('longitude').points
lons,lats = numpy.meshgrid(lons,lats)
CS=ax.contour(lons, lats, prmsl.data*0.01,
colors='black',
linewidths=0.5,
alpha=1.0,
levels=numpy.arange(870,1050,10),
zorder=200)
# Overlay the latent-space representation in the SE Pacific
x=numpy.linspace(-160,-120,10)
y=numpy.linspace(-75,-75+(40*16/18),10)
latent_img = ax.pcolorfast(x,y,ls.reshape(10,10),
cmap='viridis',
alpha=1.0,
vmin=-3,
vmax=3,
zorder=1000)
# Label with the date
ax.text(180/args.zoom-(360/args.zoom)*0.009,
90/args.zoom-(180/args.zoom)*0.016,
"%04d-%02d-%02d" % (args.year,args.month,args.day),
horizontalalignment='right',
verticalalignment='top',
color='black',
bbox=dict(facecolor=(0.6,0.6,0.6,0.5),
edgecolor='black',
boxstyle='round',
pad=0.5),
size=14,
clip_on=True,
zorder=500)
# Get autoencoded versions of the validation data
model_save_file=("%s/ML_GCM/autoencoder/"+
"Epoch_%04d/autoencoder") % (
os.getenv('SCRATCH'),args.epoch)
autoencoder=tf.keras.models.load_model(model_save_file,compile=False)
# Also load the encoder (to get the latent state)
model_save_file=("%s/ML_GCM/autoencoder/"+
"/Epoch_%04d/encoder") % (
os.getenv('SCRATCH'),args.epoch)
encoder=tf.keras.models.load_model(model_save_file,compile=False)
# Load and compress the data - only at timepoint (i.e. hour%6==0)
def get_compressed(year,month,day,hour):
prmsl=twcr.load('prmsl',datetime.datetime(year,month,day,hour),
version='2c')
prmsl=to_analysis_grid(prmsl.extract(iris.Constraint(member=1)))
t2m=twcr.load('air.2m',datetime.datetime(year,month,day,hour),
version='2c')
t2m=to_analysis_grid(t2m.extract(iris.Constraint(member=1)))
u10m=twcr.load('uwnd.10m',datetime.datetime(year,month,day,hour),
version='2c')
u10m=to_analysis_grid(u10m.extract(iris.Constraint(member=1)))
v10m=twcr.load('vwnd.10m',datetime.datetime(year,month,day,hour),
version='2c')
v10m=to_analysis_grid(v10m.extract(iris.Constraint(member=1)))
insol=to_analysis_grid(load_insolation(year,month,day,hour))
# Convert the validation data into tensor format
t2m_t = tf.convert_to_tensor(normalise_t2m(t2m.data),numpy.float32)
t2m_t = tf.reshape(t2m_t,[79,159,1])
prmsl_t = tf.convert_to_tensor(normalise_prmsl(prmsl.data),numpy.float32)
prmsl_t = tf.reshape(prmsl_t,[79,159,1])
u10m_t = tf.convert_to_tensor(normalise_wind(u10m.data),numpy.float32)
u10m_t = tf.reshape(u10m_t,[79,159,1])
v10m_t = tf.convert_to_tensor(normalise_wind(v10m.data),numpy.float32)
v10m_t = tf.reshape(v10m_t,[79,159,1])
insol_t = tf.convert_to_tensor(normalise_insolation(insol.data),numpy.float32)
insol_t = tf.reshape(insol_t,[79,159,1])
ict = tf.concat([t2m_t,prmsl_t,u10m_t,v10m_t,insol_t],2) # Now [79,159,5]
ict = tf.reshape(ict,[1,79,159,5])
result = autoencoder.predict_on_batch(ict)
result = tf.reshape(result,[79,159,5])
ls = encoder.predict_on_batch(ict)
# Convert the encoded fields back to unnormalised cubes
t2m_r=t2m.copy()
t2m_r.data = tf.reshape(result.numpy()[:,:,0],[79,159]).numpy()
t2m_r.data = unnormalise_t2m(t2m_r.data)
prmsl_r=prmsl.copy()
prmsl_r.data = tf.reshape(result.numpy()[:,:,1],[79,159]).numpy()
prmsl_r.data = unnormalise_prmsl(prmsl_r.data)
u10m_r=u10m.copy()
u10m_r.data = tf.reshape(result.numpy()[:,:,2],[79,159]).numpy()
u10m_r.data = unnormalise_wind(u10m_r.data)
v10m_r=v10m.copy()
v10m_r.data = tf.reshape(result.numpy()[:,:,3],[79,159]).numpy()
v10m_r.data = unnormalise_wind(v10m_r.data)
return {'t2m':t2m_r,'prmsl':prmsl_r,'u10m':u10m_r,'v10m':v10m_r,'ls':ls}
# Get the compressed data at the selected time
dte=datetime.datetime(args.year,args.month,args.day,
int(args.hour),int(args.hour%1*60))
prevt=datetime.datetime(args.year,args.month,args.day,
int(args.hour)-int(args.hour)%6)
nextt=prevt+datetime.timedelta(hours=6)
s_previous=get_compressed(prevt.year,prevt.month,prevt.day,prevt.hour)
s_next=get_compressed(nextt.year,nextt.month,nextt.day,nextt.hour)
compressed={}
for var in ('t2m','prmsl','u10m','v10m'):
cl=iris.cube.CubeList((s_previous[var],s_next[var])).merge_cube()
compressed[var]=cl.interpolate([('time',dte)],iris.analysis.Linear())
w=(dte-prevt).total_seconds()/(nextt-prevt).total_seconds()
ls=s_previous['ls']*(1-w)+s_next['ls']*w
# Plot the two fields and a scatterplot for each variable
fig=Figure(figsize=(19.2,10.8),
dpi=100,
facecolor=(0.88,0.88,0.88,1),
edgecolor=None,
linewidth=0.0,
frameon=False,
subplotpars=None,
tight_layout=None)
canvas=FigureCanvas(fig)
# Two maps, original and reconstructed
ax_compressed=fig.add_axes([0,0,1,1])
three_plot(ax_compressed,compressed['t2m'],
compressed['u10m'],
compressed['v10m'],
compressed['prmsl'],
ls)
# Render the figure as a png
fig.savefig('%s/%04d%02d%02d%02d%02d.png' % (args.opdir,args.year,
args.month,args.day,
int(args.hour),
int(args.hour%1*60)))
Library and utility functions used: