Utility functions for Rainfall Rescue dataΒΆ

This is a pair of library utility files with miscelanious useful functions:

  • pairs.py - functions to read and write pairs of images and CSV files containing the original Rainfall Rescue transcriptions.

  • validate.py - plotting functions used in comparison plots showing the skill of AI transcriptions

# Utility functions for handling RR image/CSV pairs

import os
import csv
import json
import random
import re

from PIL import Image


# Get the image/csv names
def get_index_list(max_n=None, shuffle=True, seed=None, fake=False, training=None):
    if seed is not None:
        random.seed(seed)
    if fake:
        image_path = os.path.join(f"{os.getenv('PDIR')}/fake_training_data/images")
        result = [x[:-4] for x in os.listdir(image_path) if x.endswith(".jpg")]
    else:
        image_path = os.path.join(f"{os.getenv('PDIR')}/from_Ed/images")
        result = [x[:-4] for x in os.listdir(image_path) if x.endswith(".jpg")]
    if training is not None:
        if training:
            result = [x for x in result if x[-1] != "0"]
        else:
            result = [x for x in result if x[-1] == "0"]
    if shuffle:
        random.shuffle(result)
    if max_n is not None and len(result) > max_n:
        result = result[:max_n]
    return result


# Load a csv file into a data structure (dictionary)
def load_station_csv(csv_path):
    result = {}
    with open(csv_path, mode="r") as file:
        reader = csv.reader(file)
        for index, row in enumerate(reader):
            row = ["null" if x == "" else x for x in row]
            if index == 0:
                result["Name"] = row[0]
            if index == 2:
                result["Number"] = row[1]
            if index == 4:
                result["Years"] = row[1:11]
            if index >= 5 and index <= 16:  # Monthly data
                result[row[0]] = row[1:11]
            if index == 17:
                result["Totals"] = row[1:11]
    return result


# Convert the csv data to a json string
# To serve as target for the model
def csv_to_json(csv_data):
    """
    Convert CSV data to a JSON string.

    Args:
        csv_data (dict): The CSV data as a dictionary.

    Returns:
        str: JSON string representation of the CSV data.
    """
    csv_data["Number"] = str(csv_data["Number"])  # Easier if this is a string
    j = json.dumps(csv_data, separators=(",\n", ":"))
    # reformat to match LLM schema
    j = re.sub(
        r"\[([^\]]*)\]",
        lambda m: "[" + m.group(1).replace("\r", "").replace("\n", "") + "]",
        j,
        flags=re.DOTALL,
    )
    j = j.replace("{", "{\n")
    j = j.replace("}", "\n}")
    return j


# Load a pair of image and csv data
def load_pair(label):
    """
    Load a pair of image and CSV file based on the label.

    Args:
        label (str): The label of the image and CSV file.

    Returns:
        tuple: A tuple containing the image and the CSV file path.
    """
    if len(label) != 4:  # Real data
        image_path = os.path.join(f"{os.getenv('PDIR')}/from_Ed/images", f"{label}.jpg")
        csv_path = os.path.join(
            f"{os.getenv('PDIR')}/from_Ed/csvs",
            f"{label}.csv",
        )

        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image {image_path} does not exist.")
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV {csv_path} does not exist.")

        image = Image.open(image_path)
        csv = load_station_csv(csv_path)
    else:
        # Fake data
        image_path = os.path.join(
            f"{os.getenv('PDIR')}/fake_training_data/images", f"{label}.jpg"
        )
        csv_path = os.path.join(
            f"{os.getenv('PDIR')}/fake_training_data/csvs",
            f"{label}.csv",
        )

        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image {image_path} does not exist.")
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV {csv_path} does not exist.")

        image = Image.open(image_path)
        csv = load_fake_csv(csv_path)

    return image, csv


# Load a pair of image and csv data - from the fake training data
def load_fake_pair(label):
    image_path = os.path.join(
        f"{os.getenv('PDIR')}/fake_training_data/images", f"{label}.jpg"
    )
    csv_path = os.path.join(
        f"{os.getenv('PDIR')}/fake_training_data/csvs",
        f"{label}.csv",
    )

    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image {image_path} does not exist.")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"CSV {csv_path} does not exist.")

    image = Image.open(image_path)
    csv = load_fake_csv(csv_path)
    return image, csv


# Load a csv file into a data structure (dictionary)
def load_fake_csv(csv_path):
    result = {}
    with open(csv_path, mode="r") as file:
        contents = file.read()
        contents = contents.replace("'", '"')  # Json needs double quotes
        result = json.loads(contents)
    return result
# Functions to make validation plots for rainfall_rescue

from rainfall_rescue.utils.pairs import load_pair, csv_to_json
import re
import json
import os
from matplotlib.text import TextPath
from matplotlib.patches import PathPatch
from matplotlib.transforms import Affine2D
import numpy as np
from collections import Counter


def plot_two_colored_text(
    ax, x, y, text1, text2, size=12, colour1="blue", colour2="red"
):

    # Calculate scaling factors from points to axes coordinates
    fig = ax.figure
    axes_bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    axes_width_inch = axes_bbox.width
    axes_height_inch = axes_bbox.height
    scale_x = 1 / axes_width_inch * (1 / 72)
    scale_y = 1 / axes_height_inch * (1 / 72)

    path = TextPath((0, 0), text1, size=size)
    trans = Affine2D().scale(scale_x, scale_y).translate(x, y) + ax.transAxes
    patch = PathPatch(path, color=colour1, linewidth=0, transform=trans)
    ax.add_patch(patch)
    verts = patch.get_path().vertices
    trans = patch.get_transform()
    verts_data = trans.transform(verts)
    verts_axes = ax.transAxes.inverted().transform(verts_data)
    right_limit_text1 = np.max(verts_axes[:, 0])
    path = TextPath((0, 0), text2, size=size)
    trans = (
        Affine2D().scale(scale_x, scale_y).translate(right_limit_text1, y)
        + ax.transAxes
    )
    patch = PathPatch(path, color=colour2, linewidth=0, transform=trans)
    ax.add_patch(patch)


# Present extracted data as a %.2f string as far as possible
def format_value(data, key, year_idx):
    if key == "Name" or key == "Number":
        try:
            return data[key]
        except KeyError:
            return "N/A"
    if key == "Years":
        try:
            return str(data[key][year_idx])
        except (IndexError, KeyError):
            return "N/A"
    try:
        value = data[key][year_idx]
    except (IndexError, KeyError):
        return "N/A"

    return format_as_2f(value)


def format_as_2f(value):
    """Format a value as a string with two decimal places."""
    if value is None or value == "null":
        return "null"
    try:
        return "%.2f" % float(value)
    except ValueError:
        return str(value)  # Return as is if it cannot be converted to float


# Plot the image into a given axes
def plot_image(ax, img):
    ax.set_axis_off()
    imgplot = ax.imshow(img, zorder=10)


# Plot target and retrieved image metadata into a given axes
def plot_metadata(ax, extracted, jcsv):

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

    ymp = 0.8
    for metad in ("Number", "Name"):
        exv = extracted[metad]
        rrv = jcsv[metad]
        if exv == rrv:
            ax.text(
                0.05,
                ymp,
                "%s: %s" % (metad, exv),
                fontsize=12,
                color="black",
            )
        else:
            ax.text(
                0.05,
                ymp,
                "%s: %s" % (metad, exv),
                fontsize=12,
                color="red",
            )
            ax.text(
                0.05,
                ymp - 0.1,
                "%s: %s" % (metad, rrv),
                fontsize=12,
                color="blue",
            )
        ymp -= 0.3


def models_agree(extracted, value, idx=None, agreement_count=2):
    """Check if the models agree on a value."""
    values = []
    for model_id in extracted.keys():
        if idx is not None:
            if value == "Years":
                val = extracted[model_id][value][idx]
            else:
                val = format_as_2f(extracted[model_id][value][idx])
        else:
            val = extracted[model_id][value]
        values.append(val)
    counts = Counter(values)
    top_two = counts.most_common(2)
    if top_two[0][0] == "N/A":  # Special case - 'agreed' on "can't do it"
        return (False, "N/A")  # Not counted as agreement
    if len(top_two) < 2:
        return (True, top_two[0][0])  # Only one unique value
    if top_two[0][1] == top_two[1][1]:  # No one most common value
        return (False, top_two[0][0])
    if top_two[0][1] >= agreement_count:  # Most common value is popular enough
        return (True, top_two[0][0])
    return (False, top_two[0][0])


def plot_metadata_agreement(ax, extracted, jcsv, agreement_count=2):

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

    ymp = 0.8
    for metad in ("Number", "Name"):
        match, exv = models_agree(extracted, metad, agreement_count=agreement_count)
        rrv = jcsv[metad]
        if match:  # Models agree
            if exv == rrv:  # on the right answer
                colour = (0, 0, 1)  # Blue
            else:  # on the wrong answer
                colour = (1, 0, 0)  # Red
        else:  # Models disagree
            colour = (0.5, 0.5, 0.5)  # Grey
        ax.text(
            0.05,
            ymp,
            "%s: %s" % (metad, exv),
            fontsize=14,
            color=colour,
        )
        ymp -= 0.3


# Plot fractional success at metadata into a given axes
def plot_metadata_fraction_agreement(ax, merged, cmp=None):

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

    ymp = 0.8
    for metad in ("Number", "Name"):
        blue = merged[metad].count("blue") / len(merged[metad])
        plot_two_colored_text(
            ax,
            0.05,
            ymp,
            "%s: " % metad,
            "  %d" % int((blue) * 100),
            size=12,
            colour1="black",
            colour2="blue",
        )
        red = merged[metad].count("red") / len(merged[metad])
        if int(red * 100) > 0:
            plot_two_colored_text(
                ax,
                0.05,
                ymp - 0.15,
                "%s: " % metad,
                "  %d" % int((red) * 100),
                size=12,
                colour1="white",  # Invisible
                colour2="red",
            )

        ymp -= 0.3


def plot_metadata_fraction(ax, merged, cmp=None):

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

    ymp = 0.8
    for metad in ("Number", "Name"):
        fraction = sum(merged[metad]) / len(merged[metad])
        ax.text(
            0.05,
            ymp,
            "%s: %d" % (metad, int(fraction * 100)),
            fontsize=12,
            color="black",
        )
        if cmp is not None:
            cmp_fraction = sum(cmp[metad]) / len(cmp[metad])
            color = "blue" if cmp_fraction < fraction else "red"
            plot_two_colored_text(
                ax,
                0.05,
                ymp - 0.15,
                "%s: " % metad,
                "  %d" % abs(int((fraction - cmp_fraction) * 100)),
                size=12,
                colour1="white",  # Invisible
                colour2=color,
            )

        ymp -= 0.3


# Plot the digitised numbers into a given axes
def plot_monthly_table(ax, extracted, jcsv, yticks=True):
    ax.set_xlim(0.5, 10.5)
    ax.set_xticks(range(1, 11))
    ax.xaxis.set_ticks_position("top")
    labels = ax.set_xticklabels(extracted["Years"])
    # Note - this has to be the last change made to the xtics, or the colours will be reset
    for year_idx, label in enumerate(labels):
        if extracted["Years"][year_idx] != jcsv["Years"][year_idx]:
            label.set_color("red")
    ax.set_ylim(0.5, 13)
    if yticks:
        ax.set_yticks(range(1, 13))
        ax.set_yticklabels(
            (
                "Jan",
                "Feb",
                "Mar",
                "Apr",
                "May",
                "Jun",
                "Jul",
                "Aug",
                "Sep",
                "Oct",
                "Nov",
                "Dec",
            )
        )
    else:
        ax.set_yticks([])
    ax.xaxis.set_label_position("top")
    ax.invert_yaxis()
    ax.set_aspect("auto")

    monthNumbers = {
        "January": 1,
        "February": 2,
        "March": 3,
        "April": 4,
        "May": 5,
        "June": 6,
        "July": 7,
        "August": 8,
        "September": 9,
        "October": 10,
        "November": 11,
        "December": 12,
    }

    for year_idx in range(10):
        for month in monthNumbers.keys():
            try:
                exv = format_value(extracted, month, year_idx)
                rrv = format_value(jcsv, month, year_idx)
                try:
                    if exv == rrv:
                        ax.text(
                            year_idx + 1,
                            monthNumbers[month],
                            exv,
                            ha="center",
                            va="center",
                            fontsize=12,
                            color="black",
                        )
                    else:
                        ax.text(
                            year_idx + 1,
                            monthNumbers[month],
                            exv,
                            ha="center",
                            va="center",
                            fontsize=12,
                            color="red",
                        )
                        ax.text(
                            year_idx + 1,
                            monthNumbers[month] + 0.5,
                            rrv,
                            ha="center",
                            va="center",
                            fontsize=12,
                            color="blue",
                        )
                except Exception as e:
                    print(rrv, exv)
                    print(e)
            except KeyError as e:
                continue


def plot_monthly_table_agreement(ax, extracted, jcsv, agreement_count=2, yticks=True):
    ax.set_xlim(0.5, 10.5)
    ax.set_xticks(range(1, 11))
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    xtl = ["N/A"] * 10
    for year_idx in range(10):
        match, xtl[year_idx] = models_agree(
            extracted, "Years", idx=year_idx, agreement_count=agreement_count
        )
    labels = ax.set_xticklabels(xtl)
    for year_idx, label in enumerate(labels):
        match, exv = models_agree(
            extracted, "Years", idx=year_idx, agreement_count=agreement_count
        )
        if match:  # Models agree
            if exv == jcsv["Years"][year_idx]:
                label.set_color("blue")
            else:  # on the wrong answer
                label.set_color("red")
        else:  # Models disagree
            label.set_color("grey")
    ax.set_ylim(0.5, 13)
    if yticks:
        ax.set_yticks(range(1, 13))
        ax.set_yticklabels(
            (
                "Jan",
                "Feb",
                "Mar",
                "Apr",
                "May",
                "Jun",
                "Jul",
                "Aug",
                "Sep",
                "Oct",
                "Nov",
                "Dec",
            )
        )
    else:
        ax.set_yticks([])
    ax.invert_yaxis()
    ax.set_aspect("auto")

    monthNumbers = {
        "January": 1,
        "February": 2,
        "March": 3,
        "April": 4,
        "May": 5,
        "June": 6,
        "July": 7,
        "August": 8,
        "September": 9,
        "October": 10,
        "November": 11,
        "December": 12,
    }

    for year_idx in range(10):
        for month in monthNumbers.keys():
            try:
                match, exv = models_agree(
                    extracted, month, idx=year_idx, agreement_count=agreement_count
                )
                if match:  # Models agree
                    if exv == jcsv[month][year_idx]:  # on the right answer
                        colour = (0, 0, 1)  # Blue
                    else:  # on the wrong answer
                        colour = (1, 0, 0)  # Red
                else:  # Models disagree
                    colour = (0.5, 0.5, 0.5)
                ax.text(
                    year_idx + 1,
                    monthNumbers[month],
                    exv,
                    ha="center",
                    va="center",
                    fontsize=14,
                    color=colour,
                )
            except KeyError as e:
                continue


def plot_monthly_table_fraction_agreement(ax, merged, cmp=None, yticks=True):
    years = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    ax.set_xlim(years[0] - 0.5, years[-1] + 0.5)
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    ax.set_xticks(range(years[0], years[-1] + 1))
    xtfraction = [
        merged["Years"][year_idx].count("blue") / len(merged["Years"][year_idx])
        for year_idx in range(10)
    ]
    xtl = [f"{int(fraction * 100)}" for fraction in xtfraction]
    ax.set_xticklabels(xtl)
    ax.set_ylim(0.5, 13)
    if yticks:
        ax.set_yticks(range(1, 13))
        ax.set_yticklabels(
            (
                "Jan",
                "Feb",
                "Mar",
                "Apr",
                "May",
                "Jun",
                "Jul",
                "Aug",
                "Sep",
                "Oct",
                "Nov",
                "Dec",
            )
        )
    else:
        ax.set_yticks([])
    ax.invert_yaxis()
    ax.set_aspect("auto")

    monthNumbers = {
        "January": 1,
        "February": 2,
        "March": 3,
        "April": 4,
        "May": 5,
        "June": 6,
        "July": 7,
        "August": 8,
        "September": 9,
        "October": 10,
        "November": 11,
        "December": 12,
    }

    for year in years:
        for month in monthNumbers.keys():
            blue = merged[month][year - 1].count("blue") / len(merged[month][year - 1])
            ax.text(
                year,
                monthNumbers[month],
                f"{int(blue * 100)}",
                ha="center",
                va="center",
                fontsize=12,
                color="blue",
            )
            red = merged[month][year - 1].count("red") / len(merged[month][year - 1])
            if (int(red * 100)) > 0:
                ax.text(
                    year,
                    monthNumbers[month] + 0.5,
                    f"{int(red * 100)}",
                    ha="center",
                    va="center",
                    fontsize=12,
                    color="red",
                )


def plot_monthly_table_fraction(ax, merged, cmp=None, yticks=True):
    years = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    ax.set_xlim(years[0] - 0.5, years[-1] + 0.5)
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    ax.set_xticks(range(years[0], years[-1] + 1))
    xtfraction = [
        sum(merged["Years"][year_idx]) / len(merged["Years"][year_idx])
        for year_idx in range(10)
    ]
    xtl = [f"{int(fraction * 100)}" for fraction in xtfraction]
    ax.set_xticklabels(xtl)
    ax.set_ylim(0.5, 13)
    if yticks:
        ax.set_yticks(range(1, 13))
        ax.set_yticklabels(
            (
                "Jan",
                "Feb",
                "Mar",
                "Apr",
                "May",
                "Jun",
                "Jul",
                "Aug",
                "Sep",
                "Oct",
                "Nov",
                "Dec",
            )
        )
    else:
        ax.set_yticks([])
    ax.invert_yaxis()
    ax.set_aspect("auto")

    monthNumbers = {
        "January": 1,
        "February": 2,
        "March": 3,
        "April": 4,
        "May": 5,
        "June": 6,
        "July": 7,
        "August": 8,
        "September": 9,
        "October": 10,
        "November": 11,
        "December": 12,
    }

    for year in years:
        for month in monthNumbers.keys():
            fraction = sum(merged[month][year - 1]) / len(merged[month][year - 1])
            ax.text(
                year,
                monthNumbers[month],
                f"{int(fraction * 100)}",
                ha="center",
                va="center",
                fontsize=12,
                color="black",
            )
            if cmp is not None:
                cmp_fraction = sum(cmp[month][year - 1]) / len(cmp[month][year - 1])
                color = "blue" if cmp_fraction < fraction else "red"
                ax.text(
                    year,
                    monthNumbers[month] + 0.5,
                    "%d" % abs(int((fraction - cmp_fraction) * 100)),
                    ha="center",
                    va="center",
                    fontsize=12,
                    color=color,
                )


# plot the extracted totals into a given axes
def plot_totals(ax, extracted, jcsv):
    ax.set_xlim(0.5, 10.5)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

    for year_idx in range(10):
        exv = format_value(extracted, "Totals", year_idx)
        rrv = format_value(jcsv, "Totals", year_idx)
        if exv == rrv:
            ax.text(
                year_idx + 1,
                0.7,
                exv,
                ha="center",
                va="center",
                fontsize=12,
                color="black",
            )
        else:
            ax.text(
                year_idx + 1,
                0.7,
                exv,
                ha="center",
                va="center",
                fontsize=12,
                color="red",
            )
            ax.text(
                year_idx + 1,
                0.3,
                rrv,
                ha="center",
                va="center",
                fontsize=12,
                color="blue",
            )


# Mark where multi models agreed - for totals
def plot_totals_agreement(
    ax,
    extracted,
    jcsv,
    agreement_count=2,
):
    ax.set_xlim(0.5, 10.5)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

    for year_idx in range(0, 10):
        match, exv = models_agree(
            extracted, "Totals", idx=year_idx, agreement_count=agreement_count
        )
        rrv = format_value(jcsv, "Totals", year_idx)

        if match:  # Models agree
            if exv == rrv:  # on the right answer
                colour = (0, 0, 1)  # Blue
            else:  # on the wrong answer
                colour = (1, 0, 0)  # Red
        else:  # Models disagree
            colour = (0.5, 0.5, 0.5)  # Grey
        ax.text(
            year_idx + 1,
            0.5,
            exv,
            ha="center",
            va="center",
            fontsize=14,
            color=colour,
        )


def plot_totals_fraction_agreement(ax, merged, cmp=None):
    years = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    ax.set_xlim(years[0] - 0.5, years[-1] + 0.5)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

    for year in years:
        blue = merged["Totals"][year - 1].count("blue") / len(
            merged["Totals"][year - 1]
        )
        ax.text(
            year,
            0.7,
            f"{int(blue * 100)}",
            ha="center",
            va="center",
            fontsize=12,
            color="blue",
        )
        red = merged["Totals"][year - 1].count("red") / len(merged["Totals"][year - 1])
        if (int(red * 100)) > 0:
            ax.text(
                year,
                0.3,
                f"{int(red * 100)}",
                ha="center",
                va="center",
                fontsize=12,
                color="red",
            )


def plot_totals_fraction(ax, merged, cmp=None):
    years = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    ax.set_xlim(years[0] - 0.5, years[-1] + 0.5)
    ax.set_ylim(0, 1)
    ax.set_xticks([])
    ax.set_yticks([])

    for year in years:
        fraction = sum(merged["Totals"][year - 1]) / len(merged["Totals"][year - 1])
        ax.text(
            year,
            0.7,
            f"{int(fraction * 100)}",
            ha="center",
            va="center",
            fontsize=12,
            color="black",
        )
        if cmp is not None:
            cmp_fraction = sum(cmp["Totals"][year - 1]) / len(cmp["Totals"][year - 1])
            color = "blue" if cmp_fraction < fraction else "red"
            ax.text(
                year,
                0.3,
                "%d" % abs(int((fraction - cmp_fraction) * 100)),
                ha="center",
                va="center",
                fontsize=12,
                color=color,
            )


def make_null_json():
    """Create a null JSON object for cases where no data is available."""
    all_na = {"Name": "N/A", "Number": "N/A", "Years": ["N/A"] * 10}
    for month in (
        "January",
        "February",
        "March",
        "April",
        "May",
        "June",
        "July",
        "August",
        "September",
        "October",
        "November",
        "December",
    ):
        all_na[month] = ["N/A"] * 10
    all_na["Totals"] = ["N/A"] * 10
    return all_na


# Models don't always make good JSON - fix the egregious problems so it parses
def quote_list_items(match):
    # Get the content inside the brackets
    items = [v.strip() for v in match.group(1).split(",")]
    quoted = ['"%s"' % v for v in items]
    return "[" + ", ".join(quoted) + "]"


def jsonfix(input):
    """Fix JSON that has numbers like .12 instead of 0.12"""
    fixed = re.sub(r"(?<!\d)\.(\d+)", r"0.\1", input)  # Fix numbers like .12 -> 0.12
    fixed = re.sub(r"(\d+):", r'"\1":', fixed)  # Fix keys like 2023: -> "2023":
    fixed = re.sub(r"\[([^\[\]]+)\]", quote_list_items, fixed)  # Quote list items
    fixed = fixed.replace('""', '"')
    fixed = "".join(
        c for c in fixed if c.isprintable()
    )  # Get rid of line breaks and other non-printable characters
    # Deal with bad terminations
    if not fixed.endswith("]}"):
        fixed = (
            fixed[: fixed.rfind("]") + 1] + "}"
        )  # Might cut off too much, but if so we're screwed anyway.
    # Get rid of any junk after the totals
    last_match = None
    for m in re.finditer(r'"Totals"\s*:\s*\[.*?\]', fixed, flags=re.DOTALL):
        last_match = m
    fixed = fixed[: last_match.end()] + "}" if last_match else fixed
    return fixed


# map JSON jeys like 'TOTALS' to 'Totals'
def cap_first_key(k: str) -> str:
    return k[:1].upper() + k[1:] if isinstance(k, str) else k


# Load the extracted data from a model, for a label
def load_extracted(model_id, label):
    """Load the extracted data from a model for a given label."""
    null_j = make_null_json()
    opfile = f"{os.getenv('PDIR')}/extracted/{model_id}/{label}.json"
    if not os.path.exists(opfile):
        print(f"No extraction for {model_id} {label}")
        return null_j
    with open(opfile, "r") as f:
        raw_j = f.read()
        fixed_j = jsonfix(raw_j)
        try:
            extracted = json.loads(fixed_j)
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON for {model_id} {label}: {e}")
            return null_j
    for key in null_j:
        if key not in extracted:
            extracted[key] = null_j[key]
        elif isinstance(extracted[key], list) and len(extracted[key]) < 10:
            # Ensure lists have 10 items
            extracted[key] += ["N/A"] * (10 - len(extracted[key]))
    # Fix boring common error
    if extracted["Name"].lower().startswith("rainfall at"):
        extracted["Name"] = extracted["Name"][12:]
    extracted = {cap_first_key(k): v for k, v in extracted.items()}
    return extracted


# find where the model is accurate for each value in one case
def validate_case(model_id, label):

    # load the image/data pair
    img, csv = load_pair(label)
    jcsv = json.loads(csv_to_json(csv))

    # Load the model extracted data
    extracted = load_extracted(model_id, label)

    # Check if the extracted data matches the CSV data
    correct = {}
    try:
        correct["Name"] = jcsv["Name"] == extracted["Name"]
    except KeyError:
        correct["Name"] = False
    try:
        correct["Number"] = jcsv["Number"] == extracted["Number"]
    except KeyError:
        correct["Number"] = False
    correct["Years"] = [False] * 10
    correct["Totals"] = [False] * 10
    for month in (
        "January",
        "February",
        "March",
        "April",
        "May",
        "June",
        "July",
        "August",
        "September",
        "October",
        "November",
        "December",
    ):
        correct[month] = [False] * 10
    for yr in range(10):
        try:
            correct["Years"][yr] = jcsv["Years"][yr] == extracted["Years"][yr]
        except (KeyError, IndexError):
            correct["Years"][yr] = False
        try:
            correct["Totals"][yr] = jcsv["Totals"][yr] == extracted["Totals"][yr]
        except (KeyError, IndexError):
            correct["Totals"][yr] = False
        for month in (
            "January",
            "February",
            "March",
            "April",
            "May",
            "June",
            "July",
            "August",
            "September",
            "October",
            "November",
            "December",
        ):
            try:
                correct[month][yr] = jcsv[month][yr] == extracted[month][yr]
            except (KeyError, IndexError):
                correct[month][yr] = False

    return correct


# Merge validated cases into a single dictionary
def merge_validated_cases(merged, case):
    if merged is None:
        merged = case
        for key in merged:
            if isinstance(merged[key], list):
                for i in range(len(merged[key])):
                    merged[key][i] = [case[key][i]]
            else:
                merged[key] = [case[key]]
    else:
        for key in case:
            if isinstance(case[key], list):
                for i in range(len(case[key])):
                    merged[key][i].append(case[key][i])
            else:
                merged[key].append(case[key])
    return merged