Plotting utilitiesΒΆ
Convenience functions for plotting maps, scatter plots, and histograms.
There are three user-callable functions in this file:
PlotFieldAxes() - plots a colourmap of an iris cube in a given matplotlib axes
PlotHistAxes() - plots a histogram of an iris cube in a given matplotlib axes
PlotScatterAxes() - plots a scatterplot of two iris cubes in a given matplotlib axes
# Plotting utility functions
import os
import numpy as np
import iris
import iris.util
import iris.analysis
import iris.coord_systems
import iris.exceptions
import matplotlib
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
import cmocean
# I don't care about datums.
iris.FUTURE.datum_support = True
# Get the pole location from a cube
# Assumes an equirectangular projection
def extract_pole(cube):
try:
lat = cube.coord("grid_latitude")
if lat.coord_system is None:
return (90, 180, 0)
except Exception:
return (90, 180, 0)
if lat.coord_system.grid_mapping_name == "rotated_latitude_longitude":
return (
lat.coord_system.grid_north_pole_latitude,
lat.coord_system.grid_north_pole_longitude,
lat.coord_system.north_pole_grid_longitude,
)
else:
print(lat.coord_system)
raise Exception("Unsupported cube for coordinate extraction")
# Make a dummy iris Cube for plotting.
# Makes a cube in equirectangular projection.
# Takes resolution, plot range, and pole location
# (all in degrees) as arguments, returns an
# iris cube.
def plot_cube(
resolution=0.25,
xmin=-180,
xmax=180,
ymin=-90,
ymax=90,
pole_latitude=90,
pole_longitude=180,
npg_longitude=0,
):
cs = iris.coord_systems.RotatedGeogCS(pole_latitude, pole_longitude, npg_longitude)
lat_values = np.arange(ymin, ymax + resolution, resolution)
latitude = iris.coords.DimCoord(
lat_values, standard_name="latitude", units="degrees_north", coord_system=cs
)
lon_values = np.arange(xmin, xmax + resolution, resolution)
longitude = iris.coords.DimCoord(
lon_values, standard_name="longitude", units="degrees_east", coord_system=cs
)
dummy_data = np.zeros((len(lat_values), len(lon_values)))
plot_cube = iris.cube.Cube(
dummy_data, dim_coords_and_dims=[(latitude, 0), (longitude, 1)]
)
return plot_cube
# High res land mask for plots
def get_land_mask(grid_cube=None):
lm = iris.load_cube(
"%s/ERA5/monthly/reanalysis/land_mask.nc" % os.getenv("SCRATCH")
)
lm = iris.util.squeeze(lm)
lm.coord("latitude").coord_system = iris.coord_systems.RotatedGeogCS(90, 180, 0)
lm.coord("longitude").coord_system = iris.coord_systems.RotatedGeogCS(90, 180, 0)
lm.data = np.where(lm.data.mask, 0, 1)
if grid_cube is not None:
lm = lm.regrid(grid_cube, iris.analysis.Linear())
return lm
# Plot a map in a supplied axes
def plotFieldAxes(
ax_map,
field,
vMax=None,
vMin=None,
lMask=None,
cMap=cmocean.cm.balance,
plotCube=None,
f_alpha=1.0,
show_land=True,
):
if plotCube is not None:
field = field.regrid(plotCube, iris.analysis.Linear())
if vMax is None:
vMax = np.max(field.data.compressed())
if vMin is None:
vMin = np.min(field.data.compressed())
if lMask is None:
cs = extract_pole(field)
lMask = get_land_mask(
plot_cube(
resolution=0.1,
pole_latitude=cs[0],
pole_longitude=cs[1],
npg_longitude=cs[2],
)
)
try:
lons = field.coord("grid_longitude").points
lats = field.coord("grid_latitude").points
except iris.exceptions.CoordinateNotFoundError:
lons = field.coord("longitude").points
lats = field.coord("latitude").points
ax_map.set_ylim(min(lats), max(lats))
ax_map.set_xlim(min(lons), max(lons))
ax_map.set_axis_off()
ax_map.set_aspect("equal", adjustable="box", anchor="C")
ax_map.add_patch(
Rectangle(
(min(lons), min(lats)),
max(lons) - min(lons),
max(lats) - min(lats),
facecolor=(0.9, 0.9, 0.9, 1),
fill=True,
zorder=1,
)
)
# Plot the field
T_img = ax_map.pcolorfast(
lons,
lats,
field.data,
cmap=cMap,
vmin=vMin,
vmax=vMax,
alpha=f_alpha,
zorder=10,
)
# Overlay the land mask
if show_land:
mask_img = ax_map.pcolorfast(
lMask.coord("longitude").points,
lMask.coord("latitude").points,
lMask.data,
cmap=matplotlib.colors.ListedColormap(
((0.4, 0.4, 0.4, 0), (0.4, 0.4, 0.4, 0.3))
),
vmin=0,
vmax=1,
alpha=1,
zorder=100,
)
return T_img
# Scatter plot in provided axes
def plotScatterAxes(
ax, var_in, var_out, vMax=None, vMin=None, xlabel="", ylabel="", bins="log"
):
if vMax is None:
vMax = max(np.max(var_in.data), np.max(var_out.data))
if vMin is None:
vMin = min(np.min(var_in.data), np.min(var_out.data))
ax.set_xlim(vMin, vMax)
ax.set_ylim(vMin, vMax)
ax.hexbin(
x=var_in.data.compressed(),
y=var_out.data.compressed(),
cmap=cmocean.tools.crop_by_percent(cmocean.cm.ice_r, 5, which="min"),
bins=bins,
gridsize=50,
mincnt=1,
)
ax.add_line(
Line2D(
xdata=(vMin, vMax),
ydata=(vMin, vMax),
linestyle="solid",
linewidth=0.5,
color=(0.5, 0.5, 0.5, 1),
zorder=100,
)
)
ax.set(ylabel=ylabel, xlabel=xlabel)
ax.grid(color="black", alpha=0.2, linestyle="-", linewidth=0.5)
# Histogram in provided axes
def plotHistAxes(ax, var, vMax=None, vMin=None, xlabel="", ylabel="", bins=100):
if vMax is None:
vMax = np.max(var.data)
if vMin is None:
vMin = np.min(var.data)
x = var.data.flatten()
if np.ma.is_masked(x):
x = x.compressed()
ax.hist(
x=x,
range=(vMin, vMax),
bins=bins,
color="blue",
density=True,
)
ax.set(ylabel=ylabel, xlabel=xlabel)
ax.grid(color="black", alpha=0.2, linestyle="-", linewidth=0.5)