Assemble ERA5 raw data into a set of tf.tensors

The data download scripts assemble selected ERA5 data in netCDF files. To use that data efficiently in analysis and modelling it is necessary to reformat it as a set of tf.tensors. These have consistent format and resolution and can be reassembled into a tf.data.Dataset for ML model training.

So for each month in the training period, for each variable (2m_temperature, mean_sea_level_pressure, total_precipitation), we read in the data from netCDF, regrid it to a common grid, and save it as a tf.tensor.

The script make_all_raw_tensors.sh creates a set of commands to make all the tensors. The script outputs a list of other scripts (one per year, month, variable). Running all the output scripts will create the set of tensors. (Use GNU parallel to run the scripts efficiently - or submit them as jobs to a cluster).

#!/bin/bash

# Make all the raw tensors
# Requires downloaded data

(cd ERA5 && ./make_all_tensors.py --variable=2m_temperature)
(cd ERA5 && ./make_all_tensors.py --variable=sea_surface_temperature)
(cd ERA5 && ./make_all_tensors.py --variable=mean_sea_level_pressure)
(cd ERA5 && ./make_all_tensors.py --variable=total_precipitation)

When the main script has completed and all the raw tensors are made, we need to add some metadata to them (used by subsequent scripts to find out how much data is available). The script `update_tensor_metadata.sh` does this. Run this script to set the metadata and check that the tensors have been created successfully.

#!/bin/bash

# Update the metadata for the raw tensors (date::index lists for everything available)
# Make all the tensors first

(cd ERA5 && ./update_tensor_metadata.py --variable=2m_temperature)
(cd ERA5 && ./update_tensor_metadata.py --variable=sea_surface_temperature)
(cd ERA5 && ./update_tensor_metadata.py --variable=mean_sea_level_pressure)
(cd ERA5 && ./update_tensor_metadata.py --variable=total_precipitation)

Other scripts used by that main script:

Script to make the set of tensors for one variable. Takes argument –variable:

#!/usr/bin/env python

# Make raw data tensors for normalization

import os
from shutil import rmtree
import argparse
import zarr
import tensorstore as ts
import numpy as np

from tensor_utils import date_to_index, FirstYear, LastYear

sDir = os.path.dirname(os.path.realpath(__file__))

parser = argparse.ArgumentParser()
parser.add_argument(
    "--variable",
    help="Variable name",
    type=str,
    required=True,
)
args = parser.parse_args()

# Create the output zarr array if it doesn't exist
fn = "%s/DCVAE-Climate/raw_datasets/ERA5/%s_zarr" % (
    os.getenv("SCRATCH"),
    args.variable,
)

# Create TensorStore dataset if it doesn't exist
try:
    dataset = ts.open(
        {
            "driver": "zarr",
            "kvstore": "file://" + fn,
        },
        dtype=ts.float32,
        chunk_layout=ts.ChunkLayout(chunk_shape=[721, 1440, 1]),
        create=True,
        fill_value=np.nan,
        shape=[
            721,
            1440,
            date_to_index(LastYear, 12) + 1,
        ],
    ).result()
except ValueError:  # Already exists
    pass

# Add date range to array as metadata
# TensorStore doesn't support metadata, so use the underlying zarr array
zarr_ds = zarr.open(fn, mode="r+")
zarr_ds.attrs["FirstYear"] = FirstYear
zarr_ds.attrs["LastYear"] = LastYear

count = 0
for year in range(FirstYear, LastYear + 1):
    for month in range(1, 13):
        idx = date_to_index(year, month)
        slice = zarr_ds[:, :, idx]
        if np.all(np.isnan(slice)):  # Data missing, so make it
            cmd = (
                "%s/make_training_tensor.py --year=%04d --month=%02d --variable=%s"
                % (
                    sDir,
                    year,
                    month,
                    args.variable,
                )
            )
            print(cmd)

Calls another script to make a single tensor:

#!/usr/bin/env python

# Read in monthly variable from ERA5 - regrid to model resolution
# Convert into a TensorFlow tensor.
# Serialise and store on $SCRATCH.

import os
import sys
import numpy as np
import warnings

# Supress TensorFlow moaning about cuda - we don't need a GPU for this
# Also the warning message confuses people.
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
import tensorstore as ts

import dask

# Going to do external parallelism - run this on one core
tf.config.threading.set_inter_op_parallelism_threads(1)
dask.config.set(scheduler="single-threaded")

from tensor_utils import load_raw, raw_to_tensor, date_to_index

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--year", help="Year", type=int, required=True)
parser.add_argument("--month", help="Integer month", type=int, required=True)
parser.add_argument("--variable", help="Variable name", type=str, required=True)
args = parser.parse_args()

fn = "%s/DCVAE-Climate/raw_datasets/ERA5/%s_zarr" % (
    os.getenv("SCRATCH"),
    args.variable,
)

dataset = ts.open(
    {
        "driver": "zarr",
        "kvstore": "file://" + fn,
    }
).result()

# Load and standardise data
try:
    qd = load_raw(args.year, args.month, variable=args.variable)
    ict = raw_to_tensor(qd)
except Exception:
    warnings.warn(
        "Failed to load data for %s %04d-%02d" % (args.variable, args.year, args.month)
    )
    ict = tf.fill([721, 1440], tf.constant(np.nan, dtype=tf.float32))

# Write to file
didx = date_to_index(args.year, args.month)
op = dataset[:, :, didx].write(ict)
op.result()  # Ensure write completes before exiting

Library functions to convert between tf.tensor and iris.cube.cube:

# Utility functions for creating and manipulating raw tensors

import numpy as np
import tensorflow as tf

from get_data.ERA5 import ERA5_monthly
from utilities import grids

# Convert date into an array index
FirstYear = 1940
LastYear = 2035


def date_to_index(year, month):
    return (year - FirstYear) * 12 + month - 1


def index_to_date(idx):
    return (idx // 12) + FirstYear, (idx % 12) + 1


# Load the data for 1 month (on the standard cube).
def load_raw(year, month, member=None, variable="total_precipitation"):
    raw = ERA5_monthly.load(
        variable=variable,
        year=year,
        month=month,
        grid=grids.E5sCube,
    )
    raw.data.data[raw.data.mask == True] = np.nan
    return raw


# Convert raw cube to tensor
def raw_to_tensor(raw):
    ict = tf.convert_to_tensor(raw.data, tf.float32)
    return ict


# Convert tensor to cube
def tensor_to_cube(tensor):
    cube = grids.E5sCube.copy()
    cube.data = tensor.numpy()
    cube.data = np.ma.MaskedArray(cube.data, np.isnan(cube.data))
    return cube

Metadata update script for an ERA5 variable:

#!/usr/bin/env python

# Update the raw tensor zarr array with metadata giving dates and indices for each field present

import os
import argparse
import zarr
import numpy as np

from tensor_utils import date_to_index, FirstYear, LastYear

sDir = os.path.dirname(os.path.realpath(__file__))

parser = argparse.ArgumentParser()
parser.add_argument(
    "--variable",
    help="Variable name",
    type=str,
    required=True,
)
args = parser.parse_args()

# Find the raw_tensor zarr array
fn = "%s/DCVAE-Climate/raw_datasets/ERA5/%s_zarr" % (
    os.getenv("SCRATCH"),
    args.variable,
)

# Add date range to array as metadata
zarr_ds = zarr.open(fn, mode="r+")

AvailableMonths = {}
start = "%04d-%02d" % (LastYear, 12)
end = "%04d-%02d" % (FirstYear, 1)
for year in range(FirstYear, LastYear + 1):
    for month in range(1, 13):
        dte = "%d-%02d" % (year, month)
        idx = date_to_index(year, month)
        slice = zarr_ds[:, :, idx]
        if not np.all(np.isnan(slice)):
            AvailableMonths["%d-%02d" % (year, month)] = idx
            if dte < start:
                start = dte
            if dte > end:
                end = dte

zarr_ds.attrs["AvailableMonths"] = AvailableMonths

missing = 0
for year in range(FirstYear, LastYear + 1):
    for month in range(1, 13):
        dte = "%d-%02d" % (year, month)
        if dte < start or dte > end:
            continue
        if dte not in AvailableMonths:
            missing += 1

print(args.variable)
print("Start date:", start)
print("End date:", end)
print("Missing months:", missing)
print("Total months:", len(AvailableMonths))
print("\n")