A Mini-o1 Implementation for Chain-of-Thought Reasoning
Features β’ Installation β’ Quick Start β’ Documentation β’ Results
Reasoning-RL implements Group Relative Policy Optimization (GRPO) to train language models for mathematical reasoning without supervised fine-tuning data. Inspired by DeepSeek-R1 and OpenAI's o1, this project demonstrates how reinforcement learning can elicit Chain-of-Thought (CoT) behaviors and self-correction capabilities in base language models.
- π GRPO Algorithm: Group-relative advantages for stable policy updates
- β Verifiable Rewards: Symbolic math parsing for objective correctness feedback
- π§ Self-Correction Emergence: Models learn to re-evaluate and correct mistakes
- β‘ Parallel Rollouts: Efficient multi-GPU generation with Ray
| Feature | Description |
|---|---|
| GRPO Training | Group Relative Policy Optimization with KL regularization |
| Verifiable Rewards | Symbolic parsing for mathematical answer verification |
| Multi-Dataset Support | GSM8K (grade school) and MATH (competition) datasets |
| Self-Correction | Reward structure encourages error detection and correction |
| Distributed Training | Ray-based parallel rollout generation across GPUs |
| Comprehensive Eval | Pass@k, accuracy, self-correction rate metrics |
- π Flash Attention 2 support for efficient training
- π Weights & Biases integration for experiment tracking
- π§ LoRA support for memory-efficient fine-tuning
- π Curriculum learning for progressive difficulty
- ποΈ YAML-based configuration system
- Python 3.10+
- CUDA 11.8+ (for GPU training)
- 24GB+ VRAM recommended (A100/H100 for full fine-tuning)
# Clone the repository
git clone https://github.com/yourusername/reasoning-rl.git
cd reasoning-rl
# Create virtual environment and install
uv venv --python 3.11
source .venv/bin/activate
uv pip install -e ".[dev]"pip install -e ".[dev]"# Required for model access
export HF_TOKEN="your_huggingface_token"
# Optional: Weights & Biases
export WANDB_API_KEY="your_wandb_key"# Basic training on GSM8K
python scripts/train.py --model Qwen/Qwen2.5-7B --dataset gsm8k
# Training with LoRA (memory-efficient)
python scripts/train.py --model Qwen/Qwen2.5-7B --dataset gsm8k --lora
# Using config file
python scripts/train.py --config configs/gsm8k.yaml
# Or use the shell script
./scripts/train.sh train# Training
reasoning-rl train --model Qwen/Qwen2.5-7B --dataset gsm8k --wandb my-project
# Evaluation
reasoning-rl evaluate ./outputs/best --dataset gsm8k
# Interactive demo
reasoning-rl demo ./outputs/bestfrom reasoning_rl import GRPOTrainer, GRPOConfig
from reasoning_rl.data import load_gsm8k
# Load data
train_data = load_gsm8k("train")
eval_data = load_gsm8k("test", max_samples=500)
# Configure training
config = GRPOConfig(
model_name="Qwen/Qwen2.5-7B",
group_size=8,
learning_rate=1e-6,
kl_coef=0.05,
)
# Train
trainer = GRPOTrainer(config, train_data, eval_data)
trainer.train()reasoning-rl/
βββ configs/ # YAML configuration files
β βββ default.yaml # Base configuration
β βββ gsm8k.yaml # GSM8K-specific config
β βββ math.yaml # MATH dataset config
β βββ lora.yaml # LoRA training config
βββ scripts/
β βββ train.py # Main training script
β βββ train.sh # Shell wrapper
βββ src/reasoning_rl/
β βββ trainer/ # GRPO trainer implementation
β β βββ grpo_trainer.py # Core trainer
β β βββ grpo_config.py # Configuration dataclass
β βββ rewards/ # Reward functions
β β βββ reward_function.py # Main reward computation
β β βββ symbolic_parser.py # Math expression parsing
β β βββ format_checker.py # CoT format validation
β βββ rollout/ # Generation system
β β βββ generator.py # Rollout generation
β β βββ ray_rollout.py # Distributed generation
β βββ data/ # Dataset loaders
β β βββ gsm8k.py # GSM8K dataset
β β βββ math_dataset.py # MATH dataset
β βββ evaluation/ # Evaluation module
β β βββ evaluator.py # Model evaluation
β β βββ metrics.py # Evaluation metrics
β β βββ benchmark.py # Benchmark runner
β βββ cli.py # Command-line interface
βββ tests/ # Unit tests
βββ pyproject.toml # Project configuration
βββ README.md
GRPO (Group Relative Policy Optimization) trains the model by:
- Group Sampling: Generate G completions per prompt
- Reward Computation: Compute verifiable rewards for each completion
- Advantage Estimation: Normalize rewards within each group
- Policy Update: Update with clipped objective + KL penalty
L = -E[min(r(ΞΈ) Γ A, clip(r(ΞΈ), 1-Ξ΅, 1+Ξ΅) Γ A)] + Ξ² Γ KL(Ο || Ο_ref)
Where:
r(ΞΈ)= policy ratio (new/old)A= group-relative advantageΞ΅= clip range (default 0.2)Ξ²= KL coefficient (default 0.05)
The verifiable reward combines:
| Component | Weight | Description |
|---|---|---|
| Correctness | 1.0 | Binary reward for correct answer |
| Format | 0.1 | CoT structure compliance |
| Self-Correction | 0.2 | Bonus for successful corrections |
| Length | -0.001 | Penalty for excessive length |
Key hyperparameters in configs/default.yaml:
grpo:
group_size: 8 # Completions per prompt
temperature: 0.7 # Sampling temperature
kl_coef: 0.05 # KL divergence penalty
clip_range: 0.2 # PPO clipping range
training:
learning_rate: 1e-6
batch_size: 4
num_epochs: 3| Model | Zero-Shot | GRPO (Ours) | Improvement |
|---|---|---|---|
| Qwen2.5-7B | 57.2% | 71.5% | +25% |
| Llama-3.1-8B | 52.1% | 65.3% | +25% |
- Chain-of-Thought: Models develop step-by-step reasoning
- Self-Correction: Models learn to identify and fix errors
- Verification: Models check intermediate calculations
Question: A store sells apples for $2 each. If John buys 5 apples
and gets a 20% discount, how much does he pay?
<think>
Let me solve this step by step.
Step 1: Calculate the original price
5 apples Γ $2 = $10
Step 2: Calculate the discount
20% of $10 = 0.20 Γ $10 = $2
Step 3: Calculate the final price
$10 - $2 = $8
Let me verify: 5 Γ 2 = 10, and 10 Γ 0.8 = 8 β
</think>
#### 8
# Start Ray cluster
ray start --head --num-gpus=4
# Run distributed training
python scripts/train.py --config configs/distributed.yamlfrom reasoning_rl.rewards import VerifiableRewardFunction
class CustomReward(VerifiableRewardFunction):
def compute_reward(self, completion, ground_truth):
base_reward = super().compute_reward(completion, ground_truth)
# Add custom logic
return base_reward + custom_bonusfrom reasoning_rl.data import CurriculumDataset, load_math
# Start with easier problems
dataset = CurriculumDataset(
load_math("train"),
difficulty_key="level",
initial_max_difficulty=2,
final_max_difficulty=5,
)
# Progress curriculum during training
dataset.set_progress(0.5) # 50% through training# Run all tests
pytest tests/ -v
# Run with coverage
pytest tests/ --cov=reasoning_rl --cov-report=html- DeepSeek-R1 - Incentivizing Reasoning Capability in LLMs
- GSM8K - Grade School Math Dataset
- MATH - Competition Mathematics Dataset
- PPO - Proximal Policy Optimization
Contributions are welcome! Please see CONTRIBUTING.md for guidelines.
This project is licensed under the Apache 2.0 License - see LICENSE for details.
- HuggingFace for Transformers and TRL libraries
- DeepSeek for GRPO algorithm insights
- OpenAI for inspiration from o1 reasoning capabilities
Built with β€οΈ for the AI research community