Experiments in Fine-tuning Llama 3.2 for classification using Lora

An experiment on using Lora to fine-tune Llama 3.2 to specialize a model for classifying support requests

I started out wanting to play around with fine-tuning Llama 3.2 to see if I could specialize it's capabilities for classifying support requests.

After diving in fairly deep, writing a fair amount of code, and researching a few alternative approaches I would not try this method first nor would I continue with this approach further before investigating alternatives first.

For starters, Meta provides a way and documentation to fine-tune their Llama models. They include documentation on full fine-tuning of the model parameters but also PEFT (Parameter Efficient Fine-Tuning) using LoRa/QLoRa. (Keep in mind that you will likely have to request access to LLama models on HuggingFace in order to train using these methods.) LoRA is particularly ingenious in that it allows you to more efficiently train a model and get a much smaller "adaptation" of that base model than a full re-training of all its parameters.

Using Meta's prebuilt tools would undoubtedly lead to better results faster than writing your own tool once you figure out how to use it.

However, there are alternative, and likely less costly, approaches out there for classification tasks. One being the use of a vector database to look up sample classifications from a database of examples. This is particularly useful when there are too many examples or variations to provide in the context and you only want to provide relevant ones.

If I were to attempt to implement a classification mechanism that is more complex than determining if something is spam/not spam, then I would go down the route of using a vector database of examples and basing the classification off of that. I believe the results for most tasks will be acceptable and the cost for the implementation will be much lower.

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from datasets import Dataset
import json
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
import os

def model_name():
    return "unsloth/Llama-3.2-1B"

def compute_engine():
    return "cuda" if torch.cuda.is_available() else "cpu"

class JsonFile:
    def save(self, filename, json_data):
        with open(self.file_path(filename), 'w') as f:
            json.dump(json_data, f, indent=4)

    def load(self, file_name):
        with open(self.file_path(file_name), 'r') as f:
            return json.load(f)

    def file_path(self, file_name):
        current_dir = os.path.dirname(os.path.abspath(__file__))
        return os.path.join(current_dir, file_name)

class Labels:
    @classmethod
    def instance(cls, json_data):
        unique_labels = sorted(list(set(item['tag'] for item in json_data)))
        json_file = JsonFile()
        label2id = {label: i for i, label in enumerate(unique_labels)}
        id2label = {i: label for label, i in label2id.items()}

        instance = cls(json_file, label2id, id2label)
        instance.persist()
        return instance

    @classmethod
    def from_files(cls):
        json_file = JsonFile()
        label2id = json_file.load('label2id.json')
        id2label = json_file.load('id2label.json')
        return cls(json_file, label2id, id2label)

    def __init__(self, json_file, label2id, id2label):
        self.label2id = label2id
        self.id2label = id2label
        self.json_file = json_file
        self.count = len(label2id)

        print("Label mapping:", self.label2id)
        print("Label mapping (reverse):", self.id2label)

    def persist(self):
        self.json_file.save('label2id.json', self.label2id)
        self.json_file.save('id2label.json', self.id2label)

    def id_from_label(self, label):
        return self.label2id[label]

    def label_from_id(self, id):
        return self.id2label[str(id)]


class JsonDataset:
    @classmethod
    def instance(cls):
        json_file = JsonFile()
        data = json_file.load('customer_service_dataset.json')
        labels = Labels.instance(data)
        return cls(data, labels)

    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def prepare(self,):
        return Dataset.from_dict({
            'text': [item['question'] for item in self.data],
            'label': [self.labels.id_from_label(item['tag']) for item in self.data]
        })

class Collator:
    def __init__(self, tokenizer, device):
        self.tokenizer = tokenizer
        self.device = device

    def collate(self, batch):
        texts = [item['text'] for item in batch]
        labels = [item['label'] for item in batch]

        # Tokenize with proper padding
        inputs = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512  # Add reasonable max length
        )

        return {
            'input_ids': inputs['input_ids'].to(self.device),
            'attention_mask': inputs['attention_mask'].to(self.device),
            'labels': torch.tensor(labels, dtype=torch.long).to(self.device)
        }

class TokenizerFactory:
    def create(self, full_model_name: str = model_name()):
        tokenizer = AutoTokenizer.from_pretrained(full_model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "right"  # Ensure right padding
        return tokenizer

class ModelFactory:
    def create(self, full_model_name: str = model_name(), device: str = compute_engine(), tokenizer = TokenizerFactory().create()):

        # Load base model with padding token configuration
        model = AutoModelForSequenceClassification.from_pretrained(
            full_model_name,
            num_labels=9,  # Number of customer service categories
            pad_token_id=tokenizer.pad_token_id,
        ).to(device)

        # Ensure model knows about padding token
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.padding_side = "right"

        # Initialize LoRA config with correct target modules for Llama
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],  # Updated for Llama
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.SEQ_CLS
        )

        # Apply LoRA
        model = get_peft_model(model, lora_config)
        return model


class SimpleFineTuner:
    @classmethod
    def create(
        cls,
        device = compute_engine()
    ):
        tokenizer = TokenizerFactory().create()
        model = ModelFactory().create()
        collator = Collator(tokenizer, device)
        json_file = JsonFile()
        return cls(model, tokenizer, device, collator, json_file)


    def __init__(
        self,
        model,
        tokenizer,
        device: str,
        collator,
        json_file
    ):
        self.optimizer = None
        self.train_loader = None
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.collator = collator
        self.json_file = json_file

    def train(
        self,
        dataset: Dataset,
        num_epochs: int = 3,
    ):
        self.model.train()
        for epoch in range(num_epochs):
            total_loss = 0
            progress_bar = tqdm(self._train_loader(dataset), desc=f"Epoch {epoch + 1}/{num_epochs}")
            
            for batch in progress_bar:
                total_loss = self._process_batch(batch, self._optimizer(), progress_bar, total_loss)

            avg_loss = total_loss / len(self._train_loader(dataset))
            print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        self.model.save_pretrained(self.json_file.file_path('customer_service_lora_weights'))

    def _process_batch(self, batch, optimizer, progress_bar, total_loss):
        optimizer.zero_grad()
        # Forward pass
        outputs = self.model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        loss = outputs.loss
        total_loss += loss.item()
        # Backward pass
        loss.backward()
        optimizer.step()
        progress_bar.set_postfix({'loss': loss.item()})
        return total_loss

    def _train_loader(self, dataset):
        if self.train_loader is None:
            self.train_loader = DataLoader(
                dataset,
                batch_size=8,
                shuffle=True,
                collate_fn=self.collator.collate
            )

        return self.train_loader

    def _optimizer(self, ):
        if self.optimizer is None:
            learning_rate: float = 2e-4
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)

        return self.optimizer

class LabelPredictor:
    @classmethod
    def instance(cls):
        """Load a saved model and label mappings"""
        model_path = JsonFile().file_path('customer_service_lora_weights')
        tokenizer = TokenizerFactory().create()

        model = PeftModel.from_pretrained(
            ModelFactory().create(tokenizer=tokenizer),
            model_path,
            is_trainable=False
        )
        model.eval()
        labels = Labels.from_files()

        return cls(model, tokenizer, labels)

    def __init__(self, model, tokenizer, labels, device: str = compute_engine()):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.labels = labels

    def predict(self, text: str) -> tuple[int, str]:
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(self.device)

        # Get prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            predicted_class = torch.argmax(outputs.logits, dim=-1).item()

        return predicted_class, self.labels.label_from_id(predicted_class)


def run_predict():
    label_predictor = LabelPredictor.instance()
    class_id, label = label_predictor.predict(args.predict)
    print("Input text:", args.predict)
    print(f"Prediction: {label} (class ID: {class_id})")

def run_train():
    dataset = JsonDataset.instance().prepare()
    SimpleFineTuner.create().train(dataset)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Customer Service Simple Fine Tuner")
    parser.add_argument('--train', action='store_true', help='Train the model')
    parser.add_argument('--predict', type=str, help='Input for prediction or training dataset', required=False)
    args = parser.parse_args()

    if args.train:
        run_train()
    else:
        if args.predict is None:
            raise ValueError("Please provide an input for prediction")

        run_predict()