L

Initializing Studio...

Documentation
Last updated: October 10, 2025

Getting Started

  • Introduction
  • Quick Start
  • Installation

Fine-tuning

  • LoRA & QLoRA
  • Full Fine-tuning

API & SDK

  • REST API
  • Python SDK

Deployment

  • Cloud Deployment
  • Security

Resources

  • FAQ
  • Changelog

Supervised Fine-Tuning (SFT)

Master Supervised Fine-Tuning (SFT) techniques to adapt open-source Large Language Models for specific tasks, domains, and use cases. Learn how to prepare data, optimize training, and achieve superior performance on targeted applications.

Key Features

🎯

Task Specialization

Adapt general-purpose models to excel at specific tasks and domains with targeted training data.

📊

Data Efficiency

Achieve excellent results with thousands of examples rather than billions, making customization accessible.

🎛️

Behavioral Control

Shape model outputs to follow specific formats, styles, and behavioral patterns for your use case.

🏥

Domain Expertise

Incorporate specialized knowledge and terminology from medical, legal, financial, and other domains.

Understanding SFT

Supervised Fine-Tuning (SFT) is the process of adapting a pre-trained language model to perform specific tasks by training it on labeled task-specific data. Unlike unsupervised pre-training, SFT uses input-output pairs to teach the model desired behaviors.

Core Concepts:

What is SFT?
SFT takes a pre-trained language model and continues training it on a curated dataset of input-output examples. This process adapts the model's general language understanding to specific tasks, domains, or behaviors.

SFT vs. Pre-training:
- Pre-training: Learns general language patterns from massive unlabeled text
- SFT: Learns specific behaviors from labeled examples
- Data Volume: Pre-training uses trillions of tokens, SFT uses thousands to millions
- Objective: Pre-training optimizes next-token prediction, SFT optimizes task-specific performance

Types of SFT:

1. Instruction Tuning:
- Teaching models to follow instructions
- Format: Instruction → Response
- Examples: "Summarize this text" → Summary
- Results in general-purpose instruction-following models

2. Task-Specific Fine-tuning:
- Adapting models for specific tasks
- Examples: Question answering, sentiment analysis, code generation
- Highly optimized for single use cases

3. Domain Adaptation:
- Specializing models for specific domains
- Examples: Medical, legal, financial, scientific domains
- Incorporates domain-specific knowledge and terminology

4. Behavioral Alignment:
- Training models to exhibit desired behaviors
- Examples: Being helpful, harmless, and honest
- Often combined with reinforcement learning techniques

Key Benefits:
- Task Performance: Dramatically improves performance on target tasks
- Efficiency: Requires less data than training from scratch
- Customization: Allows tailoring to specific requirements
- Control: Better control over model outputs and behavior
- Domain Expertise: Incorporates specialized knowledge

When to Use SFT:
- Adapting general models to specific domains
- Improving performance on targeted tasks
- Teaching new formats or behaviors
- Incorporating proprietary or domain-specific data
- Creating specialized AI assistants
Code Example
# Complete SFT implementation with Hugging Face Transformers
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    get_linear_schedule_with_warmup
)
from datasets import Dataset, load_dataset
import json
from typing import Dict, List
import numpy as np

class SupervisedFineTuner:
    def __init__(self, model_name: str, max_length: int = 2048):
        self.model_name = model_name
        self.max_length = max_length
        self.tokenizer = None
        self.model = None
        
    def setup_model_and_tokenizer(self):
        """Initialize model and tokenizer for SFT."""
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            trust_remote_code=True,
            use_fast=True
        )
        
        # Set special tokens
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
            use_cache=False  # Disable cache for training
        )
        
        # Enable gradient checkpointing for memory efficiency
        self.model.gradient_checkpointing_enable()
        
        print(f"Model loaded: {self.model_name}")
        print(f"Vocabulary size: {len(self.tokenizer)}")
        print(f"Model parameters: {self.model.num_parameters():,}")
    
    def prepare_instruction_dataset(self, data: List[Dict]) -> Dataset:
        """Prepare instruction-following dataset for SFT."""
        
        def format_instruction_example(example):
            """Format a single instruction example."""
            
            instruction = example.get('instruction', '')
            input_text = example.get('input', '')
            output = example.get('output', '')
            
            # Create formatted prompt
            if input_text:
                prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
            else:
                prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
            
            # Full text for training
            full_text = prompt + output + self.tokenizer.eos_token
            
            return {
                'text': full_text,
                'prompt': prompt,
                'response': output
            }
        
        # Format all examples
        formatted_data = [format_instruction_example(item) for item in data]
        
        return Dataset.from_list(formatted_data)
    
    def prepare_conversation_dataset(self, data: List[Dict]) -> Dataset:
        """Prepare conversational dataset for SFT."""
        
        def format_conversation(example):
            """Format a conversation example."""
            
            conversation = example.get('conversation', [])
            formatted_text = ""
            
            for turn in conversation:
                role = turn.get('role', 'user')
                content = turn.get('content', '')
                
                if role == 'user':
                    formatted_text += f"Human: {content}\n\n"
                elif role == 'assistant':
                    formatted_text += f"Assistant: {content}\n\n"
            
            # Add EOS token
            formatted_text += self.tokenizer.eos_token
            
            return {'text': formatted_text}
        
        # Format all conversations
        formatted_data = [format_conversation(item) for item in data]
        
        return Dataset.from_list(formatted_data)
    
    def tokenize_dataset(self, dataset: Dataset) -> Dataset:
        """Tokenize dataset for training."""
        
        def tokenize_function(examples):
            # Tokenize texts
            tokenized = self.tokenizer(
                examples["text"],
                truncation=True,
                padding=False,
                max_length=self.max_length,
                return_overflowing_tokens=False,
            )
            
            # For causal LM, labels are the same as input_ids
            tokenized["labels"] = tokenized["input_ids"].copy()
            
            return tokenized
        
        # Apply tokenization
        tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=dataset.column_names,
            desc="Tokenizing dataset"
        )
        
        # Filter out examples that are too long
        original_size = len(tokenized_dataset)
        tokenized_dataset = tokenized_dataset.filter(
            lambda x: len(x["input_ids"]) <= self.max_length
        )
        final_size = len(tokenized_dataset)
        
        print(f"Dataset size: {original_size} -> {final_size} examples")
        
        return tokenized_dataset
    
    def create_data_collator(self):
        """Create data collator for training."""
        
        return DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,  # Not masked language modeling
            pad_to_multiple_of=8,  # For efficiency on modern GPUs
        )
    
    def train(
        self, 
        train_dataset: Dataset,
        eval_dataset: Dataset = None,
        output_dir: str = "./sft_results",
        num_epochs: int = 3,
        batch_size: int = 4,
        learning_rate: float = 5e-5,
        warmup_ratio: float = 0.03,
        save_steps: int = 500,
        logging_steps: int = 10,
        eval_steps: int = 500,
    ):
        """Train the model with SFT."""
        
        # Calculate total training steps
        total_steps = (len(train_dataset) // batch_size) * num_epochs
        warmup_steps = int(total_steps * warmup_ratio)
        
        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            gradient_accumulation_steps=1,
            learning_rate=learning_rate,
            weight_decay=0.01,
            adam_beta1=0.9,
            adam_beta2=0.999,
            adam_epsilon=1e-8,
            max_grad_norm=1.0,
            warmup_steps=warmup_steps,
            lr_scheduler_type="linear",
            logging_steps=logging_steps,
            save_steps=save_steps,
            eval_steps=eval_steps if eval_dataset else None,
            evaluation_strategy="steps" if eval_dataset else "no",
            save_strategy="steps",
            load_best_model_at_end=True if eval_dataset else False,
            metric_for_best_model="eval_loss" if eval_dataset else None,
            greater_is_better=False,
            report_to="none",  # Disable wandb/tensorboard
            dataloader_pin_memory=False,
            gradient_checkpointing=True,
            bf16=True,  # Use bfloat16 for stability
            remove_unused_columns=False,
            push_to_hub=False,
        )
        
        # Create trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=self.create_data_collator(),
            tokenizer=self.tokenizer,
        )
        
        # Add custom callbacks for monitoring
        class TrainingCallback:
            def on_step_end(self, trainer, logs):
                if trainer.state.global_step % 100 == 0:
                    # Print memory usage
                    if torch.cuda.is_available():
                        memory_allocated = torch.cuda.memory_allocated() / 1024**3
                        memory_reserved = torch.cuda.memory_reserved() / 1024**3
                        print(f"Step {trainer.state.global_step}: "
                              f"Memory Allocated: {memory_allocated:.2f}GB, "
                              f"Reserved: {memory_reserved:.2f}GB")
        
        trainer.add_callback(TrainingCallback())
        
        # Start training
        print(f"Starting SFT training...")
        print(f"Training examples: {len(train_dataset)}")
        if eval_dataset:
            print(f"Evaluation examples: {len(eval_dataset)}")
        print(f"Total training steps: {total_steps}")
        print(f"Warmup steps: {warmup_steps}")
        
        trainer.train()
        
        # Save final model
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)
        
        print(f"Training completed! Model saved to {output_dir}")
        
        return trainer
    
    def evaluate_model(self, test_prompts: List[str], max_new_tokens: int = 200):
        """Evaluate the fine-tuned model on test prompts."""
        
        print("\nEvaluating fine-tuned model:")
        print("=" * 60)
        
        self.model.eval()
        
        for i, prompt in enumerate(test_prompts, 1):
            print(f"\nTest {i}:")
            print(f"Prompt: {prompt}")
            
            # Tokenize input
            inputs = self.tokenizer(prompt, return_tensors="pt")
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            
            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            
            # Decode response
            response = self.tokenizer.decode(
                outputs[0][inputs['input_ids'].shape[1]:], 
                skip_special_tokens=True
            )
            
            print(f"Response: {response}")
            print("-" * 40)

# Example usage and data preparation
def prepare_sample_instruction_data():
    """Prepare sample instruction-following data."""
    
    sample_data = [
        {
            "instruction": "Explain the concept of machine learning in simple terms.",
            "input": "",
            "output": "Machine learning is a type of artificial intelligence where computers learn to make predictions or decisions by analyzing patterns in data, rather than being explicitly programmed for each task. It's like teaching a computer to recognize patterns the same way humans learn from experience."
        },
        {
            "instruction": "Write a Python function to calculate the factorial of a number.",
            "input": "",
            "output": "def factorial(n):\n    if n == 0 or n == 1:\n        return 1\n    else:\n        return n * factorial(n - 1)\n\n# Example usage:\n# print(factorial(5))  # Output: 120"
        },
        {
            "instruction": "Summarize the following text.",
            "input": "Artificial intelligence (AI) is intelligence demonstrated by machines, in contrast to the natural intelligence displayed by humans and animals. Leading AI textbooks define the field as the study of intelligent agents: any device that perceives its environment and takes actions that maximize its chance of successfully achieving its goals.",
            "output": "AI refers to machine intelligence that enables devices to perceive their environment and take goal-oriented actions, distinguishing it from natural intelligence found in humans and animals."
        },
        {
            "instruction": "Translate the following English text to French.",
            "input": "Hello, how are you today?",
            "output": "Bonjour, comment allez-vous aujourd'hui ?"
        },
        {
            "instruction": "Generate a creative story beginning with the given sentence.",
            "input": "The old lighthouse stood alone on the rocky cliff.",
            "output": "The old lighthouse stood alone on the rocky cliff, its weathered walls holding secrets of countless storms. Sarah climbed the spiral staircase, each step echoing with memories of the lighthouse keeper who had vanished mysteriously fifty years ago. At the top, she discovered a hidden journal that would change everything she thought she knew about her grandfather's disappearance."
        }
    ]
    
    return sample_data

# Main training script
if __name__ == "__main__":
    # Initialize fine-tuner
    model_name = "mistralai/Mistral-7B-v0.1"  # or "meta-llama/Llama-2-7b-hf"
    fine_tuner = SupervisedFineTuner(model_name, max_length=2048)
    
    # Setup model and tokenizer
    fine_tuner.setup_model_and_tokenizer()
    
    # Prepare training data
    print("Preparing training data...")
    sample_data = prepare_sample_instruction_data()
    
    # Create dataset
    dataset = fine_tuner.prepare_instruction_dataset(sample_data)
    tokenized_dataset = fine_tuner.tokenize_dataset(dataset)
    
    # Split into train/eval (80/20)
    train_size = int(0.8 * len(tokenized_dataset))
    eval_size = len(tokenized_dataset) - train_size
    
    train_dataset = tokenized_dataset.select(range(train_size))
    eval_dataset = tokenized_dataset.select(range(train_size, train_size + eval_size))
    
    print(f"Training examples: {len(train_dataset)}")
    print(f"Evaluation examples: {len(eval_dataset)}")
    
    # Start training
    trainer = fine_tuner.train(
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        output_dir="./sft_mistral_7b",
        num_epochs=3,
        batch_size=2,  # Adjust based on your GPU memory
        learning_rate=5e-5,
        warmup_ratio=0.03,
        save_steps=100,
        logging_steps=10,
        eval_steps=50,
    )
    
    # Test the fine-tuned model
    test_prompts = [
        "### Instruction:\nExplain quantum computing in simple terms.\n\n### Response:\n",
        "### Instruction:\nWrite a Python function to find the largest number in a list.\n\n### Response:\n",
        "### Instruction:\nWhat are the benefits of renewable energy?\n\n### Response:\n"
    ]
    
    fine_tuner.evaluate_model(test_prompts, max_new_tokens=150)
    
    print("\nSFT training completed successfully!")

Data Preparation for SFT

The quality and format of your training data is crucial for successful SFT. Here's how to prepare different types of datasets:

1. Instruction-Following Datasets:

Alpaca Format (Recommended):
- Structure: Instruction + Optional Input + Response
- Use Case: General instruction following
- Benefits: Standardized format, widely supported

ChatML Format:
- Structure: Role-based conversation format
- Use Case: Multi-turn conversations
- Benefits: Natural conversation flow

2. Domain-Specific Datasets:

Medical Domain:
- Data Sources: Medical literature, Q&A pairs, case studies
- Format: Question-answer or case-diagnosis pairs
- Considerations: Accuracy, regulatory compliance, bias reduction

Legal Domain:
- Data Sources: Legal documents, case law, regulations
- Format: Legal query-response or document analysis
- Considerations: Jurisdiction accuracy, ethical guidelines

Code Generation:
- Data Sources: GitHub repositories, coding competitions, documentation
- Format: Natural language description to code
- Considerations: Code quality, security, best practices

3. Data Quality Guidelines:

High-Quality Characteristics:
- Accuracy: Factually correct information
- Relevance: Aligned with target use case
- Diversity: Covers various scenarios and edge cases
- Consistency: Uniform formatting and style
- Completeness: Complete responses without truncation

Data Preprocessing Steps:
1. Deduplication: Remove duplicate or near-duplicate examples
2. Quality Filtering: Remove low-quality, inappropriate, or biased content
3. Length Filtering: Remove examples that are too short or too long
4. Format Validation: Ensure consistent formatting
5. Language Detection: Filter for target language(s)
6. Content Filtering: Remove harmful, toxic, or inappropriate content

4. Data Augmentation Techniques:

Paraphrasing:
- Rewrite instructions and responses in different ways
- Increases dataset diversity
- Helps model generalize better

Back-Translation:
- Translate to another language and back
- Creates natural variations
- Useful for multilingual applications

Synthetic Data Generation:
- Use existing LLMs to generate training examples
- Helpful for expanding small datasets
- Requires careful quality control

5. Dataset Size Recommendations:

Task Complexity vs. Dataset Size:
- Simple Tasks: 500-2,000 examples
- Moderate Tasks: 2,000-10,000 examples
- Complex Tasks: 10,000-50,000+ examples
- Domain Adaptation: 1,000-5,000 high-quality examples

Quality vs. Quantity:
- Prefer 1,000 high-quality examples over 10,000 low-quality ones
- Focus on covering diverse scenarios
- Ensure balanced representation of different use cases
Code Example
# Comprehensive data preparation utilities for SFT
import json
import re
import random
from typing import List, Dict, Tuple
from collections import Counter
import pandas as pd
from datasets import Dataset, load_dataset
import hashlib

class SFTDataProcessor:
    def __init__(self):
        self.processed_data = []
        self.stats = {}
    
    def load_alpaca_format(self, file_path: str) -> List[Dict]:
        """Load data in Alpaca format."""
        
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # Validate format
        required_keys = ['instruction', 'output']
        valid_data = []
        
        for item in data:
            if all(key in item for key in required_keys):
                valid_data.append({
                    'instruction': item['instruction'].strip(),
                    'input': item.get('input', '').strip(),
                    'output': item['output'].strip()
                })
        
        print(f"Loaded {len(valid_data)}/{len(data)} valid examples")
        return valid_data
    
    def load_conversational_format(self, file_path: str) -> List[Dict]:
        """Load conversational data (ChatML-like format)."""
        
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        processed_conversations = []
        
        for conversation in data:
            if 'messages' in conversation:
                messages = conversation['messages']
                formatted_conversation = []
                
                for message in messages:
                    if 'role' in message and 'content' in message:
                        formatted_conversation.append({
                            'role': message['role'],
                            'content': message['content'].strip()
                        })
                
                if len(formatted_conversation) >= 2:  # At least one exchange
                    processed_conversations.append({
                        'conversation': formatted_conversation
                    })
        
        print(f"Loaded {len(processed_conversations)} conversations")
        return processed_conversations
    
    def deduplicate_data(self, data: List[Dict], method: str = 'exact') -> List[Dict]:
        """Remove duplicate examples from dataset."""
        
        if method == 'exact':
            # Exact string matching
            seen = set()
            deduplicated = []
            
            for item in data:
                # Create hash of instruction + input + output
                content = item['instruction'] + item.get('input', '') + item['output']
                content_hash = hashlib.md5(content.encode()).hexdigest()
                
                if content_hash not in seen:
                    seen.add(content_hash)
                    deduplicated.append(item)
            
        elif method == 'fuzzy':
            # Fuzzy matching based on similarity
            from difflib import SequenceMatcher
            
            deduplicated = []
            threshold = 0.9
            
            for item in data:
                is_duplicate = False
                content = item['instruction'] + ' ' + item.get('input', '') + ' ' + item['output']
                
                for existing in deduplicated:
                    existing_content = existing['instruction'] + ' ' + existing.get('input', '') + ' ' + existing['output']
                    similarity = SequenceMatcher(None, content, existing_content).ratio()
                    
                    if similarity > threshold:
                        is_duplicate = True
                        break
                
                if not is_duplicate:
                    deduplicated.append(item)
        
        print(f"Deduplication: {len(data)} -> {len(deduplicated)} examples")
        return deduplicated
    
    def filter_by_quality(self, data: List[Dict]) -> List[Dict]:
        """Filter data based on quality criteria."""
        
        filtered_data = []
        
        for item in data:
            instruction = item['instruction']
            output = item['output']
            
            # Quality checks
            checks = [
                len(instruction.strip()) >= 10,  # Minimum instruction length
                len(output.strip()) >= 5,       # Minimum output length
                len(output.split()) <= 500,     # Maximum output length
                not self._contains_placeholder(instruction, output),
                not self._contains_inappropriate_content(instruction, output),
                self._is_coherent_response(instruction, output)
            ]
            
            if all(checks):
                filtered_data.append(item)
        
        print(f"Quality filtering: {len(data)} -> {len(filtered_data)} examples")
        return filtered_data
    
    def _contains_placeholder(self, instruction: str, output: str) -> bool:
        """Check if text contains placeholder content."""
        placeholders = ['[PLACEHOLDER]', 'TODO', 'FIXME', '...', 'Lorem ipsum']
        text = (instruction + ' ' + output).lower()
        return any(placeholder.lower() in text for placeholder in placeholders)
    
    def _contains_inappropriate_content(self, instruction: str, output: str) -> bool:
        """Basic check for inappropriate content."""
        # Simple keyword-based filtering (expand as needed)
        inappropriate_keywords = ['hate', 'violence', 'explicit']  # Simplified list
        text = (instruction + ' ' + output).lower()
        return any(keyword in text for keyword in inappropriate_keywords)
    
    def _is_coherent_response(self, instruction: str, output: str) -> bool:
        """Check if the response is coherent with the instruction."""
        # Simple heuristics (can be improved with more sophisticated methods)
        
        # Check if output is not just repeating the instruction
        if instruction.lower() in output.lower() and len(output) < len(instruction) * 1.5:
            return False
        
        # Check for minimum complexity
        if len(output.split()) < 3:
            return False
        
        return True
    
    def augment_data(self, data: List[Dict], augmentation_factor: float = 0.2) -> List[Dict]:
        """Augment dataset with variations."""
        
        augmented_data = data.copy()
        num_to_augment = int(len(data) * augmentation_factor)
        
        # Simple paraphrasing (in practice, use more sophisticated methods)
        paraphrasing_patterns = [
            (r"Explain (.+)", r"Describe \1"),
            (r"What is (.+)?", r"Can you explain \1?"),
            (r"How do I (.+)?", r"What's the way to \1?"),
            (r"Write (.+)", r"Create \1"),
        ]
        
        for _ in range(num_to_augment):
            original = random.choice(data)
            augmented = original.copy()
            
            # Try to paraphrase the instruction
            for pattern, replacement in paraphrasing_patterns:
                if re.search(pattern, augmented['instruction'], re.IGNORECASE):
                    augmented['instruction'] = re.sub(
                        pattern, replacement, augmented['instruction'], flags=re.IGNORECASE
                    )
                    break
            
            augmented_data.append(augmented)
        
        print(f"Data augmentation: {len(data)} -> {len(augmented_data)} examples")
        return augmented_data
    
    def analyze_dataset(self, data: List[Dict]) -> Dict:
        """Analyze dataset characteristics."""
        
        analysis = {
            'total_examples': len(data),
            'avg_instruction_length': 0,
            'avg_output_length': 0,
            'instruction_length_distribution': [],
            'output_length_distribution': [],
            'common_instruction_patterns': [],
        }
        
        instruction_lengths = []
        output_lengths = []
        instruction_starts = []
        
        for item in data:
            inst_len = len(item['instruction'].split())
            out_len = len(item['output'].split())
            
            instruction_lengths.append(inst_len)
            output_lengths.append(out_len)
            
            # Extract instruction patterns
            first_words = ' '.join(item['instruction'].split()[:3]).lower()
            instruction_starts.append(first_words)
        
        analysis['avg_instruction_length'] = sum(instruction_lengths) / len(instruction_lengths)
        analysis['avg_output_length'] = sum(output_lengths) / len(output_lengths)
        
        # Length distributions
        analysis['instruction_length_distribution'] = {
            'min': min(instruction_lengths),
            'max': max(instruction_lengths),
            'median': sorted(instruction_lengths)[len(instruction_lengths)//2]
        }
        
        analysis['output_length_distribution'] = {
            'min': min(output_lengths),
            'max': max(output_lengths),
            'median': sorted(output_lengths)[len(output_lengths)//2]
        }
        
        # Common patterns
        pattern_counts = Counter(instruction_starts)
        analysis['common_instruction_patterns'] = pattern_counts.most_common(10)
        
        return analysis
    
    def create_balanced_dataset(self, data: List[Dict], categories: List[str] = None) -> List[Dict]:
        """Create a balanced dataset across different categories."""
        
        if categories is None:
            # Auto-detect categories based on instruction patterns
            categories = self._auto_detect_categories(data)
        
        # Categorize examples
        categorized_data = {cat: [] for cat in categories}
        uncategorized = []
        
        for item in data:
            instruction = item['instruction'].lower()
            categorized = False
            
            for category in categories:
                if category.lower() in instruction:
                    categorized_data[category].append(item)
                    categorized = True
                    break
            
            if not categorized:
                uncategorized.append(item)
        
        # Balance categories
        min_category_size = min(len(examples) for examples in categorized_data.values() if examples)
        balanced_data = []
        
        for category, examples in categorized_data.items():
            if examples:
                # Sample from each category
                selected = random.sample(examples, min(len(examples), min_category_size))
                balanced_data.extend(selected)
        
        # Add some uncategorized examples
        if uncategorized:
            additional_size = len(balanced_data) // 4  # 25% uncategorized
            selected_uncategorized = random.sample(
                uncategorized, 
                min(len(uncategorized), additional_size)
            )
            balanced_data.extend(selected_uncategorized)
        
        print(f"Balanced dataset: {len(data)} -> {len(balanced_data)} examples")
        return balanced_data
    
    def _auto_detect_categories(self, data: List[Dict]) -> List[str]:
        """Auto-detect common categories in the dataset."""
        
        # Common instruction types
        patterns = [
            'explain', 'describe', 'write', 'create', 'generate',
            'translate', 'summarize', 'analyze', 'compare', 'define'
        ]
        
        detected_categories = []
        
        for pattern in patterns:
            count = sum(1 for item in data if pattern in item['instruction'].lower())
            if count >= 5:  # Minimum threshold
                detected_categories.append(pattern)
        
        return detected_categories[:10]  # Limit to top 10 categories
    
    def export_processed_data(self, data: List[Dict], output_path: str, format: str = 'json'):
        """Export processed data in specified format."""
        
        if format == 'json':
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, indent=2, ensure_ascii=False)
        
        elif format == 'jsonl':
            with open(output_path, 'w', encoding='utf-8') as f:
                for item in data:
                    f.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        elif format == 'csv':
            df = pd.DataFrame(data)
            df.to_csv(output_path, index=False)
        
        print(f"Data exported to {output_path} in {format} format")

# Example usage
if __name__ == "__main__":
    processor = SFTDataProcessor()
    
    # Sample data for demonstration
    sample_data = [
        {
            "instruction": "Explain machine learning",
            "input": "",
            "output": "Machine learning is a subset of AI that enables computers to learn from data."
        },
        {
            "instruction": "Write a Python function to add two numbers",
            "input": "",
            "output": "def add(a, b):\n    return a + b"
        },
        # Add more examples...
    ]
    
    print("Original dataset analysis:")
    analysis = processor.analyze_dataset(sample_data)
    for key, value in analysis.items():
        if key != 'common_instruction_patterns':
            print(f"{key}: {value}")
    
    # Process the data
    print("\nProcessing data...")
    
    # Deduplicate
    deduplicated = processor.deduplicate_data(sample_data)
    
    # Filter by quality
    filtered = processor.filter_by_quality(deduplicated)
    
    # Augment data
    augmented = processor.augment_data(filtered, augmentation_factor=0.3)
    
    # Create balanced dataset
    balanced = processor.create_balanced_dataset(augmented)
    
    print("\nFinal dataset analysis:")
    final_analysis = processor.analyze_dataset(balanced)
    for key, value in final_analysis.items():
        if key != 'common_instruction_patterns':
            print(f"{key}: {value}")
    
    # Export processed data
    processor.export_processed_data(balanced, "processed_sft_data.json")
    
    print("\nData processing completed!")

Advanced SFT Techniques

Beyond basic supervised fine-tuning, several advanced techniques can improve training efficiency and model performance:

1. Gradient Checkpointing and Memory Optimization:

Gradient Checkpointing:
- Trades computation for memory
- Enables training larger models on limited hardware
- Typically reduces memory usage by 50-70%

Mixed Precision Training:
- Uses both 16-bit and 32-bit floating point numbers
- Speeds up training while maintaining stability
- Automatic Mixed Precision (AMP) handles this automatically

DeepSpeed Integration:
- ZeRO optimizer for distributed training
- Significantly reduces memory usage
- Enables training models that wouldn't fit on single GPU

2. Learning Rate Scheduling:

Warmup Strategies:
- Linear warmup: Gradually increase learning rate
- Cosine warmup: Smooth increase with cosine function
- Prevents early training instability

Decay Strategies:
- Linear decay: Steady decrease over time
- Cosine decay: Smooth decrease following cosine curve
- Step decay: Discrete steps down at intervals

Adaptive Learning Rates:
- Different rates for different parameter groups
- Higher rates for task-specific layers
- Lower rates for pre-trained parameters

3. Regularization Techniques:

Dropout:
- Apply to attention layers and feedforward networks
- Helps prevent overfitting on small datasets
- Typical values: 0.1-0.3

Weight Decay:
- L2 regularization to prevent large weights
- Helps with generalization
- Typical values: 0.01-0.1

Label Smoothing:
- Softens hard targets
- Improves calibration and generalization
- Typical epsilon: 0.1

4. Data Efficiency Techniques:

Curriculum Learning:
- Start with easier examples
- Gradually increase difficulty
- Can improve convergence and final performance

Active Learning:
- Iteratively select most informative examples
- Maximizes learning from limited data
- Particularly useful for domain-specific applications

Few-Shot Learning:
- Learn from very few examples per task
- Useful when data is scarce
- Can be combined with meta-learning approaches

5. Multi-Task and Transfer Learning:

Multi-Task Fine-Tuning:
- Train on multiple related tasks simultaneously
- Shared representations benefit all tasks
- Requires careful task balancing

Sequential Fine-Tuning:
- Fine-tune on related tasks first
- Then fine-tune on target task
- Can improve performance on low-resource tasks

Domain Adaptation:
- Gradual adaptation from source to target domain
- Useful when target domain data is limited
- Can use domain adversarial training

6. Evaluation and Monitoring:

Perplexity Tracking:
- Monitor language modeling performance
- Lower perplexity generally indicates better fit
- Watch for overfitting patterns

Task-Specific Metrics:
- ROUGE for summarization
- BLEU for translation
- Exact match for QA
- Custom metrics for domain tasks

Validation Strategies:
- Hold-out validation set
- Cross-validation for small datasets
- Temporal splits for time-sensitive data

7. Hyperparameter Optimization:

Grid Search:
- Systematic exploration of hyperparameter space
- Good for small number of parameters
- Computationally expensive

Random Search:
- Random sampling of hyperparameter combinations
- Often more efficient than grid search
- Good for high-dimensional spaces

Bayesian Optimization:
- Uses previous results to guide search
- More sample-efficient than random/grid search
- Tools: Optuna, Hyperopt, Ray Tune

8. Model Distillation:

Knowledge Distillation:
- Train smaller student model to mimic larger teacher
- Maintains much of the performance with less compute
- Useful for deployment constraints

Progressive Distillation:
- Gradually reduce model size through multiple stages
- Can achieve better size/performance trade-offs
- Particularly effective for transformer models
Code Example
# Advanced SFT techniques implementation
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, TrainingArguments, 
    Trainer, get_linear_schedule_with_warmup
)
from torch.optim import AdamW
import numpy as np
from typing import Dict, List, Optional
import wandb
from torch.cuda.amp import GradScaler, autocast

class AdvancedSFTTrainer:
    def __init__(self, model_name: str, use_deepspeed: bool = False):
        self.model_name = model_name
        self.use_deepspeed = use_deepspeed
        self.model = None
        self.tokenizer = None
        self.scaler = GradScaler() if torch.cuda.is_available() else None
        
    def setup_model_with_optimizations(self, 
                                     gradient_checkpointing: bool = True,
                                     mixed_precision: bool = True):
        """Setup model with memory and training optimizations."""
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model with optimizations
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.bfloat16 if mixed_precision else torch.float32,
            device_map="auto" if not self.use_deepspeed else None,
            use_cache=False,  # Disable for training
        )
        
        if gradient_checkpointing:
            self.model.gradient_checkpointing_enable()
        
        print(f"Model loaded with optimizations:")
        print(f"- Gradient checkpointing: {gradient_checkpointing}")
        print(f"- Mixed precision: {mixed_precision}")
        print(f"- DeepSpeed: {self.use_deepspeed}")
    
    def create_custom_optimizer(self, 
                               learning_rate: float = 5e-5,
                               weight_decay: float = 0.01,
                               use_layer_wise_lr: bool = False) -> AdamW:
        """Create optimized AdamW optimizer with optional layer-wise learning rates."""
        
        if use_layer_wise_lr:
            # Different learning rates for different layers
            parameter_groups = []
            
            # Embedding layers - lower LR
            embedding_params = []
            for name, param in self.model.named_parameters():
                if 'embed' in name or 'wte' in name or 'wpe' in name:
                    embedding_params.append(param)
            
            if embedding_params:
                parameter_groups.append({
                    'params': embedding_params,
                    'lr': learning_rate * 0.1,  # 10x lower
                    'weight_decay': weight_decay
                })
            
            # Output layers - higher LR
            output_params = []
            for name, param in self.model.named_parameters():
                if 'lm_head' in name or 'output' in name:
                    output_params.append(param)
            
            if output_params:
                parameter_groups.append({
                    'params': output_params,
                    'lr': learning_rate * 2.0,  # 2x higher
                    'weight_decay': weight_decay
                })
            
            # All other parameters - standard LR
            other_params = []
            embedding_names = {id(p) for p in embedding_params}
            output_names = {id(p) for p in output_params}
            
            for param in self.model.parameters():
                if id(param) not in embedding_names and id(param) not in output_names:
                    other_params.append(param)
            
            if other_params:
                parameter_groups.append({
                    'params': other_params,
                    'lr': learning_rate,
                    'weight_decay': weight_decay
                })
            
            optimizer = AdamW(parameter_groups, betas=(0.9, 0.999), eps=1e-8)
            
        else:
            # Standard optimizer
            optimizer = AdamW(
                self.model.parameters(),
                lr=learning_rate,
                weight_decay=weight_decay,
                betas=(0.9, 0.999),
                eps=1e-8
            )
        
        return optimizer
    
    def create_curriculum_dataset(self, dataset, difficulty_metric: str = 'length'):
        """Create curriculum learning dataset ordered by difficulty."""
        
        def calculate_difficulty(example):
            if difficulty_metric == 'length':
                return len(example['input_ids'])
            elif difficulty_metric == 'vocab_complexity':
                # Simple vocabulary complexity metric
                unique_tokens = len(set(example['input_ids']))
                total_tokens = len(example['input_ids'])
                return unique_tokens / total_tokens
            else:
                return 0.5  # Default neutral difficulty
        
        # Calculate difficulty scores
        difficulties = [calculate_difficulty(example) for example in dataset]
        
        # Sort by difficulty (easy to hard)
        sorted_indices = sorted(range(len(dataset)), key=lambda i: difficulties[i])
        
        # Create curriculum dataset
        curriculum_dataset = dataset.select(sorted_indices)
        
        print(f"Created curriculum dataset with {len(curriculum_dataset)} examples")
        return curriculum_dataset
    
    def train_with_advanced_techniques(self,
                                     train_dataset,
                                     eval_dataset=None,
                                     output_dir="./advanced_sft",
                                     num_epochs=3,
                                     batch_size=4,
                                     learning_rate=5e-5,
                                     use_curriculum=True,
                                     use_label_smoothing=True,
                                     label_smoothing_factor=0.1,
                                     use_cosine_schedule=True,
                                     warmup_ratio=0.03):
        """Train with advanced techniques."""
        
        # Setup curriculum learning
        if use_curriculum:
            train_dataset = self.create_curriculum_dataset(train_dataset)
        
        # Calculate training steps
        total_steps = (len(train_dataset) // batch_size) * num_epochs
        warmup_steps = int(total_steps * warmup_ratio)
        
        # Create custom optimizer
        optimizer = self.create_custom_optimizer(
            learning_rate=learning_rate,
            use_layer_wise_lr=True
        )
        
        # Create learning rate scheduler
        if use_cosine_schedule:
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=total_steps
            )
        else:
            scheduler = None
        
        # Custom loss function with label smoothing
        class LabelSmoothingLoss(nn.Module):
            def __init__(self, smoothing=0.1, vocab_size=None):
                super().__init__()
                self.smoothing = smoothing
                self.vocab_size = vocab_size or len(self.tokenizer)
                
            def forward(self, pred, target):
                # Reshape predictions and targets
                pred = pred.view(-1, pred.size(-1))
                target = target.view(-1)
                
                # Create smoothed targets
                confidence = 1.0 - self.smoothing
                smooth_value = self.smoothing / (self.vocab_size - 1)
                
                # One-hot encode targets
                one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), confidence)
                one_hot += smooth_value
                
                # Compute cross entropy with smoothed labels
                log_probs = torch.log_softmax(pred, dim=1)
                loss = -torch.sum(one_hot * log_probs, dim=1)
                
                # Mask padding tokens
                mask = (target != -100).float()
                loss = loss * mask
                
                return loss.sum() / mask.sum()
        
        # Custom trainer with advanced features
        class AdvancedTrainer(Trainer):
            def __init__(self, *args, label_smoothing_loss=None, **kwargs):
                super().__init__(*args, **kwargs)
                self.label_smoothing_loss = label_smoothing_loss
                self.training_step = 0
                
            def compute_loss(self, model, inputs, return_outputs=False):
                labels = inputs.get("labels")
                outputs = model(**inputs)
                logits = outputs.get("logits")
                
                if self.label_smoothing_loss and labels is not None:
                    loss = self.label_smoothing_loss(logits, labels)
                else:
                    loss = outputs.loss
                
                return (loss, outputs) if return_outputs else loss
            
            def training_step(self, model, inputs):
                """Custom training step with mixed precision."""
                model.train()
                inputs = self._prepare_inputs(inputs)
                
                if self.use_amp:
                    with autocast():
                        loss = self.compute_loss(model, inputs)
                else:
                    loss = self.compute_loss(model, inputs)
                
                if self.args.n_gpu > 1:
                    loss = loss.mean()
                
                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps
                
                if self.use_amp:
                    self.scaler.scale(loss).backward()
                else:
                    loss.backward()
                
                self.training_step += 1
                
                return loss.detach()
            
            def optimizer_step(self, optimizer):
                """Custom optimizer step with gradient clipping."""
                if self.use_amp:
                    self.scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    self.scaler.step(optimizer)
                    self.scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    optimizer.step()
                
                optimizer.zero_grad()
        
        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            gradient_accumulation_steps=1,
            learning_rate=learning_rate,
            weight_decay=0.01,
            max_grad_norm=1.0,
            warmup_steps=warmup_steps,
            logging_steps=10,
            save_steps=500,
            eval_steps=500 if eval_dataset else None,
            evaluation_strategy="steps" if eval_dataset else "no",
            save_strategy="steps",
            load_best_model_at_end=True if eval_dataset else False,
            metric_for_best_model="eval_loss" if eval_dataset else None,
            greater_is_better=False,
            report_to="wandb",
            run_name="advanced_sft",
            bf16=True,
            dataloader_pin_memory=False,
            gradient_checkpointing=True,
            remove_unused_columns=False,
        )
        
        # Create loss function
        loss_fn = LabelSmoothingLoss(
            smoothing=label_smoothing_factor,
            vocab_size=len(self.tokenizer)
        ) if use_label_smoothing else None
        
        # Create trainer
        trainer = AdvancedTrainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            optimizers=(optimizer, scheduler),
            label_smoothing_loss=loss_fn,
        )
        
        # Initialize wandb
        wandb.init(
            project="advanced-sft",
            config={
                "model_name": self.model_name,
                "num_epochs": num_epochs,
                "batch_size": batch_size,
                "learning_rate": learning_rate,
                "use_curriculum": use_curriculum,
                "use_label_smoothing": use_label_smoothing,
                "label_smoothing_factor": label_smoothing_factor,
            }
        )
        
        # Training
        print("Starting advanced SFT training...")
        trainer.train()
        
        # Save model
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)
        
        wandb.finish()
        print(f"Advanced SFT completed! Model saved to {output_dir}")
        
        return trainer
    
    def evaluate_with_multiple_metrics(self, test_dataset, metrics=['perplexity', 'bleu']):
        """Evaluate model with multiple metrics."""
        
        from sklearn.metrics import accuracy_score
        import sacrebleu
        
        self.model.eval()
        results = {}
        
        if 'perplexity' in metrics:
            # Calculate perplexity
            total_loss = 0
            total_tokens = 0
            
            for example in test_dataset:
                inputs = {k: torch.tensor(v).unsqueeze(0).to(self.model.device) 
                         for k, v in example.items() if k in ['input_ids', 'attention_mask']}
                
                with torch.no_grad():
                    outputs = self.model(**inputs, labels=inputs['input_ids'])
                    loss = outputs.loss.item()
                    num_tokens = inputs['input_ids'].numel()
                    
                    total_loss += loss * num_tokens
                    total_tokens += num_tokens
            
            perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
            results['perplexity'] = perplexity.item()
        
        if 'bleu' in metrics:
            # Calculate BLEU score (simplified example)
            references = []
            predictions = []
            
            for example in test_dataset[:100]:  # Sample for efficiency
                # This is a simplified example - adapt based on your data format
                input_ids = example['input_ids'][:50]  # First 50 tokens as input
                target_ids = example['input_ids'][50:]  # Rest as target
                
                inputs = torch.tensor(input_ids).unsqueeze(0).to(self.model.device)
                
                with torch.no_grad():
                    outputs = self.model.generate(
                        inputs,
                        max_new_tokens=len(target_ids),
                        do_sample=False,
                        pad_token_id=self.tokenizer.eos_token_id
                    )
                
                pred_text = self.tokenizer.decode(outputs[0][len(input_ids):], skip_special_tokens=True)
                ref_text = self.tokenizer.decode(target_ids, skip_special_tokens=True)
                
                predictions.append(pred_text)
                references.append([ref_text])  # BLEU expects list of references
            
            if predictions and references:
                bleu_score = sacrebleu.corpus_bleu(predictions, references)
                results['bleu'] = bleu_score.score
        
        return results

# Example usage
if __name__ == "__main__":
    # Initialize advanced trainer
    trainer = AdvancedSFTTrainer("mistralai/Mistral-7B-v0.1")
    
    # Setup model with optimizations
    trainer.setup_model_with_optimizations(
        gradient_checkpointing=True,
        mixed_precision=True
    )
    
    # Prepare sample dataset (replace with your actual data)
    from datasets import Dataset
    
    sample_data = [
        {"input_ids": [1, 2, 3, 4, 5] * 100, "attention_mask": [1] * 500},
        {"input_ids": [6, 7, 8, 9, 10] * 80, "attention_mask": [1] * 400},
        # Add more examples...
    ]
    
    train_dataset = Dataset.from_list(sample_data)
    eval_dataset = Dataset.from_list(sample_data[:2])  # Small eval set
    
    # Train with advanced techniques
    trainer.train_with_advanced_techniques(
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        output_dir="./advanced_sft_model",
        num_epochs=2,
        batch_size=2,
        learning_rate=5e-5,
        use_curriculum=True,
        use_label_smoothing=True,
        label_smoothing_factor=0.1,
        use_cosine_schedule=True,
    )
    
    # Evaluate with multiple metrics
    results = trainer.evaluate_with_multiple_metrics(eval_dataset)
    print("Evaluation results:", results)

On this page

Understanding SFTData Preparation for SFTAdvanced SFT Techniques