L
Initializing Studio...
Master Supervised Fine-Tuning (SFT) techniques to adapt open-source Large Language Models for specific tasks, domains, and use cases. Learn how to prepare data, optimize training, and achieve superior performance on targeted applications.
Adapt general-purpose models to excel at specific tasks and domains with targeted training data.
Achieve excellent results with thousands of examples rather than billions, making customization accessible.
Shape model outputs to follow specific formats, styles, and behavioral patterns for your use case.
Incorporate specialized knowledge and terminology from medical, legal, financial, and other domains.
# Complete SFT implementation with Hugging Face Transformers
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
get_linear_schedule_with_warmup
)
from datasets import Dataset, load_dataset
import json
from typing import Dict, List
import numpy as np
class SupervisedFineTuner:
def __init__(self, model_name: str, max_length: int = 2048):
self.model_name = model_name
self.max_length = max_length
self.tokenizer = None
self.model = None
def setup_model_and_tokenizer(self):
"""Initialize model and tokenizer for SFT."""
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
use_fast=True
)
# Set special tokens
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
use_cache=False # Disable cache for training
)
# Enable gradient checkpointing for memory efficiency
self.model.gradient_checkpointing_enable()
print(f"Model loaded: {self.model_name}")
print(f"Vocabulary size: {len(self.tokenizer)}")
print(f"Model parameters: {self.model.num_parameters():,}")
def prepare_instruction_dataset(self, data: List[Dict]) -> Dataset:
"""Prepare instruction-following dataset for SFT."""
def format_instruction_example(example):
"""Format a single instruction example."""
instruction = example.get('instruction', '')
input_text = example.get('input', '')
output = example.get('output', '')
# Create formatted prompt
if input_text:
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
else:
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
# Full text for training
full_text = prompt + output + self.tokenizer.eos_token
return {
'text': full_text,
'prompt': prompt,
'response': output
}
# Format all examples
formatted_data = [format_instruction_example(item) for item in data]
return Dataset.from_list(formatted_data)
def prepare_conversation_dataset(self, data: List[Dict]) -> Dataset:
"""Prepare conversational dataset for SFT."""
def format_conversation(example):
"""Format a conversation example."""
conversation = example.get('conversation', [])
formatted_text = ""
for turn in conversation:
role = turn.get('role', 'user')
content = turn.get('content', '')
if role == 'user':
formatted_text += f"Human: {content}\n\n"
elif role == 'assistant':
formatted_text += f"Assistant: {content}\n\n"
# Add EOS token
formatted_text += self.tokenizer.eos_token
return {'text': formatted_text}
# Format all conversations
formatted_data = [format_conversation(item) for item in data]
return Dataset.from_list(formatted_data)
def tokenize_dataset(self, dataset: Dataset) -> Dataset:
"""Tokenize dataset for training."""
def tokenize_function(examples):
# Tokenize texts
tokenized = self.tokenizer(
examples["text"],
truncation=True,
padding=False,
max_length=self.max_length,
return_overflowing_tokens=False,
)
# For causal LM, labels are the same as input_ids
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
# Apply tokenization
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names,
desc="Tokenizing dataset"
)
# Filter out examples that are too long
original_size = len(tokenized_dataset)
tokenized_dataset = tokenized_dataset.filter(
lambda x: len(x["input_ids"]) <= self.max_length
)
final_size = len(tokenized_dataset)
print(f"Dataset size: {original_size} -> {final_size} examples")
return tokenized_dataset
def create_data_collator(self):
"""Create data collator for training."""
return DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False, # Not masked language modeling
pad_to_multiple_of=8, # For efficiency on modern GPUs
)
def train(
self,
train_dataset: Dataset,
eval_dataset: Dataset = None,
output_dir: str = "./sft_results",
num_epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 5e-5,
warmup_ratio: float = 0.03,
save_steps: int = 500,
logging_steps: int = 10,
eval_steps: int = 500,
):
"""Train the model with SFT."""
# Calculate total training steps
total_steps = (len(train_dataset) // batch_size) * num_epochs
warmup_steps = int(total_steps * warmup_ratio)
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=1,
learning_rate=learning_rate,
weight_decay=0.01,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-8,
max_grad_norm=1.0,
warmup_steps=warmup_steps,
lr_scheduler_type="linear",
logging_steps=logging_steps,
save_steps=save_steps,
eval_steps=eval_steps if eval_dataset else None,
evaluation_strategy="steps" if eval_dataset else "no",
save_strategy="steps",
load_best_model_at_end=True if eval_dataset else False,
metric_for_best_model="eval_loss" if eval_dataset else None,
greater_is_better=False,
report_to="none", # Disable wandb/tensorboard
dataloader_pin_memory=False,
gradient_checkpointing=True,
bf16=True, # Use bfloat16 for stability
remove_unused_columns=False,
push_to_hub=False,
)
# Create trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=self.create_data_collator(),
tokenizer=self.tokenizer,
)
# Add custom callbacks for monitoring
class TrainingCallback:
def on_step_end(self, trainer, logs):
if trainer.state.global_step % 100 == 0:
# Print memory usage
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated() / 1024**3
memory_reserved = torch.cuda.memory_reserved() / 1024**3
print(f"Step {trainer.state.global_step}: "
f"Memory Allocated: {memory_allocated:.2f}GB, "
f"Reserved: {memory_reserved:.2f}GB")
trainer.add_callback(TrainingCallback())
# Start training
print(f"Starting SFT training...")
print(f"Training examples: {len(train_dataset)}")
if eval_dataset:
print(f"Evaluation examples: {len(eval_dataset)}")
print(f"Total training steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")
trainer.train()
# Save final model
trainer.save_model()
self.tokenizer.save_pretrained(output_dir)
print(f"Training completed! Model saved to {output_dir}")
return trainer
def evaluate_model(self, test_prompts: List[str], max_new_tokens: int = 200):
"""Evaluate the fine-tuned model on test prompts."""
print("\nEvaluating fine-tuned model:")
print("=" * 60)
self.model.eval()
for i, prompt in enumerate(test_prompts, 1):
print(f"\nTest {i}:")
print(f"Prompt: {prompt}")
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode response
response = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
print(f"Response: {response}")
print("-" * 40)
# Example usage and data preparation
def prepare_sample_instruction_data():
"""Prepare sample instruction-following data."""
sample_data = [
{
"instruction": "Explain the concept of machine learning in simple terms.",
"input": "",
"output": "Machine learning is a type of artificial intelligence where computers learn to make predictions or decisions by analyzing patterns in data, rather than being explicitly programmed for each task. It's like teaching a computer to recognize patterns the same way humans learn from experience."
},
{
"instruction": "Write a Python function to calculate the factorial of a number.",
"input": "",
"output": "def factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n - 1)\n\n# Example usage:\n# print(factorial(5)) # Output: 120"
},
{
"instruction": "Summarize the following text.",
"input": "Artificial intelligence (AI) is intelligence demonstrated by machines, in contrast to the natural intelligence displayed by humans and animals. Leading AI textbooks define the field as the study of intelligent agents: any device that perceives its environment and takes actions that maximize its chance of successfully achieving its goals.",
"output": "AI refers to machine intelligence that enables devices to perceive their environment and take goal-oriented actions, distinguishing it from natural intelligence found in humans and animals."
},
{
"instruction": "Translate the following English text to French.",
"input": "Hello, how are you today?",
"output": "Bonjour, comment allez-vous aujourd'hui ?"
},
{
"instruction": "Generate a creative story beginning with the given sentence.",
"input": "The old lighthouse stood alone on the rocky cliff.",
"output": "The old lighthouse stood alone on the rocky cliff, its weathered walls holding secrets of countless storms. Sarah climbed the spiral staircase, each step echoing with memories of the lighthouse keeper who had vanished mysteriously fifty years ago. At the top, she discovered a hidden journal that would change everything she thought she knew about her grandfather's disappearance."
}
]
return sample_data
# Main training script
if __name__ == "__main__":
# Initialize fine-tuner
model_name = "mistralai/Mistral-7B-v0.1" # or "meta-llama/Llama-2-7b-hf"
fine_tuner = SupervisedFineTuner(model_name, max_length=2048)
# Setup model and tokenizer
fine_tuner.setup_model_and_tokenizer()
# Prepare training data
print("Preparing training data...")
sample_data = prepare_sample_instruction_data()
# Create dataset
dataset = fine_tuner.prepare_instruction_dataset(sample_data)
tokenized_dataset = fine_tuner.tokenize_dataset(dataset)
# Split into train/eval (80/20)
train_size = int(0.8 * len(tokenized_dataset))
eval_size = len(tokenized_dataset) - train_size
train_dataset = tokenized_dataset.select(range(train_size))
eval_dataset = tokenized_dataset.select(range(train_size, train_size + eval_size))
print(f"Training examples: {len(train_dataset)}")
print(f"Evaluation examples: {len(eval_dataset)}")
# Start training
trainer = fine_tuner.train(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir="./sft_mistral_7b",
num_epochs=3,
batch_size=2, # Adjust based on your GPU memory
learning_rate=5e-5,
warmup_ratio=0.03,
save_steps=100,
logging_steps=10,
eval_steps=50,
)
# Test the fine-tuned model
test_prompts = [
"### Instruction:\nExplain quantum computing in simple terms.\n\n### Response:\n",
"### Instruction:\nWrite a Python function to find the largest number in a list.\n\n### Response:\n",
"### Instruction:\nWhat are the benefits of renewable energy?\n\n### Response:\n"
]
fine_tuner.evaluate_model(test_prompts, max_new_tokens=150)
print("\nSFT training completed successfully!")# Comprehensive data preparation utilities for SFT
import json
import re
import random
from typing import List, Dict, Tuple
from collections import Counter
import pandas as pd
from datasets import Dataset, load_dataset
import hashlib
class SFTDataProcessor:
def __init__(self):
self.processed_data = []
self.stats = {}
def load_alpaca_format(self, file_path: str) -> List[Dict]:
"""Load data in Alpaca format."""
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Validate format
required_keys = ['instruction', 'output']
valid_data = []
for item in data:
if all(key in item for key in required_keys):
valid_data.append({
'instruction': item['instruction'].strip(),
'input': item.get('input', '').strip(),
'output': item['output'].strip()
})
print(f"Loaded {len(valid_data)}/{len(data)} valid examples")
return valid_data
def load_conversational_format(self, file_path: str) -> List[Dict]:
"""Load conversational data (ChatML-like format)."""
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
processed_conversations = []
for conversation in data:
if 'messages' in conversation:
messages = conversation['messages']
formatted_conversation = []
for message in messages:
if 'role' in message and 'content' in message:
formatted_conversation.append({
'role': message['role'],
'content': message['content'].strip()
})
if len(formatted_conversation) >= 2: # At least one exchange
processed_conversations.append({
'conversation': formatted_conversation
})
print(f"Loaded {len(processed_conversations)} conversations")
return processed_conversations
def deduplicate_data(self, data: List[Dict], method: str = 'exact') -> List[Dict]:
"""Remove duplicate examples from dataset."""
if method == 'exact':
# Exact string matching
seen = set()
deduplicated = []
for item in data:
# Create hash of instruction + input + output
content = item['instruction'] + item.get('input', '') + item['output']
content_hash = hashlib.md5(content.encode()).hexdigest()
if content_hash not in seen:
seen.add(content_hash)
deduplicated.append(item)
elif method == 'fuzzy':
# Fuzzy matching based on similarity
from difflib import SequenceMatcher
deduplicated = []
threshold = 0.9
for item in data:
is_duplicate = False
content = item['instruction'] + ' ' + item.get('input', '') + ' ' + item['output']
for existing in deduplicated:
existing_content = existing['instruction'] + ' ' + existing.get('input', '') + ' ' + existing['output']
similarity = SequenceMatcher(None, content, existing_content).ratio()
if similarity > threshold:
is_duplicate = True
break
if not is_duplicate:
deduplicated.append(item)
print(f"Deduplication: {len(data)} -> {len(deduplicated)} examples")
return deduplicated
def filter_by_quality(self, data: List[Dict]) -> List[Dict]:
"""Filter data based on quality criteria."""
filtered_data = []
for item in data:
instruction = item['instruction']
output = item['output']
# Quality checks
checks = [
len(instruction.strip()) >= 10, # Minimum instruction length
len(output.strip()) >= 5, # Minimum output length
len(output.split()) <= 500, # Maximum output length
not self._contains_placeholder(instruction, output),
not self._contains_inappropriate_content(instruction, output),
self._is_coherent_response(instruction, output)
]
if all(checks):
filtered_data.append(item)
print(f"Quality filtering: {len(data)} -> {len(filtered_data)} examples")
return filtered_data
def _contains_placeholder(self, instruction: str, output: str) -> bool:
"""Check if text contains placeholder content."""
placeholders = ['[PLACEHOLDER]', 'TODO', 'FIXME', '...', 'Lorem ipsum']
text = (instruction + ' ' + output).lower()
return any(placeholder.lower() in text for placeholder in placeholders)
def _contains_inappropriate_content(self, instruction: str, output: str) -> bool:
"""Basic check for inappropriate content."""
# Simple keyword-based filtering (expand as needed)
inappropriate_keywords = ['hate', 'violence', 'explicit'] # Simplified list
text = (instruction + ' ' + output).lower()
return any(keyword in text for keyword in inappropriate_keywords)
def _is_coherent_response(self, instruction: str, output: str) -> bool:
"""Check if the response is coherent with the instruction."""
# Simple heuristics (can be improved with more sophisticated methods)
# Check if output is not just repeating the instruction
if instruction.lower() in output.lower() and len(output) < len(instruction) * 1.5:
return False
# Check for minimum complexity
if len(output.split()) < 3:
return False
return True
def augment_data(self, data: List[Dict], augmentation_factor: float = 0.2) -> List[Dict]:
"""Augment dataset with variations."""
augmented_data = data.copy()
num_to_augment = int(len(data) * augmentation_factor)
# Simple paraphrasing (in practice, use more sophisticated methods)
paraphrasing_patterns = [
(r"Explain (.+)", r"Describe \1"),
(r"What is (.+)?", r"Can you explain \1?"),
(r"How do I (.+)?", r"What's the way to \1?"),
(r"Write (.+)", r"Create \1"),
]
for _ in range(num_to_augment):
original = random.choice(data)
augmented = original.copy()
# Try to paraphrase the instruction
for pattern, replacement in paraphrasing_patterns:
if re.search(pattern, augmented['instruction'], re.IGNORECASE):
augmented['instruction'] = re.sub(
pattern, replacement, augmented['instruction'], flags=re.IGNORECASE
)
break
augmented_data.append(augmented)
print(f"Data augmentation: {len(data)} -> {len(augmented_data)} examples")
return augmented_data
def analyze_dataset(self, data: List[Dict]) -> Dict:
"""Analyze dataset characteristics."""
analysis = {
'total_examples': len(data),
'avg_instruction_length': 0,
'avg_output_length': 0,
'instruction_length_distribution': [],
'output_length_distribution': [],
'common_instruction_patterns': [],
}
instruction_lengths = []
output_lengths = []
instruction_starts = []
for item in data:
inst_len = len(item['instruction'].split())
out_len = len(item['output'].split())
instruction_lengths.append(inst_len)
output_lengths.append(out_len)
# Extract instruction patterns
first_words = ' '.join(item['instruction'].split()[:3]).lower()
instruction_starts.append(first_words)
analysis['avg_instruction_length'] = sum(instruction_lengths) / len(instruction_lengths)
analysis['avg_output_length'] = sum(output_lengths) / len(output_lengths)
# Length distributions
analysis['instruction_length_distribution'] = {
'min': min(instruction_lengths),
'max': max(instruction_lengths),
'median': sorted(instruction_lengths)[len(instruction_lengths)//2]
}
analysis['output_length_distribution'] = {
'min': min(output_lengths),
'max': max(output_lengths),
'median': sorted(output_lengths)[len(output_lengths)//2]
}
# Common patterns
pattern_counts = Counter(instruction_starts)
analysis['common_instruction_patterns'] = pattern_counts.most_common(10)
return analysis
def create_balanced_dataset(self, data: List[Dict], categories: List[str] = None) -> List[Dict]:
"""Create a balanced dataset across different categories."""
if categories is None:
# Auto-detect categories based on instruction patterns
categories = self._auto_detect_categories(data)
# Categorize examples
categorized_data = {cat: [] for cat in categories}
uncategorized = []
for item in data:
instruction = item['instruction'].lower()
categorized = False
for category in categories:
if category.lower() in instruction:
categorized_data[category].append(item)
categorized = True
break
if not categorized:
uncategorized.append(item)
# Balance categories
min_category_size = min(len(examples) for examples in categorized_data.values() if examples)
balanced_data = []
for category, examples in categorized_data.items():
if examples:
# Sample from each category
selected = random.sample(examples, min(len(examples), min_category_size))
balanced_data.extend(selected)
# Add some uncategorized examples
if uncategorized:
additional_size = len(balanced_data) // 4 # 25% uncategorized
selected_uncategorized = random.sample(
uncategorized,
min(len(uncategorized), additional_size)
)
balanced_data.extend(selected_uncategorized)
print(f"Balanced dataset: {len(data)} -> {len(balanced_data)} examples")
return balanced_data
def _auto_detect_categories(self, data: List[Dict]) -> List[str]:
"""Auto-detect common categories in the dataset."""
# Common instruction types
patterns = [
'explain', 'describe', 'write', 'create', 'generate',
'translate', 'summarize', 'analyze', 'compare', 'define'
]
detected_categories = []
for pattern in patterns:
count = sum(1 for item in data if pattern in item['instruction'].lower())
if count >= 5: # Minimum threshold
detected_categories.append(pattern)
return detected_categories[:10] # Limit to top 10 categories
def export_processed_data(self, data: List[Dict], output_path: str, format: str = 'json'):
"""Export processed data in specified format."""
if format == 'json':
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
elif format == 'jsonl':
with open(output_path, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
elif format == 'csv':
df = pd.DataFrame(data)
df.to_csv(output_path, index=False)
print(f"Data exported to {output_path} in {format} format")
# Example usage
if __name__ == "__main__":
processor = SFTDataProcessor()
# Sample data for demonstration
sample_data = [
{
"instruction": "Explain machine learning",
"input": "",
"output": "Machine learning is a subset of AI that enables computers to learn from data."
},
{
"instruction": "Write a Python function to add two numbers",
"input": "",
"output": "def add(a, b):\n return a + b"
},
# Add more examples...
]
print("Original dataset analysis:")
analysis = processor.analyze_dataset(sample_data)
for key, value in analysis.items():
if key != 'common_instruction_patterns':
print(f"{key}: {value}")
# Process the data
print("\nProcessing data...")
# Deduplicate
deduplicated = processor.deduplicate_data(sample_data)
# Filter by quality
filtered = processor.filter_by_quality(deduplicated)
# Augment data
augmented = processor.augment_data(filtered, augmentation_factor=0.3)
# Create balanced dataset
balanced = processor.create_balanced_dataset(augmented)
print("\nFinal dataset analysis:")
final_analysis = processor.analyze_dataset(balanced)
for key, value in final_analysis.items():
if key != 'common_instruction_patterns':
print(f"{key}: {value}")
# Export processed data
processor.export_processed_data(balanced, "processed_sft_data.json")
print("\nData processing completed!")# Advanced SFT techniques implementation
import torch
import torch.nn as nn
from transformers import (
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
Trainer, get_linear_schedule_with_warmup
)
from torch.optim import AdamW
import numpy as np
from typing import Dict, List, Optional
import wandb
from torch.cuda.amp import GradScaler, autocast
class AdvancedSFTTrainer:
def __init__(self, model_name: str, use_deepspeed: bool = False):
self.model_name = model_name
self.use_deepspeed = use_deepspeed
self.model = None
self.tokenizer = None
self.scaler = GradScaler() if torch.cuda.is_available() else None
def setup_model_with_optimizations(self,
gradient_checkpointing: bool = True,
mixed_precision: bool = True):
"""Setup model with memory and training optimizations."""
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with optimizations
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16 if mixed_precision else torch.float32,
device_map="auto" if not self.use_deepspeed else None,
use_cache=False, # Disable for training
)
if gradient_checkpointing:
self.model.gradient_checkpointing_enable()
print(f"Model loaded with optimizations:")
print(f"- Gradient checkpointing: {gradient_checkpointing}")
print(f"- Mixed precision: {mixed_precision}")
print(f"- DeepSpeed: {self.use_deepspeed}")
def create_custom_optimizer(self,
learning_rate: float = 5e-5,
weight_decay: float = 0.01,
use_layer_wise_lr: bool = False) -> AdamW:
"""Create optimized AdamW optimizer with optional layer-wise learning rates."""
if use_layer_wise_lr:
# Different learning rates for different layers
parameter_groups = []
# Embedding layers - lower LR
embedding_params = []
for name, param in self.model.named_parameters():
if 'embed' in name or 'wte' in name or 'wpe' in name:
embedding_params.append(param)
if embedding_params:
parameter_groups.append({
'params': embedding_params,
'lr': learning_rate * 0.1, # 10x lower
'weight_decay': weight_decay
})
# Output layers - higher LR
output_params = []
for name, param in self.model.named_parameters():
if 'lm_head' in name or 'output' in name:
output_params.append(param)
if output_params:
parameter_groups.append({
'params': output_params,
'lr': learning_rate * 2.0, # 2x higher
'weight_decay': weight_decay
})
# All other parameters - standard LR
other_params = []
embedding_names = {id(p) for p in embedding_params}
output_names = {id(p) for p in output_params}
for param in self.model.parameters():
if id(param) not in embedding_names and id(param) not in output_names:
other_params.append(param)
if other_params:
parameter_groups.append({
'params': other_params,
'lr': learning_rate,
'weight_decay': weight_decay
})
optimizer = AdamW(parameter_groups, betas=(0.9, 0.999), eps=1e-8)
else:
# Standard optimizer
optimizer = AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999),
eps=1e-8
)
return optimizer
def create_curriculum_dataset(self, dataset, difficulty_metric: str = 'length'):
"""Create curriculum learning dataset ordered by difficulty."""
def calculate_difficulty(example):
if difficulty_metric == 'length':
return len(example['input_ids'])
elif difficulty_metric == 'vocab_complexity':
# Simple vocabulary complexity metric
unique_tokens = len(set(example['input_ids']))
total_tokens = len(example['input_ids'])
return unique_tokens / total_tokens
else:
return 0.5 # Default neutral difficulty
# Calculate difficulty scores
difficulties = [calculate_difficulty(example) for example in dataset]
# Sort by difficulty (easy to hard)
sorted_indices = sorted(range(len(dataset)), key=lambda i: difficulties[i])
# Create curriculum dataset
curriculum_dataset = dataset.select(sorted_indices)
print(f"Created curriculum dataset with {len(curriculum_dataset)} examples")
return curriculum_dataset
def train_with_advanced_techniques(self,
train_dataset,
eval_dataset=None,
output_dir="./advanced_sft",
num_epochs=3,
batch_size=4,
learning_rate=5e-5,
use_curriculum=True,
use_label_smoothing=True,
label_smoothing_factor=0.1,
use_cosine_schedule=True,
warmup_ratio=0.03):
"""Train with advanced techniques."""
# Setup curriculum learning
if use_curriculum:
train_dataset = self.create_curriculum_dataset(train_dataset)
# Calculate training steps
total_steps = (len(train_dataset) // batch_size) * num_epochs
warmup_steps = int(total_steps * warmup_ratio)
# Create custom optimizer
optimizer = self.create_custom_optimizer(
learning_rate=learning_rate,
use_layer_wise_lr=True
)
# Create learning rate scheduler
if use_cosine_schedule:
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
else:
scheduler = None
# Custom loss function with label smoothing
class LabelSmoothingLoss(nn.Module):
def __init__(self, smoothing=0.1, vocab_size=None):
super().__init__()
self.smoothing = smoothing
self.vocab_size = vocab_size or len(self.tokenizer)
def forward(self, pred, target):
# Reshape predictions and targets
pred = pred.view(-1, pred.size(-1))
target = target.view(-1)
# Create smoothed targets
confidence = 1.0 - self.smoothing
smooth_value = self.smoothing / (self.vocab_size - 1)
# One-hot encode targets
one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), confidence)
one_hot += smooth_value
# Compute cross entropy with smoothed labels
log_probs = torch.log_softmax(pred, dim=1)
loss = -torch.sum(one_hot * log_probs, dim=1)
# Mask padding tokens
mask = (target != -100).float()
loss = loss * mask
return loss.sum() / mask.sum()
# Custom trainer with advanced features
class AdvancedTrainer(Trainer):
def __init__(self, *args, label_smoothing_loss=None, **kwargs):
super().__init__(*args, **kwargs)
self.label_smoothing_loss = label_smoothing_loss
self.training_step = 0
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
if self.label_smoothing_loss and labels is not None:
loss = self.label_smoothing_loss(logits, labels)
else:
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
def training_step(self, model, inputs):
"""Custom training step with mixed precision."""
model.train()
inputs = self._prepare_inputs(inputs)
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean()
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.use_amp:
self.scaler.scale(loss).backward()
else:
loss.backward()
self.training_step += 1
return loss.detach()
def optimizer_step(self, optimizer):
"""Custom optimizer step with gradient clipping."""
if self.use_amp:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=1,
learning_rate=learning_rate,
weight_decay=0.01,
max_grad_norm=1.0,
warmup_steps=warmup_steps,
logging_steps=10,
save_steps=500,
eval_steps=500 if eval_dataset else None,
evaluation_strategy="steps" if eval_dataset else "no",
save_strategy="steps",
load_best_model_at_end=True if eval_dataset else False,
metric_for_best_model="eval_loss" if eval_dataset else None,
greater_is_better=False,
report_to="wandb",
run_name="advanced_sft",
bf16=True,
dataloader_pin_memory=False,
gradient_checkpointing=True,
remove_unused_columns=False,
)
# Create loss function
loss_fn = LabelSmoothingLoss(
smoothing=label_smoothing_factor,
vocab_size=len(self.tokenizer)
) if use_label_smoothing else None
# Create trainer
trainer = AdvancedTrainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
optimizers=(optimizer, scheduler),
label_smoothing_loss=loss_fn,
)
# Initialize wandb
wandb.init(
project="advanced-sft",
config={
"model_name": self.model_name,
"num_epochs": num_epochs,
"batch_size": batch_size,
"learning_rate": learning_rate,
"use_curriculum": use_curriculum,
"use_label_smoothing": use_label_smoothing,
"label_smoothing_factor": label_smoothing_factor,
}
)
# Training
print("Starting advanced SFT training...")
trainer.train()
# Save model
trainer.save_model()
self.tokenizer.save_pretrained(output_dir)
wandb.finish()
print(f"Advanced SFT completed! Model saved to {output_dir}")
return trainer
def evaluate_with_multiple_metrics(self, test_dataset, metrics=['perplexity', 'bleu']):
"""Evaluate model with multiple metrics."""
from sklearn.metrics import accuracy_score
import sacrebleu
self.model.eval()
results = {}
if 'perplexity' in metrics:
# Calculate perplexity
total_loss = 0
total_tokens = 0
for example in test_dataset:
inputs = {k: torch.tensor(v).unsqueeze(0).to(self.model.device)
for k, v in example.items() if k in ['input_ids', 'attention_mask']}
with torch.no_grad():
outputs = self.model(**inputs, labels=inputs['input_ids'])
loss = outputs.loss.item()
num_tokens = inputs['input_ids'].numel()
total_loss += loss * num_tokens
total_tokens += num_tokens
perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
results['perplexity'] = perplexity.item()
if 'bleu' in metrics:
# Calculate BLEU score (simplified example)
references = []
predictions = []
for example in test_dataset[:100]: # Sample for efficiency
# This is a simplified example - adapt based on your data format
input_ids = example['input_ids'][:50] # First 50 tokens as input
target_ids = example['input_ids'][50:] # Rest as target
inputs = torch.tensor(input_ids).unsqueeze(0).to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_new_tokens=len(target_ids),
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id
)
pred_text = self.tokenizer.decode(outputs[0][len(input_ids):], skip_special_tokens=True)
ref_text = self.tokenizer.decode(target_ids, skip_special_tokens=True)
predictions.append(pred_text)
references.append([ref_text]) # BLEU expects list of references
if predictions and references:
bleu_score = sacrebleu.corpus_bleu(predictions, references)
results['bleu'] = bleu_score.score
return results
# Example usage
if __name__ == "__main__":
# Initialize advanced trainer
trainer = AdvancedSFTTrainer("mistralai/Mistral-7B-v0.1")
# Setup model with optimizations
trainer.setup_model_with_optimizations(
gradient_checkpointing=True,
mixed_precision=True
)
# Prepare sample dataset (replace with your actual data)
from datasets import Dataset
sample_data = [
{"input_ids": [1, 2, 3, 4, 5] * 100, "attention_mask": [1] * 500},
{"input_ids": [6, 7, 8, 9, 10] * 80, "attention_mask": [1] * 400},
# Add more examples...
]
train_dataset = Dataset.from_list(sample_data)
eval_dataset = Dataset.from_list(sample_data[:2]) # Small eval set
# Train with advanced techniques
trainer.train_with_advanced_techniques(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir="./advanced_sft_model",
num_epochs=2,
batch_size=2,
learning_rate=5e-5,
use_curriculum=True,
use_label_smoothing=True,
label_smoothing_factor=0.1,
use_cosine_schedule=True,
)
# Evaluate with multiple metrics
results = trainer.evaluate_with_multiple_metrics(eval_dataset)
print("Evaluation results:", results)