Granite: Train against a sample of images with transcriptionsΒΆ

Much the same as the script to fine tune SmolVLM.

#!/usr/bin/env python

# Granite model training script adapted for RRR
import os
import random
from PIL import Image
import argparse
from RR_utils.hf import HFlogin
from RR_utils.image import cut_image

HFlogin()  # Connect to Huggingface Hub - only needed for initial model weights download

from rainfall_rescue.utils.pairs import get_index_list, load_pair, csv_to_json

import torch
from torch.utils.data import Dataset

from transformers import AutoProcessor, AutoModelForImageTextToText, TrainerCallback

from peft import LoraConfig, PeftModel
from trl import SFTTrainer

# Text prompts - system and user
from models.smolvlm.prompts import s_prompt, u_prompt

parser = argparse.ArgumentParser()
parser.add_argument(
    "--model_id",
    help="Model ID",
    type=str,
    required=False,
    default="ibm-granite/granite-vision-3.3-2b",
)
parser.add_argument(
    "--purpose",
    help="Training or test or neither",
    type=str,
    required=False,
    default="Training",
)
parser.add_argument(
    "--nmax",
    help="Maximum number of training cases to use",
    type=int,
    required=False,
    default=100,
)
parser.add_argument(
    "--fake",
    help="Use fake cases instead of real",
    action="store_true",
    required=False,
    default=False,
)
parser.add_argument(
    "--random_seed",
    help="Control the set of 'random'; choices",
    type=int,
    required=False,
    default=None,
)
parser.add_argument(
    "--epochs",
    help="Number of epochs to train",
    type=int,
    required=False,
    default=3,
)
parser.add_argument(
    "--run_id",
    help="Identifier for this training run",
    type=str,
    required=True,
    default=None,
)
parser.add_argument(
    "--patch_size",
    help="Image patch size (pixels)",
    type=int,
    required=False,
    default=None,
)

clargs = parser.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"


# Convert dataset to OAI messages
def format_data(sample):
    # Desired output is JSON ormated csv data
    assistant_message = csv_to_json(sample[2])
    # Cut the image into squares
    img = sample[1]
    if clargs.patch_size is not None:
        blocks = cut_image(img, clargs.patch_size, overlap=0.1)
    else:
        blocks = [img]  # Use the whole image if no patch size is specified

    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": s_prompt}],
            },
            {
                "role": "user",
                "content": [{"type": "image", "image": block} for block in blocks]
                + [
                    {
                        "type": "text",
                        "text": s_prompt,
                    },
                    {
                        "type": "text",
                        "text": u_prompt,
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": assistant_message}],
            },
        ],
    }


# Make a training dataset from the RR image/CSV pairs
class RRTrainingDataset(Dataset):
    def __init__(self, max_n=None, seed=None):
        if clargs.purpose.lower() == "test":
            self.labels = get_index_list(
                max_n=max_n, seed=seed, fake=clargs.fake, training=False
            )
        elif clargs.purpose[:5].lower() == "train":
            self.labels = get_index_list(
                max_n=max_n, seed=seed, fake=clargs.fake, training=True
            )
        else:
            self.labels = get_index_list(
                max_n=max_n, seed=seed, fake=clargs.fake, training=None
            )

        if clargs.random_seed is not None:
            torch.manual_seed(clargs.random_seed)
            random.seed(clargs.random_seed)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img, csv = load_pair(self.labels[idx])
        return self.labels[idx], img, csv


dataset = RRTrainingDataset(max_n=clargs.nmax, seed=clargs.random_seed)
# Convert dataset to OAI messages
dataset = [format_data(sample) for sample in dataset]


# Load model and tokenizer
processor = AutoProcessor.from_pretrained(clargs.model_id)
model = AutoModelForImageTextToText.from_pretrained(clargs.model_id).to(device)

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)
from trl import SFTConfig

sargs = SFTConfig(
    output_dir="%s/%s"
    % (os.getenv("PDIR"), clargs.run_id),  # directory to save and repository id
    num_train_epochs=clargs.epochs,  # number of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    gradient_accumulation_steps=4,  # number of steps before performing a backward/update pass
    gradient_checkpointing=True,  # use gradient checkpointing to save memory
    optim="adamw_torch_fused",  # use fused adamw optimizer
    logging_steps=5,  # log every 5 steps
    save_strategy="epoch",  # save checkpoint every epoch
    learning_rate=1e-4,  # 2e-4,  # learning rate, based on QLoRA paper
    bf16=True,  # use bfloat16 precision
    max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,  # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",  # use constant learning rate scheduler
    push_to_hub=False,  # push model to hub
    report_to="tensorboard",  # report metrics to tensorboard
    logging_dir="%s/%s/logs"
    % (os.getenv("PDIR"), clargs.run_id),  # directory to save logs
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",  # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
sargs.remove_unused_columns = False  # important for collator


# Create a data collator to encode text and image pairs
# Don't understand this bit - why can't the processor just operate on messages?


def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs


def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    assistant_tokens = processor.tokenizer("<|assistant|>", return_tensors="pt")[
        "input_ids"
    ][0]
    eos_token = processor.tokenizer("<|end_of_text|>", return_tensors="pt")[
        "input_ids"
    ][0]

    for i in range(batch["input_ids"].shape[0]):
        apply_loss = False
        for j in range(batch["input_ids"].shape[1]):
            if not apply_loss:
                labels[i][j] = -100
            if (j >= len(assistant_tokens) + 1) and torch.all(
                batch["input_ids"][i][j + 1 - len(assistant_tokens) : j + 1]
                == assistant_tokens
            ):
                apply_loss = True
            if batch["input_ids"][i][j] == eos_token:
                apply_loss = False

    batch["labels"] = labels
    return batch


# Define a callback to save a merged version of the model at the end of each epoch
class SaveMergedCallback(TrainerCallback):
    def __init__(self, base_model_id, out_dir):
        self.base_model_id = base_model_id
        self.out_dir = out_dir

    # Called at end of epoch
    def on_epoch_end(self, args, state, control, **kwargs):
        print("[SaveMergedCallback] on_epoch_end called")
        trainer.save_model()
        torch.cuda.empty_cache()
        model = AutoModelForImageTextToText.from_pretrained(
            clargs.model_id, low_cpu_mem_usage=True
        )
        # Merge LoRA and base model and save
        peft_model = PeftModel.from_pretrained(model, sargs.output_dir)
        merged_model = peft_model.merge_and_unload()
        merged_dir = os.path.join(
            sargs.output_dir, f"merged_epoch_{int(getattr(state,'epoch',0))}"
        )
        os.makedirs(merged_dir, exist_ok=True)
        merged_model.save_pretrained(
            merged_dir, safe_serialization=True, max_shard_size="2GB"
        )
        # Also need to save the processor and tokenizer to make a reusable model
        processor.save_pretrained(merged_dir)
        processor.tokenizer.save_pretrained(merged_dir)
        del merged_model
        del peft_model
        torch.cuda.empty_cache()


trainer = SFTTrainer(
    model=model,
    args=sargs,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
    callbacks=[SaveMergedCallback(clargs.model_id, sargs.output_dir)],
)


# Start training, the model will be automatically saved to the Hub and the output directory
try:
    trainer.train(
        resume_from_checkpoint=True
    )  # Auto restart if possible (job is likely to be preempted at some point)
except ValueError as e:
    trainer.train(resume_from_checkpoint=False)  # No checkpoint, start from scratch

# Save the final model
trainer.save_model()
# free the memory again
del model
del trainer
torch.cuda.empty_cache()


# We trained a LORA model, that's additional weights instead of of the base model.
# We need to merge the LORA weights with the base model weights to make a new base model

# Load Model base model
# model = AutoModelForImageTextToText.from_pretrained(args.model_id, low_cpu_mem_usage=True)

# # Merge LoRA and base model and save
# peft_model = PeftModel.from_pretrained(model, sargs.output_dir)
# merged_model = peft_model.merge_and_unload()
# merged_model.save_pretrained(
#     "merged_model", safe_serialization=True, max_shard_size="2GB"
# )