Gemma-3-4b: Train against a sample of images with transcriptions¶
Much the same as the script to fine tune SmolVLM. The main difference is that Gemma requires image input to be pre-cut into squares - so one rectangular page has to be explicitly subdivided). I don’t know how best to do the subdivision, but the default used by this script works OK - at least after fine-tuneing. (I have not checked, but pobably it’s important to use the same subdivision method for training and for extraction).
#!/usr/bin/env python
# Gemma model training script adapted for RRR
import os
from PIL import Image
import argparse
from RR_utils.hf import HFlogin
from RR_utils.image import cut_image
HFlogin() # Connect to Huggingface Hub - only needed for initial model weights download
from rainfall_rescue.utils.pairs import get_index_list, load_pair, csv_to_json
import torch
from torch.utils.data import Dataset
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
TrainerCallback,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
# Text prompts - system and user
from models.smolvlm.prompts import s_prompt, u_prompt
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
help="Model ID",
type=str,
required=False,
default="google/gemma-3-4b-it",
)
parser.add_argument(
"--nmax",
help="Maximum number of training cases to use",
type=int,
required=False,
default=100,
)
parser.add_argument(
"--fake",
help="Use fake cases instead of real",
action="store_true",
required=False,
default=False,
)
parser.add_argument(
"--random_seed",
help="Control the set of 'random'; choices",
type=int,
required=False,
default=None,
)
parser.add_argument(
"--epochs",
help="Number of epochs to train",
type=int,
required=False,
default=3,
)
parser.add_argument(
"--run_id",
help="Identifier for this training run",
type=str,
required=True,
default=None,
)
parser.add_argument(
"--patch_size",
help="Image patch size (pixels)",
type=int,
required=False,
default=600,
)
clargs = parser.parse_args()
# Convert dataset to OAI messages
def format_data(sample):
# Desired output is JSON ormated csv data
assistant_message = csv_to_json(sample[2])
# Cut the image into squares
img = sample[1]
if clargs.patch_size is not None:
blocks = cut_image(img, clargs.patch_size, overlap=0.1)
else:
blocks = [img] # Use the whole image if no patch size is specified
return {
"messages": [
{
"role": "system",
"content": [{"type": "text", "text": s_prompt}],
},
{
"role": "user",
"content": [{"type": "image", "image": block} for block in blocks]
+ [
{
"type": "text",
"text": u_prompt,
}
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": assistant_message}],
},
],
}
# Make a training dataset from the RR image/CSV pairs
class RRTrainingDataset(Dataset):
def __init__(self, max_n=None, seed=None):
self.labels = get_index_list(
max_n=max_n, seed=seed, fake=clargs.fake, training=True
)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img, csv = load_pair(self.labels[idx])
return self.labels[idx], img, csv
dataset = RRTrainingDataset(max_n=clargs.nmax, seed=clargs.random_seed)
# Convert dataset to OAI messages
dataset = [format_data(sample) for sample in dataset]
# Define model init arguments
model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU, and 'eager' for older GPUs
torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(clargs.model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(clargs.model_id)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=[
"lm_head",
"embed_tokens",
],
)
from trl import SFTConfig
sargs = SFTConfig(
output_dir="%s/%s"
% (os.getenv("PDIR"), clargs.run_id), # directory to save and repository id
num_train_epochs=clargs.epochs, # number of training epochs
per_device_train_batch_size=1, # batch size per device during training
gradient_accumulation_steps=4, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=5, # log every 5 steps
save_strategy="epoch", # save checkpoint every epoch
learning_rate=1e-4, # 2e-4, # learning rate, based on QLoRA paper
bf16=True, # use bfloat16 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=False, # push model to hub
report_to="tensorboard", # report metrics to tensorboard
logging_dir="%s/%s/logs"
% (os.getenv("PDIR"), clargs.run_id), # directory to save logs
gradient_checkpointing_kwargs={
"use_reentrant": False
}, # use reentrant checkpointing
dataset_text_field="", # need a dummy field for collator
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
)
sargs.remove_unused_columns = False # important for collator
# Create a data collator to encode text and image pairs
# Don't understand this bit - why can't the processor just operate on messages?
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
# Iterate through each conversation
for msg in messages:
# Get content (ensure it's a list)
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
# Check each content element for images
for element in content:
if isinstance(element, dict) and (
"image" in element or element.get("type") == "image"
):
# Get the image and convert to RGB
if "image" in element:
image = element["image"]
else:
image = element
image_inputs.append(image.convert("RGB"))
return image_inputs
def collate_fn(examples):
texts = []
images = []
for example in examples:
image_inputs = process_vision_info(example["messages"])
text = processor.apply_chat_template(
example["messages"], add_generation_prompt=False, tokenize=False
)
texts.append(text.strip())
images.append(image_inputs)
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
labels = batch["input_ids"].clone()
# Mask image tokens
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.special_tokens_map["boi_token"]
)
]
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch
# Define a callback to save a merged version of the model at the end of each epoch
class SaveMergedCallback(TrainerCallback):
def __init__(self, base_model_id, out_dir):
self.base_model_id = base_model_id
self.out_dir = out_dir
# Called at end of epoch
def on_epoch_end(self, args, state, control, **kwargs):
print("[SaveMergedCallback] on_epoch_end called")
trainer.save_model()
torch.cuda.empty_cache()
model = AutoModelForImageTextToText.from_pretrained(
clargs.model_id, low_cpu_mem_usage=True
)
# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, sargs.output_dir)
merged_model = peft_model.merge_and_unload()
merged_dir = os.path.join(
sargs.output_dir, f"merged_epoch_{int(getattr(state,'epoch',0))}"
)
os.makedirs(merged_dir, exist_ok=True)
merged_model.save_pretrained(
merged_dir, safe_serialization=True, max_shard_size="2GB"
)
# Also need to save the processor and tokenizer to make a reusable model
processor.save_pretrained(merged_dir)
processor.tokenizer.save_pretrained(merged_dir)
del merged_model
del peft_model
torch.cuda.empty_cache()
trainer = SFTTrainer(
model=model,
args=sargs,
train_dataset=dataset,
peft_config=peft_config,
processing_class=processor,
data_collator=collate_fn,
callbacks=[SaveMergedCallback(clargs.model_id, sargs.output_dir)],
)
# Start training, the model will be automatically saved to the Hub and the output directory
try:
trainer.train(
resume_from_checkpoint=True
) # Auto restart if possible (job is likely to be preempted at some point)
except ValueError as e:
trainer.train(resume_from_checkpoint=False) # No checkpoint, start from scratch