L
Initializing Studio...
Master Reinforcement Learning from Human Feedback (RLHF) to align language models with human preferences, create safer AI systems, and improve response quality through human feedback integration.
Align language models with human preferences and values through iterative feedback and reinforcement learning.
Build safer AI systems by incorporating human judgment and reducing harmful or inappropriate outputs.
Significantly improve response quality, relevance, and engagement through preference-based optimization.
Continuously improve model behavior through ongoing human feedback collection and model updates.
# RLHF implementation overview using TRL (Transformer Reinforcement Learning)
import torch
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
AutoModelForSequenceClassification
)
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from datasets import Dataset
import numpy as np
from typing import List, Dict
class RLHFPipeline:
def __init__(self, model_name: str, reward_model_name: str = None):
self.model_name = model_name
self.reward_model_name = reward_model_name
self.tokenizer = None
self.sft_model = None
self.reward_model = None
self.ppo_model = None
def stage1_supervised_fine_tuning(self, sft_dataset: Dataset):
"""Stage 1: Supervised Fine-Tuning on human demonstrations."""
print("Stage 1: Supervised Fine-Tuning")
print("=" * 40)
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.sft_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Prepare SFT dataset
def format_sft_example(example):
"""Format example for SFT training."""
prompt = f"Human: {example['prompt']}\n\nAssistant: "
response = example['chosen'] # Use human-preferred response
full_text = prompt + response + self.tokenizer.eos_token
return {"text": full_text}
formatted_dataset = sft_dataset.map(format_sft_example)
# SFT training (simplified - use full Trainer in practice)
print(f"SFT dataset size: {len(formatted_dataset)}")
print("SFT training would happen here...")
print("✓ Stage 1 completed: SFT model ready")
# Save SFT model
self.sft_model.save_pretrained("./sft_model")
self.tokenizer.save_pretrained("./sft_model")
def stage2_reward_model_training(self, preference_dataset: Dataset):
"""Stage 2: Train reward model on human preferences."""
print("\nStage 2: Reward Model Training")
print("=" * 40)
# Load model for reward modeling (typically smaller than main model)
reward_model_name = self.reward_model_name or self.model_name
self.reward_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_name,
num_labels=1, # Single scalar reward
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Prepare preference dataset
def format_preference_example(example):
"""Format preference comparison for reward model training."""
prompt = example['prompt']
chosen = example['chosen']
rejected = example['rejected']
# Create prompt + response pairs
chosen_text = f"Human: {prompt}\n\nAssistant: {chosen}"
rejected_text = f"Human: {prompt}\n\nAssistant: {rejected}"
return {
'chosen': chosen_text,
'rejected': rejected_text
}
formatted_preferences = preference_dataset.map(format_preference_example)
# Reward model training (simplified)
print(f"Preference dataset size: {len(formatted_preferences)}")
print("Reward model training would happen here...")
print("Training on preference pairs (chosen > rejected)")
print("✓ Stage 2 completed: Reward model ready")
# Save reward model
self.reward_model.save_pretrained("./reward_model")
def stage3_ppo_training(self, prompts_dataset: Dataset, ppo_config: PPOConfig = None):
"""Stage 3: PPO training using reward model."""
print("\nStage 3: PPO Training")
print("=" * 40)
# Default PPO configuration
if ppo_config is None:
ppo_config = PPOConfig(
model_name="./sft_model",
learning_rate=1.41e-5,
batch_size=16,
mini_batch_size=4,
gradient_accumulation_steps=1,
optimize_cuda_cache=True,
early_stopping=False,
target_kl=0.1,
ppo_epochs=4,
seed=0,
init_kl_coef=0.2,
adap_kl_ctrl=True,
)
# Load model with value head for PPO
self.ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(
"./sft_model",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Load reward model
reward_model = AutoModelForSequenceClassification.from_pretrained(
"./reward_model",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Create PPO trainer
ppo_trainer = PPOTrainer(
config=ppo_config,
model=self.ppo_model,
ref_model=None, # Will use model copy as reference
tokenizer=self.tokenizer,
)
# Prepare prompts for PPO training
prompts = [f"Human: {prompt}\n\nAssistant: " for prompt in prompts_dataset['prompt']]
print(f"PPO training on {len(prompts)} prompts")
# PPO training loop (simplified)
for epoch in range(3): # Limited epochs for demonstration
print(f"\nPPO Epoch {epoch + 1}")
for batch_idx in range(0, len(prompts), ppo_config.batch_size):
batch_prompts = prompts[batch_idx:batch_idx + ppo_config.batch_size]
# Generate responses
prompt_tensors = [
self.tokenizer.encode(prompt, return_tensors="pt")[0]
for prompt in batch_prompts
]
# Generate responses from current policy
response_tensors = []
for prompt_tensor in prompt_tensors:
response = self.ppo_model.generate(
prompt_tensor.unsqueeze(0),
max_new_tokens=50,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id
)
response_tensors.append(response[0])
# Calculate rewards using reward model
rewards = []
for prompt_tensor, response_tensor in zip(prompt_tensors, response_tensors):
# Combine prompt and response
full_text = self.tokenizer.decode(response_tensor, skip_special_tokens=True)
# Get reward score
inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(reward_model.device) for k, v in inputs.items()}
with torch.no_grad():
reward_score = reward_model(**inputs).logits[0, 0].item()
rewards.append(reward_score)
# Convert to tensors
rewards = [torch.tensor(r) for r in rewards]
# PPO training step
stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
if batch_idx % 4 == 0: # Log every few batches
print(f" Batch {batch_idx//ppo_config.batch_size + 1}: "
f"Mean reward: {np.mean([r.item() for r in rewards]):.3f}")
print("✓ Stage 3 completed: RLHF training finished")
# Save final model
self.ppo_model.save_pretrained("./rlhf_model")
return ppo_trainer
def evaluate_rlhf_model(self, test_prompts: List[str]):
"""Evaluate the RLHF-trained model."""
print("\nEvaluating RLHF Model")
print("=" * 40)
if self.ppo_model is None:
# Load the trained model
self.ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("./rlhf_model")
self.ppo_model.eval()
for i, prompt in enumerate(test_prompts, 1):
formatted_prompt = f"Human: {prompt}\n\nAssistant: "
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
inputs = {k: v.to(self.ppo_model.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = self.ppo_model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
print(f"\nTest {i}:")
print(f"Human: {prompt}")
print(f"Assistant: {response}")
print("-" * 40)
# Sample data preparation functions
def create_sft_dataset():
"""Create sample SFT dataset with human demonstrations."""
sft_examples = [
{
"prompt": "Explain quantum computing in simple terms.",
"chosen": "Quantum computing uses quantum mechanical phenomena like superposition and entanglement to process information in ways that classical computers cannot. Unlike classical bits that are either 0 or 1, quantum bits (qubits) can exist in multiple states simultaneously, potentially allowing quantum computers to solve certain problems exponentially faster than classical computers."
},
{
"prompt": "How can I improve my sleep quality?",
"chosen": "Here are some evidence-based strategies to improve sleep quality: 1) Maintain a consistent sleep schedule, 2) Create a relaxing bedtime routine, 3) Ensure your bedroom is cool, dark, and quiet, 4) Avoid caffeine and screens before bedtime, 5) Get regular exercise during the day, and 6) Consider relaxation techniques like meditation or deep breathing."
},
{
"prompt": "Write a Python function to reverse a string.",
"chosen": "Here's a simple Python function to reverse a string:\n\ndef reverse_string(s):\n return s[::-1]\n\n# Example usage:\noriginal = 'hello'\nreversed_str = reverse_string(original)\nprint(reversed_str) # Output: 'olleh'\n\nThis uses Python's slice notation with a step of -1 to reverse the string efficiently."
}
]
return Dataset.from_list(sft_examples)
def create_preference_dataset():
"""Create sample preference dataset for reward model training."""
preference_examples = [
{
"prompt": "What's the capital of France?",
"chosen": "The capital of France is Paris. It's a beautiful city known for its art, culture, cuisine, and iconic landmarks like the Eiffel Tower and Louvre Museum.",
"rejected": "Paris is the capital. It's in France and has some buildings and stuff."
},
{
"prompt": "How do I bake a chocolate cake?",
"chosen": "To bake a chocolate cake, you'll need: flour, sugar, cocoa powder, eggs, butter, baking powder, and milk. Mix dry ingredients, cream butter and sugar, add eggs, then alternate adding dry ingredients and milk. Bake at 350°F for 25-30 minutes. Let me know if you'd like a detailed recipe!",
"rejected": "Just mix some chocolate stuff together and put it in the oven until it looks done. Should work fine."
},
{
"prompt": "Is it safe to eat raw eggs?",
"chosen": "Eating raw eggs carries some risk of Salmonella infection, though the risk is relatively low (about 1 in 20,000 eggs). Pasteurized eggs are safer for raw consumption. If you're pregnant, elderly, or immunocompromised, it's best to avoid raw eggs. For recipes requiring raw eggs, consider pasteurized alternatives.",
"rejected": "Raw eggs are totally fine to eat, there's no risk at all. Eat as many as you want!"
}
]
return Dataset.from_list(preference_examples)
def create_prompts_dataset():
"""Create sample prompts for PPO training."""
prompts = [
"Explain the importance of exercise.",
"What's the best way to learn a new language?",
"How does photosynthesis work?",
"Give me tips for public speaking.",
"What are the benefits of meditation?",
"How do I start investing in stocks?",
"Explain machine learning to a beginner.",
"What's the difference between weather and climate?",
]
return Dataset.from_dict({"prompt": prompts})
# Main execution example
if __name__ == "__main__":
# Initialize RLHF pipeline
rlhf = RLHFPipeline(
model_name="microsoft/DialoGPT-medium", # Use smaller model for demo
reward_model_name="microsoft/DialoGPT-small"
)
# Create sample datasets
sft_data = create_sft_dataset()
preference_data = create_preference_dataset()
prompts_data = create_prompts_dataset()
# Run RLHF pipeline
print("Starting RLHF Pipeline")
print("=" * 50)
# Stage 1: SFT
rlhf.stage1_supervised_fine_tuning(sft_data)
# Stage 2: Reward Model
rlhf.stage2_reward_model_training(preference_data)
# Stage 3: PPO Training
ppo_config = PPOConfig(
model_name="./sft_model",
learning_rate=1.41e-5,
batch_size=4, # Small batch for demo
mini_batch_size=2,
ppo_epochs=2, # Fewer epochs for demo
target_kl=0.1,
)
rlhf.stage3_ppo_training(prompts_data, ppo_config)
# Evaluate final model
test_prompts = [
"What's the best way to stay healthy?",
"Explain artificial intelligence.",
"How do I write a good email?"
]
rlhf.evaluate_rlhf_model(test_prompts)
print("\nRLHF Pipeline completed successfully!")# Comprehensive reward model implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
AutoConfig, Trainer, TrainingArguments
)
from datasets import Dataset
from sklearn.metrics import accuracy_score
import numpy as np
from typing import Dict, List, Tuple, Optional
class RewardModelTrainer:
def __init__(self, base_model_name: str, max_length: int = 512):
self.base_model_name = base_model_name
self.max_length = max_length
self.tokenizer = None
self.model = None
def setup_reward_model(self, dropout_rate: float = 0.1):
"""Setup reward model architecture."""
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load config and modify for reward modeling
config = AutoConfig.from_pretrained(self.base_model_name)
config.num_labels = 1 # Single scalar reward
config.hidden_dropout_prob = dropout_rate
config.attention_probs_dropout_prob = dropout_rate
# Load model for sequence classification
self.model = AutoModelForSequenceClassification.from_pretrained(
self.base_model_name,
config=config,
torch_dtype=torch.bfloat16
)
# Modify the classifier head for reward modeling
self.model.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(config.hidden_size, config.hidden_size),
nn.Tanh(),
nn.Dropout(dropout_rate),
nn.Linear(config.hidden_size, 1) # Single reward score
)
print(f"Reward model setup completed:")
print(f"- Base model: {self.base_model_name}")
print(f"- Parameters: {self.model.num_parameters():,}")
print(f"- Dropout rate: {dropout_rate}")
def prepare_preference_dataset(self, preference_data: List[Dict]) -> Dataset:
"""Prepare preference comparison dataset."""
def tokenize_pair(example):
"""Tokenize chosen and rejected responses."""
prompt = example['prompt']
chosen = example['chosen']
rejected = example['rejected']
# Create full texts
chosen_text = f"{prompt}\n\n{chosen}"
rejected_text = f"{prompt}\n\n{rejected}"
# Tokenize both
chosen_tokens = self.tokenizer(
chosen_text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors="pt"
)
rejected_tokens = self.tokenizer(
rejected_text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors="pt"
)
return {
'chosen_input_ids': chosen_tokens['input_ids'].squeeze(),
'chosen_attention_mask': chosen_tokens['attention_mask'].squeeze(),
'rejected_input_ids': rejected_tokens['input_ids'].squeeze(),
'rejected_attention_mask': rejected_tokens['attention_mask'].squeeze(),
}
# Convert to dataset and tokenize
dataset = Dataset.from_list(preference_data)
tokenized_dataset = dataset.map(tokenize_pair, remove_columns=dataset.column_names)
print(f"Prepared preference dataset with {len(tokenized_dataset)} pairs")
return tokenized_dataset
def create_pairwise_dataset(self, tokenized_dataset: Dataset) -> Dataset:
"""Create pairwise dataset for Bradley-Terry training."""
pairwise_examples = []
for example in tokenized_dataset:
# Chosen example (label = 1)
pairwise_examples.append({
'input_ids': example['chosen_input_ids'],
'attention_mask': example['chosen_attention_mask'],
'labels': torch.tensor(1.0) # Chosen is better
})
# Rejected example (label = 0)
pairwise_examples.append({
'input_ids': example['rejected_input_ids'],
'attention_mask': example['rejected_attention_mask'],
'labels': torch.tensor(0.0) # Rejected is worse
})
return Dataset.from_list(pairwise_examples)
def compute_pairwise_loss(self, chosen_rewards, rejected_rewards):
"""Compute Bradley-Terry pairwise ranking loss."""
# Bradley-Terry loss: -log(sigmoid(chosen - rejected))
diff = chosen_rewards - rejected_rewards
loss = -F.logsigmoid(diff).mean()
return loss
def train_reward_model(self,
train_dataset: Dataset,
eval_dataset: Dataset = None,
output_dir: str = "./reward_model",
num_epochs: int = 3,
batch_size: int = 8,
learning_rate: float = 2e-5,
warmup_ratio: float = 0.1):
"""Train the reward model on preference data."""
# Custom trainer for pairwise ranking
class RewardTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prediction_step_count = 0
def compute_loss(self, model, inputs, return_outputs=False):
"""Compute pairwise ranking loss."""
# Split batch into chosen and rejected
batch_size = inputs['input_ids'].size(0) // 2
chosen_inputs = {
'input_ids': inputs['input_ids'][:batch_size],
'attention_mask': inputs['attention_mask'][:batch_size]
}
rejected_inputs = {
'input_ids': inputs['input_ids'][batch_size:],
'attention_mask': inputs['attention_mask'][batch_size:]
}
# Get reward scores
chosen_outputs = model(**chosen_inputs)
rejected_outputs = model(**rejected_inputs)
chosen_rewards = chosen_outputs.logits.squeeze(-1)
rejected_rewards = rejected_outputs.logits.squeeze(-1)
# Compute pairwise loss
loss = self.compute_pairwise_loss(chosen_rewards, rejected_rewards)
return (loss, {'chosen_rewards': chosen_rewards, 'rejected_rewards': rejected_rewards}) if return_outputs else loss
def compute_pairwise_loss(self, chosen_rewards, rejected_rewards):
"""Bradley-Terry loss implementation."""
diff = chosen_rewards - rejected_rewards
return -F.logsigmoid(diff).mean()
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""Custom prediction step for evaluation."""
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
outputs = model(**inputs)
rewards = outputs.logits.squeeze(-1)
# For evaluation, we compute accuracy of preference predictions
batch_size = rewards.size(0) // 2
chosen_rewards = rewards[:batch_size]
rejected_rewards = rewards[batch_size:]
# Accuracy: how often chosen > rejected
correct = (chosen_rewards > rejected_rewards).float()
accuracy = correct.mean()
loss = self.compute_pairwise_loss(chosen_rewards, rejected_rewards)
return (loss, accuracy, accuracy) # Return accuracy as both predictions and labels
# Prepare dataset for pairwise training
pairwise_train = self.create_pairwise_dataset(train_dataset)
pairwise_eval = self.create_pairwise_dataset(eval_dataset) if eval_dataset else None
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=1,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=warmup_ratio,
logging_steps=50,
save_steps=500,
eval_steps=500 if pairwise_eval else None,
evaluation_strategy="steps" if pairwise_eval else "no",
save_strategy="steps",
load_best_model_at_end=True if pairwise_eval else False,
metric_for_best_model="eval_loss" if pairwise_eval else None,
greater_is_better=False,
report_to="none",
bf16=True,
dataloader_pin_memory=False,
remove_unused_columns=False,
)
# Create trainer
trainer = RewardTrainer(
model=self.model,
args=training_args,
train_dataset=pairwise_train,
eval_dataset=pairwise_eval,
tokenizer=self.tokenizer,
)
# Train
print(f"Starting reward model training...")
print(f"Training examples: {len(pairwise_train)}")
if pairwise_eval:
print(f"Evaluation examples: {len(pairwise_eval)}")
trainer.train()
# Save model
trainer.save_model()
self.tokenizer.save_pretrained(output_dir)
print(f"Reward model training completed! Saved to {output_dir}")
return trainer
def evaluate_reward_model(self, test_data: List[Dict]) -> Dict:
"""Evaluate reward model on test data."""
self.model.eval()
results = {
'accuracy': 0.0,
'mean_chosen_reward': 0.0,
'mean_rejected_reward': 0.0,
'reward_difference': 0.0
}
correct_predictions = 0
chosen_rewards = []
rejected_rewards = []
for example in test_data:
prompt = example['prompt']
chosen = example['chosen']
rejected = example['rejected']
# Score chosen response
chosen_text = f"{prompt}\n\n{chosen}"
chosen_inputs = self.tokenizer(
chosen_text,
return_tensors="pt",
truncation=True,
max_length=self.max_length
)
chosen_inputs = {k: v.to(self.model.device) for k, v in chosen_inputs.items()}
with torch.no_grad():
chosen_score = self.model(**chosen_inputs).logits[0, 0].item()
# Score rejected response
rejected_text = f"{prompt}\n\n{rejected}"
rejected_inputs = self.tokenizer(
rejected_text,
return_tensors="pt",
truncation=True,
max_length=self.max_length
)
rejected_inputs = {k: v.to(self.model.device) for k, v in rejected_inputs.items()}
with torch.no_grad():
rejected_score = self.model(**rejected_inputs).logits[0, 0].item()
# Check if model correctly prefers chosen over rejected
if chosen_score > rejected_score:
correct_predictions += 1
chosen_rewards.append(chosen_score)
rejected_rewards.append(rejected_score)
# Calculate metrics
results['accuracy'] = correct_predictions / len(test_data)
results['mean_chosen_reward'] = np.mean(chosen_rewards)
results['mean_rejected_reward'] = np.mean(rejected_rewards)
results['reward_difference'] = results['mean_chosen_reward'] - results['mean_rejected_reward']
return results
def get_reward_score(self, prompt: str, response: str) -> float:
"""Get reward score for a prompt-response pair."""
self.model.eval()
text = f"{prompt}\n\n{response}"
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_length
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
score = self.model(**inputs).logits[0, 0].item()
return score
# Example usage and testing
def create_sample_preference_data():
"""Create sample preference data for training."""
preference_data = [
{
"prompt": "Explain the concept of gravity.",
"chosen": "Gravity is a fundamental force of nature that causes objects with mass to attract each other. According to Einstein's theory of general relativity, gravity is not actually a force, but rather the curvature of spacetime caused by mass and energy. This curvature guides the motion of objects, making them appear to be attracted to each other.",
"rejected": "Gravity is when things fall down because they're heavy."
},
{
"prompt": "How do I cook pasta?",
"chosen": "To cook pasta: 1) Bring a large pot of salted water to boil, 2) Add pasta and stir occasionally, 3) Cook according to package directions (usually 8-12 minutes) until al dente, 4) Drain and serve immediately. The key is using plenty of water and not overcooking.",
"rejected": "Put pasta in water and heat it until it's soft. Should be fine."
},
{
"prompt": "What causes climate change?",
"chosen": "Climate change is primarily caused by increased concentrations of greenhouse gases in the atmosphere, mainly from human activities like burning fossil fuels, deforestation, and industrial processes. These gases trap heat from the sun, leading to global warming and associated climate impacts like sea level rise, extreme weather, and ecosystem disruption.",
"rejected": "The sun gets hotter sometimes and that changes the climate. It's natural."
}
]
return preference_data
if __name__ == "__main__":
# Initialize reward model trainer
trainer = RewardModelTrainer("microsoft/DialoGPT-small", max_length=256)
# Setup model
trainer.setup_reward_model(dropout_rate=0.1)
# Create sample data
preference_data = create_sample_preference_data()
# Prepare dataset
dataset = trainer.prepare_preference_dataset(preference_data)
# Split for training and evaluation
train_size = int(0.8 * len(dataset))
train_dataset = dataset.select(range(train_size))
eval_dataset = dataset.select(range(train_size, len(dataset)))
# Train reward model
reward_trainer = trainer.train_reward_model(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir="./sample_reward_model",
num_epochs=2,
batch_size=2, # Small batch for demo
learning_rate=5e-5
)
# Evaluate model
results = trainer.evaluate_reward_model(preference_data)
print("\nReward Model Evaluation Results:")
for metric, value in results.items():
print(f"{metric}: {value:.4f}")
# Test individual scoring
print("\nTesting individual reward scoring:")
test_prompt = "What's the best way to learn programming?"
good_response = "Start with a beginner-friendly language like Python, practice regularly with small projects, and don't be afraid to make mistakes - they're part of learning!"
bad_response = "Just read some books about it."
good_score = trainer.get_reward_score(test_prompt, good_response)
bad_score = trainer.get_reward_score(test_prompt, bad_response)
print(f"Good response score: {good_score:.4f}")
print(f"Bad response score: {bad_score:.4f}")
print(f"Difference: {good_score - bad_score:.4f}")
print("\nReward model training completed!")# Comprehensive PPO implementation for RLHF
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import wandb
@dataclass
class PPOConfig:
"""Configuration for PPO training."""
model_name: str = "gpt2"
learning_rate: float = 1.41e-5
batch_size: int = 64
mini_batch_size: int = 16
gradient_accumulation_steps: int = 1
ppo_epochs: int = 4
max_grad_norm: float = 1.0
clip_range: float = 0.2
clip_range_vf: Optional[float] = None
vf_coef: float = 0.1
target_kl: float = 0.1
init_kl_coef: float = 0.2
adap_kl_ctrl: bool = True
gamma: float = 1.0
lam: float = 0.95
use_score_scaling: bool = False
use_score_norm: bool = False
score_clip: Optional[float] = None
class PPOTrainer:
def __init__(self,
config: PPOConfig,
model: nn.Module,
ref_model: nn.Module,
reward_model: nn.Module,
tokenizer):
self.config = config
self.model = model # Policy model
self.ref_model = ref_model # Reference model (frozen)
self.reward_model = reward_model
self.tokenizer = tokenizer
# Freeze reference model
for param in self.ref_model.parameters():
param.requires_grad = False
# Setup optimizer
self.optimizer = AdamW(
self.model.parameters(),
lr=config.learning_rate,
eps=1e-8,
weight_decay=0.01
)
# KL controller for adaptive penalty
self.kl_ctl = AdaptiveKLController(config.init_kl_coef, config.target_kl)
# Training statistics
self.stats = {
'policy_loss': [],
'value_loss': [],
'total_loss': [],
'kl_divergence': [],
'rewards': [],
'advantages': [],
'approx_kl': [],
}
def generate_responses(self,
prompts: List[str],
max_new_tokens: int = 50,
temperature: float = 0.7,
top_p: float = 0.9) -> Tuple[List[str], torch.Tensor, torch.Tensor]:
"""Generate responses from current policy."""
self.model.eval()
all_responses = []
all_response_tensors = []
all_log_probs = []
for prompt in prompts:
# Tokenize prompt
prompt_tokens = self.tokenizer.encode(prompt, return_tensors="pt")
prompt_tokens = prompt_tokens.to(self.model.device)
# Generate response
with torch.no_grad():
response_tokens = self.model.generate(
prompt_tokens,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=self.tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True
)
# Extract generated tokens (without prompt)
generated_tokens = response_tokens.sequences[0][prompt_tokens.shape[1]:]
# Calculate log probabilities
log_probs = []
for i, token_id in enumerate(generated_tokens):
if i < len(response_tokens.scores):
scores = response_tokens.scores[i][0] # [vocab_size]
log_prob = F.log_softmax(scores, dim=-1)[token_id].item()
log_probs.append(log_prob)
# Decode response
response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
all_responses.append(response_text)
all_response_tensors.append(generated_tokens)
all_log_probs.append(torch.tensor(log_probs))
return all_responses, all_response_tensors, all_log_probs
def compute_rewards(self, prompts: List[str], responses: List[str]) -> List[float]:
"""Compute rewards using reward model."""
self.reward_model.eval()
rewards = []
for prompt, response in zip(prompts, responses):
# Create full text
full_text = f"{prompt}\n\n{response}"
# Tokenize and get reward
inputs = self.tokenizer(
full_text,
return_tensors="pt",
truncation=True,
max_length=512
)
inputs = {k: v.to(self.reward_model.device) for k, v in inputs.items()}
with torch.no_grad():
reward = self.reward_model(**inputs).logits[0, 0].item()
rewards.append(reward)
return rewards
def compute_advantages(self,
rewards: torch.Tensor,
values: torch.Tensor,
masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute GAE (Generalized Advantage Estimation) advantages."""
# Add terminal value of 0
values = torch.cat([values, torch.zeros(1).to(values.device)])
advantages = torch.zeros_like(rewards)
last_gae_lam = 0
# Compute advantages using GAE
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_non_terminal = 0
next_values = 0
else:
next_non_terminal = masks[t + 1]
next_values = values[t + 1]
delta = rewards[t] + self.config.gamma * next_values * next_non_terminal - values[t]
advantages[t] = last_gae_lam = delta + self.config.gamma * self.config.lam * next_non_terminal * last_gae_lam
# Compute returns
returns = advantages + values[:-1]
return advantages, returns
def compute_policy_loss(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
masks: torch.Tensor) -> torch.Tensor:
"""Compute clipped PPO policy loss."""
# Compute probability ratio
log_ratio = log_probs - old_log_probs
ratio = torch.exp(log_ratio)
# Compute clipped surrogate loss
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.config.clip_range, 1 + self.config.clip_range) * advantages
policy_loss = -torch.min(surr1, surr2)
# Apply mask and average
policy_loss = (policy_loss * masks).sum() / masks.sum()
return policy_loss
def compute_value_loss(self,
values: torch.Tensor,
old_values: torch.Tensor,
returns: torch.Tensor,
masks: torch.Tensor) -> torch.Tensor:
"""Compute value function loss."""
if self.config.clip_range_vf is not None:
# Clipped value loss
values_clipped = old_values + torch.clamp(
values - old_values,
-self.config.clip_range_vf,
self.config.clip_range_vf
)
vf_loss1 = (values - returns) ** 2
vf_loss2 = (values_clipped - returns) ** 2
vf_loss = torch.max(vf_loss1, vf_loss2)
else:
# Standard MSE loss
vf_loss = (values - returns) ** 2
# Apply mask and average
vf_loss = (vf_loss * masks).sum() / masks.sum()
return vf_loss
def compute_kl_penalty(self,
log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
masks: torch.Tensor) -> torch.Tensor:
"""Compute KL divergence penalty."""
kl_div = ref_log_probs - log_probs
kl_penalty = (kl_div * masks).sum() / masks.sum()
return kl_penalty
def train_step(self, batch_data: Dict) -> Dict:
"""Perform one PPO training step."""
self.model.train()
# Extract batch data
prompts = batch_data['prompts']
responses = batch_data['responses']
old_log_probs = batch_data['log_probs']
rewards = batch_data['rewards']
# Convert to tensors
rewards = torch.tensor(rewards, dtype=torch.float32, device=self.model.device)
# Generate current policy outputs
current_responses, response_tensors, current_log_probs = self.generate_responses(prompts)
# Compute values (simplified - in practice, use separate value head)
values = torch.zeros_like(rewards) # Placeholder
# Compute advantages
masks = torch.ones_like(rewards) # Simplified masking
advantages, returns = self.compute_advantages(rewards, values, masks)
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# PPO training loop
total_policy_loss = 0
total_value_loss = 0
total_kl_penalty = 0
for ppo_epoch in range(self.config.ppo_epochs):
# Get current log probs (simplified)
curr_log_probs = torch.stack([lp.mean() for lp in current_log_probs])
old_lp = torch.stack([lp.mean() for lp in old_log_probs])
# Compute losses
policy_loss = self.compute_policy_loss(curr_log_probs, old_lp, advantages, masks)
value_loss = self.compute_value_loss(values, values, returns, masks) # Simplified
# Get reference model log probs
ref_log_probs = self.get_ref_log_probs(prompts, responses)
kl_penalty = self.compute_kl_penalty(curr_log_probs, ref_log_probs, masks)
# Total loss
total_loss = (
policy_loss +
self.config.vf_coef * value_loss +
self.kl_ctl.value * kl_penalty
)
# Backward pass
self.optimizer.zero_grad()
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
self.optimizer.step()
# Accumulate losses
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_kl_penalty += kl_penalty.item()
# Update KL controller
mean_kl = total_kl_penalty / self.config.ppo_epochs
self.kl_ctl.update(mean_kl, batch_data['batch_size'])
# Record statistics
stats = {
'policy_loss': total_policy_loss / self.config.ppo_epochs,
'value_loss': total_value_loss / self.config.ppo_epochs,
'kl_penalty': mean_kl,
'kl_coef': self.kl_ctl.value,
'mean_reward': rewards.mean().item(),
'mean_advantage': advantages.mean().item(),
}
return stats
def get_ref_log_probs(self, prompts: List[str], responses: List[str]) -> torch.Tensor:
"""Get log probabilities from reference model."""
self.ref_model.eval()
ref_log_probs = []
with torch.no_grad():
for prompt, response in zip(prompts, responses):
# Simplified calculation
ref_log_prob = torch.tensor(0.0) # Placeholder
ref_log_probs.append(ref_log_prob)
return torch.stack(ref_log_probs)
def train(self, prompts: List[str], num_steps: int = 1000):
"""Main training loop."""
print(f"Starting PPO training for {num_steps} steps...")
for step in range(num_steps):
# Sample batch of prompts
batch_prompts = np.random.choice(prompts, size=self.config.batch_size, replace=True).tolist()
# Generate responses
responses, response_tensors, log_probs = self.generate_responses(batch_prompts)
# Compute rewards
rewards = self.compute_rewards(batch_prompts, responses)
# Prepare batch data
batch_data = {
'prompts': batch_prompts,
'responses': responses,
'log_probs': log_probs,
'rewards': rewards,
'batch_size': len(batch_prompts)
}
# Training step
stats = self.train_step(batch_data)
# Log statistics
if step % 10 == 0:
print(f"Step {step}:")
for key, value in stats.items():
print(f" {key}: {value:.4f}")
print()
# Record stats
for key, value in stats.items():
if key in self.stats:
self.stats[key].append(value)
print("PPO training completed!")
class AdaptiveKLController:
"""Adaptive KL divergence controller."""
def __init__(self, init_kl_coef: float, target_kl: float):
self.value = init_kl_coef
self.target = target_kl
def update(self, current_kl: float, n_steps: int):
"""Update KL coefficient based on current KL divergence."""
if current_kl < self.target / 1.5:
# KL too low, decrease penalty
self.value *= 0.98
elif current_kl > self.target * 1.5:
# KL too high, increase penalty
self.value *= 1.02
# Clamp to reasonable range
self.value = max(0.01, min(2.0, self.value))
# Example usage
if __name__ == "__main__":
# Initialize models
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Policy model (trainable)
policy_model = AutoModelForCausalLM.from_pretrained(model_name)
# Reference model (frozen copy)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)
# Reward model (placeholder - use actual trained reward model)
reward_model = AutoModelForCausalLM.from_pretrained(model_name)
# PPO configuration
config = PPOConfig(
model_name=model_name,
learning_rate=1.41e-5,
batch_size=8, # Small for demo
mini_batch_size=4,
ppo_epochs=2,
target_kl=0.1,
)
# Initialize PPO trainer
ppo_trainer = PPOTrainer(
config=config,
model=policy_model,
ref_model=ref_model,
reward_model=reward_model,
tokenizer=tokenizer
)
# Sample prompts for training
prompts = [
"Human: What's the best way to learn programming?\n\nAssistant:",
"Human: Explain climate change in simple terms.\n\nAssistant:",
"Human: How do I make a good first impression?\n\nAssistant:",
"Human: What are the benefits of exercise?\n\nAssistant:",
]
# Train with PPO
ppo_trainer.train(prompts, num_steps=50) # Short training for demo
# Save trained model
policy_model.save_pretrained("./ppo_trained_model")
tokenizer.save_pretrained("./ppo_trained_model")
print("PPO training completed and model saved!")