Gemma-3-4b: Train against a sample of images with transcriptions

Much the same as the script to fine tune SmolVLM. The main difference is that Gemma requires image input to be pre-cut into squares - so one rectangular page has to be explicitly subdivided). I don’t know how best to do the subdivision, but the default used by this script works OK - at least after fine-tuneing. (I have not checked, but pobably it’s important to use the same subdivision method for training and for extraction).

#!/usr/bin/env python

# Gemma model training script adapted for RRR
import os
from PIL import Image
import argparse
from RR_utils.hf import HFlogin
from RR_utils.image import cut_image

HFlogin()  # Connect to Huggingface Hub - only needed for initial model weights download

from rainfall_rescue.utils.pairs import get_index_list, load_pair, csv_to_json

import torch
from torch.utils.data import Dataset

from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    TrainerCallback,
)

from peft import LoraConfig, PeftModel
from trl import SFTTrainer

# Text prompts - system and user
from models.smolvlm.prompts import s_prompt, u_prompt


parser = argparse.ArgumentParser()
parser.add_argument(
    "--model_id",
    help="Model ID",
    type=str,
    required=False,
    default="google/gemma-3-4b-it",
)
parser.add_argument(
    "--nmax",
    help="Maximum number of training cases to use",
    type=int,
    required=False,
    default=100,
)
parser.add_argument(
    "--fake",
    help="Use fake cases instead of real",
    action="store_true",
    required=False,
    default=False,
)
parser.add_argument(
    "--random_seed",
    help="Control the set of 'random'; choices",
    type=int,
    required=False,
    default=None,
)
parser.add_argument(
    "--epochs",
    help="Number of epochs to train",
    type=int,
    required=False,
    default=3,
)
parser.add_argument(
    "--run_id",
    help="Identifier for this training run",
    type=str,
    required=True,
    default=None,
)
parser.add_argument(
    "--patch_size",
    help="Image patch size (pixels)",
    type=int,
    required=False,
    default=600,
)

clargs = parser.parse_args()


# Convert dataset to OAI messages
def format_data(sample):
    # Desired output is JSON ormated csv data
    assistant_message = csv_to_json(sample[2])
    # Cut the image into squares
    img = sample[1]
    if clargs.patch_size is not None:
        blocks = cut_image(img, clargs.patch_size, overlap=0.1)
    else:
        blocks = [img]  # Use the whole image if no patch size is specified

    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": s_prompt}],
            },
            {
                "role": "user",
                "content": [{"type": "image", "image": block} for block in blocks]
                + [
                    {
                        "type": "text",
                        "text": u_prompt,
                    }
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": assistant_message}],
            },
        ],
    }


# Make a training dataset from the RR image/CSV pairs
class RRTrainingDataset(Dataset):
    def __init__(self, max_n=None, seed=None):
        self.labels = get_index_list(
            max_n=max_n, seed=seed, fake=clargs.fake, training=True
        )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img, csv = load_pair(self.labels[idx])
        return self.labels[idx], img, csv


dataset = RRTrainingDataset(max_n=clargs.nmax, seed=clargs.random_seed)
# Convert dataset to OAI messages
dataset = [format_data(sample) for sample in dataset]


# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager",  # Use "flash_attention_2" when running on Ampere or newer GPU, and 'eager' for older GPUs
    torch_dtype=torch.bfloat16,  # What torch dtype to use, defaults to auto
    device_map="auto",  # Let torch decide how to load the model
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(clargs.model_id, **model_kwargs)

processor = AutoProcessor.from_pretrained(clargs.model_id)

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)
from trl import SFTConfig

sargs = SFTConfig(
    output_dir="%s/%s"
    % (os.getenv("PDIR"), clargs.run_id),  # directory to save and repository id
    num_train_epochs=clargs.epochs,  # number of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    gradient_accumulation_steps=4,  # number of steps before performing a backward/update pass
    gradient_checkpointing=True,  # use gradient checkpointing to save memory
    optim="adamw_torch_fused",  # use fused adamw optimizer
    logging_steps=5,  # log every 5 steps
    save_strategy="epoch",  # save checkpoint every epoch
    learning_rate=1e-4,  # 2e-4,  # learning rate, based on QLoRA paper
    bf16=True,  # use bfloat16 precision
    max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,  # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",  # use constant learning rate scheduler
    push_to_hub=False,  # push model to hub
    report_to="tensorboard",  # report metrics to tensorboard
    logging_dir="%s/%s/logs"
    % (os.getenv("PDIR"), clargs.run_id),  # directory to save logs
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",  # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
sargs.remove_unused_columns = False  # important for collator


# Create a data collator to encode text and image pairs
# Don't understand this bit - why can't the processor just operate on messages?


def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs


def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch


# Define a callback to save a merged version of the model at the end of each epoch
class SaveMergedCallback(TrainerCallback):
    def __init__(self, base_model_id, out_dir):
        self.base_model_id = base_model_id
        self.out_dir = out_dir

    # Called at end of epoch
    def on_epoch_end(self, args, state, control, **kwargs):
        print("[SaveMergedCallback] on_epoch_end called")
        trainer.save_model()
        torch.cuda.empty_cache()
        model = AutoModelForImageTextToText.from_pretrained(
            clargs.model_id, low_cpu_mem_usage=True
        )
        # Merge LoRA and base model and save
        peft_model = PeftModel.from_pretrained(model, sargs.output_dir)
        merged_model = peft_model.merge_and_unload()
        merged_dir = os.path.join(
            sargs.output_dir, f"merged_epoch_{int(getattr(state,'epoch',0))}"
        )
        os.makedirs(merged_dir, exist_ok=True)
        merged_model.save_pretrained(
            merged_dir, safe_serialization=True, max_shard_size="2GB"
        )
        # Also need to save the processor and tokenizer to make a reusable model
        processor.save_pretrained(merged_dir)
        processor.tokenizer.save_pretrained(merged_dir)
        del merged_model
        del peft_model
        torch.cuda.empty_cache()


trainer = SFTTrainer(
    model=model,
    args=sargs,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
    callbacks=[SaveMergedCallback(clargs.model_id, sargs.output_dir)],
)


# Start training, the model will be automatically saved to the Hub and the output directory
try:
    trainer.train(
        resume_from_checkpoint=True
    )  # Auto restart if possible (job is likely to be preempted at some point)
except ValueError as e:
    trainer.train(resume_from_checkpoint=False)  # No checkpoint, start from scratch