Assemble data for the ML Diffusion video

I got the data from Ben Booth - I’m not including it here.

Convenience functions to load and interpolate data.:

# Functions to load The ML precipitation data

import os
import iris
import iris.cube
import iris.coords
import iris.coord_systems
from iris.util import squeeze
import iris.coord_systems
import iris.util
import numpy as np
import cftime
import datetime

# Don't really understand this, but it gets rid of the error messages.
iris.FUTURE.datum_support = True

# Coordinate system for the data
ML_cs = iris.coord_systems.RotatedGeogCS(
    37.5, 177.5, ellipsoid=iris.coord_systems.GeogCS(6371229.0)
)

# Central data directory
Ddir = "/data/users/backup/datadir/ben.booth/ML/UNET_output/Tests_Mar25"


# And a function to add the coord system to a cube (in-place)
def add_coord_system(cbe):
    cbe.coord("grid_latitude").coord_system = ML_cs
    cbe.coord("grid_longitude").coord_system = ML_cs


# Make a higher-res cube on the same grid as the data for plotting
def plot_cube(resolution=0.025, xmin=357.75, xmax=363.25, ymin=-2.75, ymax=2.75):

    lat_values = np.arange(ymin, ymax + resolution, resolution)
    latitude = iris.coords.DimCoord(
        lat_values,
        standard_name="grid_latitude",
        units="degrees_north",
        coord_system=ML_cs,
    )
    lon_values = np.arange(xmin, xmax + resolution, resolution)
    longitude = iris.coords.DimCoord(
        lon_values,
        standard_name="grid_longitude",
        units="degrees_east",
        coord_system=ML_cs,
    )
    dummy_data = np.zeros((len(lat_values), len(lon_values)))
    plot_cube = iris.cube.Cube(
        dummy_data, dim_coords_and_dims=[(latitude, 0), (longitude, 1)]
    )
    return plot_cube


default_cube = plot_cube()


def get_land_mask(grid=default_cube):
    fname = "%s/fixed_fields/land_mask/opfc_global_2019.nc" % os.getenv("DATADIR")
    if not os.path.isfile(fname):
        raise Exception("No data file %s" % fname)
    land_mask = squeeze(iris.load_cube(fname))
    land_mask = land_mask.regrid(grid, iris.analysis.Linear())
    return land_mask


def load_daily(
    model="target",
    year=None,
    month=None,
    day=None,
    member=1,
    grid=default_cube,
):
    if year is None or month is None or day is None:
        raise Exception("Year, month, and day, must be specified")
    c1 = iris.Constraint(
        time=lambda cell: cell.point.year == year
        and cell.point.month == month
        and cell.point.day == day
    )
    c2 = iris.Constraint(member=lambda cell: cell.point == member)
    if model == "target":
        fname = "%s/predictions_mse_corrected_structure.nc" % Ddir
        c3 = iris.Constraint(cube_func=lambda cube: cube.name() == "target")
        varC = iris.load_cube(
            fname,
            c1 & c2 & c3,
        )
    elif model == "unet-mse":
        fname = "%s/predictions_mse_corrected_structure.nc" % Ddir
        c3 = iris.Constraint(cube_func=lambda cube: cube.name() == "prediction")
        varC = iris.load_cube(
            fname,
            c1 & c2 & c3,
        )
    elif model == "unet-asym":
        fname = "%s/predictions_emulasym_corrected_structure.nc" % Ddir
        c3 = iris.Constraint(cube_func=lambda cube: cube.name() == "prediction")
        varC = iris.load_cube(
            fname,
            c1 & c2 & c3,
        )
    elif model == "diffusion":
        fname = "%s/predictions-ensemble01-sample0_Diffusion.nc" % Ddir
        c2 = iris.Constraint()
        c3 = iris.Constraint(cube_func=lambda cube: cube.name() == "pred_pr")
        varC = iris.load_cube(
            fname,
            c1 & c2 & c3,
        )
        varC.coord("ensemble_member").points = 1  # Hackety hack
        varC = iris.util.squeeze(varC)
        varC.data *= 86400.0  # Convert from m/s to m/day
    else:
        raise Exception("Unknown model %s" % model)
    add_coord_system(varC)
    if grid is not None:  # Regrid, but mask out areas outside original grid
        lat_max = varC.coord("grid_latitude").points.max()
        lat_min = varC.coord("grid_latitude").points.min()
        lon_max = varC.coord("grid_longitude").points.max()
        lon_min = varC.coord("grid_longitude").points.min()
        varC = varC.regrid(grid, iris.analysis.Nearest())
        latlon = np.meshgrid(
            varC.coord("grid_longitude").points, varC.coord("grid_latitude").points
        )
        varC.data = np.ma.masked_where(
            (latlon[0] < lon_min)
            | (latlon[0] > lon_max)
            | (latlon[1] < lat_min)
            | (latlon[1] > lat_max),
            varC.data,
        )
    return varC


# Switch dates using cftime


def load(
    model="target",
    year=None,
    month=None,
    day=None,
    hour=None,
    member=1,
    grid=default_cube,
):
    if year is None or month is None or day is None or hour is None:
        raise Exception("Year, month, day, and hour must be specified")
    today = load_daily(model, year, month, day, member, grid)
    current_date = cftime.datetime(year, month, day, hour, calendar="360_day")
    if hour == 12:
        return today
    elif hour < 12:
        previous_day = current_date - datetime.timedelta(days=1)
        prev = load_daily(
            model,
            previous_day.year,
            previous_day.month,
            previous_day.day,
            member,
            grid,
        )
        f = iris.cube.CubeList([prev, today])
        f = f.merge_cube()
        interpolated = f.interpolate(
            [("time", current_date.replace(hour=hour))], iris.analysis.Linear()
        )
        return interpolated
    else:  # hour> 12
        next_day = current_date + datetime.timedelta(days=1)
        next = load_daily(
            model,
            next_day.year,
            next_day.month,
            next_day.day,
            member,
            grid,
        )
        f = iris.cube.CubeList([today, next])
        f = f.merge_cube()
        interpolated = f.interpolate(
            [("time", current_date.replace(hour=hour))], iris.analysis.Linear()
        )
        return interpolated


def load_3hr(
    model="target",
    year=None,
    month=None,
    day=None,
    hour=None,
    grid=default_cube,
):
    if year is None or month is None or day is None or hour is None:
        raise Exception("Year, month, and day, must be specified")
    fyr = year
    if month == 12 and (day > 1 or hour > 0):
        fyr = year + 1
    fname = (
        "/data/scratch/tomas.wetherell/create_dataset/01/3hrinst_dataset_%04d.nc" % fyr
    )
    c1 = iris.Constraint(
        time=lambda cell: cell.point.year == year
        and cell.point.month == month
        and cell.point.day == day
        and cell.point.hour == hour
    )
    if model == "target":
        c3 = iris.Constraint(cube_func=lambda cube: cube.name() == "precipitation_flux")
        varC = iris.load_cube(
            fname,
            c1 & c3,
        )
        varC = iris.util.squeeze(varC)
        varC.data = varC.data * 86400.0  # Convert from m/s to m/day
    else:
        raise Exception("Unknown model %s" % model)
    add_coord_system(varC)
    if grid is not None:  # Regrid, but mask out areas outside original grid
        lat_max = varC.coord("grid_latitude").points.max()
        lat_min = varC.coord("grid_latitude").points.min()
        lon_max = varC.coord("grid_longitude").points.max()
        lon_min = varC.coord("grid_longitude").points.min()
        varC = varC.regrid(grid, iris.analysis.Nearest())
        latlon = np.meshgrid(
            varC.coord("grid_longitude").points, varC.coord("grid_latitude").points
        )
        varC.data = np.ma.masked_where(
            (latlon[0] < lon_min)
            | (latlon[0] > lon_max)
            | (latlon[1] < lat_min)
            | (latlon[1] > lat_max),
            varC.data,
        )
    return varC


def load_3hr_i(
    model="target",
    year=None,
    month=None,
    day=None,
    hour=None,
    minute=0,
    grid=default_cube,
):
    if year is None or month is None or day is None or hour is None:
        raise Exception("Year, month, day, and hour must be specified")
    if hour % 3 == 0 and minute == 0:
        return load_3hr(model, year, month, day, hour, grid)
    if hour % 3 != 0:
        b_hour = hour - hour % 3
    else:
        b_hour = hour
    b_field = load_3hr(model, year, month, day, b_hour, grid)
    e_hour = b_hour + 3
    current_date = cftime.datetime(year, month, day, hour, minute, calendar="360_day")
    if e_hour == 24:
        e_hour = 0
        next_day = current_date + datetime.timedelta(days=1)
        year = next_day.year
        month = next_day.month
        day = next_day.day
    e_field = load_3hr(model, year, month, day, e_hour, grid)
    f = iris.cube.CubeList([b_field, e_field])
    f = f.merge_cube()
    interpolated = f.interpolate([("time", current_date)], iris.analysis.Linear())
    return interpolated