Experiments in Fine-tuning Llama 3.2 for classification using Lora
I started out wanting to play around with fine-tuning Llama 3.2 to see if I could specialize it's capabilities for classifying support requests.
After diving in fairly deep, writing a fair amount of code, and researching a few alternative approaches I would not try this method first nor would I continue with this approach further before investigating alternatives first.
For starters, Meta provides a way and documentation to fine-tune their Llama models. They include documentation on full fine-tuning of the model parameters but also PEFT (Parameter Efficient Fine-Tuning) using LoRa/QLoRa. (Keep in mind that you will likely have to request access to LLama models on HuggingFace in order to train using these methods.) LoRA is particularly ingenious in that it allows you to more efficiently train a model and get a much smaller "adaptation" of that base model than a full re-training of all its parameters.
Using Meta's prebuilt tools would undoubtedly lead to better results faster than writing your own tool once you figure out how to use it.
However, there are alternative, and likely less costly, approaches out there for classification tasks. One being the use of a vector database to look up sample classifications from a database of examples. This is particularly useful when there are too many examples or variations to provide in the context and you only want to provide relevant ones.
If I were to attempt to implement a classification mechanism that is more complex than determining if something is spam/not spam, then I would go down the route of using a vector database of examples and basing the classification off of that. I believe the results for most tasks will be acceptable and the cost for the implementation will be much lower.
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from datasets import Dataset
import json
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
import os
def model_name():
return "unsloth/Llama-3.2-1B"
def compute_engine():
return "cuda" if torch.cuda.is_available() else "cpu"
class JsonFile:
def save(self, filename, json_data):
with open(self.file_path(filename), 'w') as f:
json.dump(json_data, f, indent=4)
def load(self, file_name):
with open(self.file_path(file_name), 'r') as f:
return json.load(f)
def file_path(self, file_name):
current_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(current_dir, file_name)
class Labels:
@classmethod
def instance(cls, json_data):
unique_labels = sorted(list(set(item['tag'] for item in json_data)))
json_file = JsonFile()
label2id = {label: i for i, label in enumerate(unique_labels)}
id2label = {i: label for label, i in label2id.items()}
instance = cls(json_file, label2id, id2label)
instance.persist()
return instance
@classmethod
def from_files(cls):
json_file = JsonFile()
label2id = json_file.load('label2id.json')
id2label = json_file.load('id2label.json')
return cls(json_file, label2id, id2label)
def __init__(self, json_file, label2id, id2label):
self.label2id = label2id
self.id2label = id2label
self.json_file = json_file
self.count = len(label2id)
print("Label mapping:", self.label2id)
print("Label mapping (reverse):", self.id2label)
def persist(self):
self.json_file.save('label2id.json', self.label2id)
self.json_file.save('id2label.json', self.id2label)
def id_from_label(self, label):
return self.label2id[label]
def label_from_id(self, id):
return self.id2label[str(id)]
class JsonDataset:
@classmethod
def instance(cls):
json_file = JsonFile()
data = json_file.load('customer_service_dataset.json')
labels = Labels.instance(data)
return cls(data, labels)
def __init__(self, data, labels):
self.data = data
self.labels = labels
def prepare(self,):
return Dataset.from_dict({
'text': [item['question'] for item in self.data],
'label': [self.labels.id_from_label(item['tag']) for item in self.data]
})
class Collator:
def __init__(self, tokenizer, device):
self.tokenizer = tokenizer
self.device = device
def collate(self, batch):
texts = [item['text'] for item in batch]
labels = [item['label'] for item in batch]
# Tokenize with proper padding
inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512 # Add reasonable max length
)
return {
'input_ids': inputs['input_ids'].to(self.device),
'attention_mask': inputs['attention_mask'].to(self.device),
'labels': torch.tensor(labels, dtype=torch.long).to(self.device)
}
class TokenizerFactory:
def create(self, full_model_name: str = model_name()):
tokenizer = AutoTokenizer.from_pretrained(full_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Ensure right padding
return tokenizer
class ModelFactory:
def create(self, full_model_name: str = model_name(), device: str = compute_engine(), tokenizer = TokenizerFactory().create()):
# Load base model with padding token configuration
model = AutoModelForSequenceClassification.from_pretrained(
full_model_name,
num_labels=9, # Number of customer service categories
pad_token_id=tokenizer.pad_token_id,
).to(device)
# Ensure model knows about padding token
model.config.pad_token_id = tokenizer.pad_token_id
model.config.padding_side = "right"
# Initialize LoRA config with correct target modules for Llama
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Updated for Llama
lora_dropout=0.1,
bias="none",
task_type=TaskType.SEQ_CLS
)
# Apply LoRA
model = get_peft_model(model, lora_config)
return model
class SimpleFineTuner:
@classmethod
def create(
cls,
device = compute_engine()
):
tokenizer = TokenizerFactory().create()
model = ModelFactory().create()
collator = Collator(tokenizer, device)
json_file = JsonFile()
return cls(model, tokenizer, device, collator, json_file)
def __init__(
self,
model,
tokenizer,
device: str,
collator,
json_file
):
self.optimizer = None
self.train_loader = None
self.model = model
self.tokenizer = tokenizer
self.device = device
self.collator = collator
self.json_file = json_file
def train(
self,
dataset: Dataset,
num_epochs: int = 3,
):
self.model.train()
for epoch in range(num_epochs):
total_loss = 0
progress_bar = tqdm(self._train_loader(dataset), desc=f"Epoch {epoch + 1}/{num_epochs}")
for batch in progress_bar:
total_loss = self._process_batch(batch, self._optimizer(), progress_bar, total_loss)
avg_loss = total_loss / len(self._train_loader(dataset))
print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
self.model.save_pretrained(self.json_file.file_path('customer_service_lora_weights'))
def _process_batch(self, batch, optimizer, progress_bar, total_loss):
optimizer.zero_grad()
# Forward pass
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels']
)
loss = outputs.loss
total_loss += loss.item()
# Backward pass
loss.backward()
optimizer.step()
progress_bar.set_postfix({'loss': loss.item()})
return total_loss
def _train_loader(self, dataset):
if self.train_loader is None:
self.train_loader = DataLoader(
dataset,
batch_size=8,
shuffle=True,
collate_fn=self.collator.collate
)
return self.train_loader
def _optimizer(self, ):
if self.optimizer is None:
learning_rate: float = 2e-4
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
return self.optimizer
class LabelPredictor:
@classmethod
def instance(cls):
"""Load a saved model and label mappings"""
model_path = JsonFile().file_path('customer_service_lora_weights')
tokenizer = TokenizerFactory().create()
model = PeftModel.from_pretrained(
ModelFactory().create(tokenizer=tokenizer),
model_path,
is_trainable=False
)
model.eval()
labels = Labels.from_files()
return cls(model, tokenizer, labels)
def __init__(self, model, tokenizer, labels, device: str = compute_engine()):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.labels = labels
def predict(self, text: str) -> tuple[int, str]:
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
# Get prediction
with torch.no_grad():
outputs = self.model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=-1).item()
return predicted_class, self.labels.label_from_id(predicted_class)
def run_predict():
label_predictor = LabelPredictor.instance()
class_id, label = label_predictor.predict(args.predict)
print("Input text:", args.predict)
print(f"Prediction: {label} (class ID: {class_id})")
def run_train():
dataset = JsonDataset.instance().prepare()
SimpleFineTuner.create().train(dataset)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Customer Service Simple Fine Tuner")
parser.add_argument('--train', action='store_true', help='Train the model')
parser.add_argument('--predict', type=str, help='Input for prediction or training dataset', required=False)
args = parser.parse_args()
if args.train:
run_train()
else:
if args.predict is None:
raise ValueError("Please provide an input for prediction")
run_predict()