A high-performance Reinforcement Learning framework built from the ground up in Rust.
Features • Architecture • Quick Start • Roadmap • Contributing
The Reinforcement Learning ecosystem is dominated by Python frameworks (Stable-Baselines3, RLlib, CleanRL), which are powerful but carry inherent limitations:
| Pain Point | Python Frameworks | RustForge RL |
|---|---|---|
| Runtime Speed | GIL bottleneck, interpreter overhead | Zero-cost abstractions, native speed |
| Memory Safety | Runtime errors, memory leaks | Compile-time guarantees via ownership |
| Concurrency | Fragile multiprocessing | Fearless concurrency with Send/Sync |
| Deployment | Heavy runtimes, dependency hell | Single static binary, no runtime |
| Reproducibility | Floating-point non-determinism | Deterministic seeding at every layer |
RustForge RL aims to be the first comprehensive, production-grade RL framework in Rust — not just a toy implementation, but a framework you can use to train real agents and deploy them anywhere.
A PyTorch-style tensor library built on top of ndarray:
- Creation:
from_vec,zeros,ones,eye,arange,linspace,scalar,full - Shape Transforms:
reshape,flatten,transpose,permute,unsqueeze,squeeze - Arithmetic: Overloaded
+,-,*,/with full broadcasting support - Matrix Math:
matmulsupporting dot products, matrix-vector, and batch matrix multiplication - Reductions:
sum,mean,max,argmax,var,std_dev(with axis + keepdim support) - Activations:
relu,sigmoid,tanh,softmax,log_softmax(numerically stable) - Math Ops:
exp,log,pow,sqrt,abs,clamp,neg,reciprocal - Concatenation:
catandstackwith arbitrary axis support - Random Init: Uniform, Normal, Xavier/Glorot, Kaiming/He initialization strategies
- Display: PyTorch-style pretty printing with automatic truncation for large tensors
Variablewrapper with gradient tracking (Rc<RefCell<>>)- Dynamic computational graph construction via
GradFntrait - Backward pass via topological sort + chain rule
- 17 gradient mappings for operations and math functions
- Numerical gradient checking (finite difference method)
- Optimizers: SGD (w/ momentum), Adam (bias-corrected)
Linear,Conv2d,BatchNorm,LayerNormSequentialcontainer,Moduletrait- Loss functions: MSE, CrossEntropy, Huber
- Model serialization and checkpointing
- Value-Based: DQN, Double DQN, Dueling DQN, Prioritized Experience Replay
- Policy Gradient: REINFORCE, A2C, PPO (clip & penalty variants)
- Off-Policy: SAC, TD3, DDPG
- Environment Interface: Gymnasium-compatible trait for custom environments
- Replay Buffers: Uniform, Prioritized (SumTree), HER
- PyO3-powered Python API for seamless integration
- NumPy array interop (zero-copy where possible)
- Drop-in replacement for select PyTorch/SB3 workflows
- Real-time web-based monitoring via Axum + WebSocket
- Live reward curves, loss plots, episode statistics
- Hyperparameter tracking and experiment comparison
RustForge RL is organized as a Cargo workspace with strict, one-directional dependencies:
rustforge-rl/
├── Cargo.toml # Workspace root
├── crates/
│ ├── rustforge-tensor/ # 🧮 Tensor computation engine
│ │ └── src/
│ │ ├── lib.rs # Crate entry point & re-exports
│ │ ├── tensor.rs # Core Tensor struct + operations
│ │ ├── ops.rs # Operator overloading (+, -, *, /, matmul)
│ │ ├── shape.rs # Broadcasting rules & shape utilities
│ │ ├── random.rs # Random initialization strategies
│ │ ├── display.rs # Pretty-print formatting
│ │ └── error.rs # Type-safe error definitions
│ │
│ ├── rustforge-autograd/ # 🔄 Automatic differentiation
│ ├── rustforge-nn/ # 🧠 Neural network layers
│ └── rustforge-rl/ # 🎮 RL algorithms
│
├── examples/ # Runnable examples (coming soon)
└── benches/ # Performance benchmarks (coming soon)
Dependency graph:
tensor ← autograd ← nn ← rl
↓
dashboard
python bindings
Each layer only depends on the layer below it, ensuring clean separation of concerns and independent testability.
- Rust 1.75+ (2021 edition)
- A C compiler (for
ndarray's BLAS backend — MSVC on Windows, GCC/Clang on Linux/macOS)
# Clone the repository
git clone https://github.com/YOUR_USERNAME/rustforge-rl.git
cd rustforge-rl
# Build the entire workspace
cargo build
# Run all tests (51 unit tests + 9 doc tests)
cargo test -p rustforge-tensor
# Run with optimizations for benchmarking
cargo build --releaseuse rustforge_tensor::Tensor;
fn main() {
// Create tensors
let weights = Tensor::xavier_uniform(&[128, 64], Some(42));
let input = Tensor::rand_normal(&[32, 64], 0.0, 1.0, None);
let bias = Tensor::zeros(&[128]);
// Forward pass: output = input @ weights^T + bias
let output = input.matmul(&weights.t()) + bias;
// Activation
let activated = output.relu();
// Softmax for probability distribution
let probs = activated.softmax(1).unwrap();
println!("Output shape: {:?}", probs.shape());
println!("Probabilities:\n{}", probs);
}use rustforge_tensor::Tensor;
// Broadcasting: [3, 1] + [1, 4] → [3, 4]
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3, 1]);
let b = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[1, 4]);
let c = &a + &b; // shape: [3, 4]
// Reductions
let data = Tensor::rand_uniform(&[100, 50], 0.0, 1.0, Some(0));
println!("Mean: {:.4}", data.mean().item());
println!("Std: {:.4}", data.std_dev().item());
// Matrix multiplication
let q = Tensor::randn(&[8, 64], Some(1)); // queries
let k = Tensor::randn(&[8, 64], Some(2)); // keys
let attention = q.matmul(&k.t()); // [8, 8] attention scores
let weights = (& attention / 8.0_f32.sqrt()).softmax(1).unwrap();| Phase | Milestone | Status |
|---|---|---|
| Phase 1 | Tensor Engine | ✅ Complete (51 tests passing) |
| Phase 1 | Autograd Engine | ✅ Complete (49 tests passing) |
| Phase 2 | Neural Network Modules | ✅ Complete (37 tests passing) |
| Phase 2 | Optimizers (SGD, Adam) | ✅ Complete |
| Phase 3 | DQN + CartPole | 📋 Planned |
| Phase 3 | PPO + Continuous Control | 📋 Planned |
| Phase 4 | SAC, TD3, DDPG | 📋 Planned |
| Phase 4 | Python Bindings (PyO3) | 📋 Planned |
| Phase 5 | Training Dashboard | 📋 Planned |
| Phase 5 | GPU Support (wgpu) | 📋 Planned |
Every operation is backed by comprehensive unit tests with numerical precision checks. We use approx for floating-point comparisons and deterministic seeding for reproducibility.
If you've used PyTorch, you'll feel right at home. Method names, broadcasting rules, and tensor semantics are intentionally aligned with PyTorch conventions.
Rust's ownership system lets us provide a safe, high-level API without runtime overhead. No garbage collector, no reference counting at the tensor layer — just stack-allocated wrappers around contiguous memory.
Each crate is independently usable. Need just tensors? Use rustforge-tensor. Want autograd without RL? Use rustforge-autograd. The workspace structure enforces clean boundaries.
RustForge RL is built on ndarray which leverages BLAS for matrix operations. Preliminary benchmarks on common operations:
| Operation | Shape | RustForge | Notes |
|---|---|---|---|
| MatMul | [512, 512] × [512, 512] | ~2ms | With OpenBLAS |
| Softmax | [1024, 1024] | ~1ms | Numerically stable |
| Xavier Init | [1024, 1024] | ~3ms | ChaCha20 RNG |
| Broadcasting Add | [1000, 1] + [1, 1000] | ~0.5ms | Native ndarray |
Note: Benchmarks are from development builds. Release builds (
--release) are typically 10-30× faster.
We welcome contributions of all kinds! RustForge RL is in its early stages, making it an excellent time to get involved.
- 🐛 Bug Reports: Found an issue? Open a GitHub Issue with reproduction steps
- 📖 Documentation: Improve doc comments, add examples, write tutorials
- 🧪 Tests: Add edge cases, property-based tests, or integration tests
- 🚀 Features: Pick an item from the roadmap and submit a PR
- 💡 Ideas: Suggest new RL algorithms, optimizations, or API improvements
# Fork and clone
git clone https://github.com/YOUR_USERNAME/rustforge-rl.git
cd rustforge-rl
# Create a feature branch
git checkout -b feat/your-feature
# Make changes and run tests
cargo test --workspace
cargo clippy --workspace
# Submit a PR!- Run
cargo fmtbefore committing - Run
cargo clippyand address all warnings - Add doc comments with
///for all public items - Include unit tests for new functionality
- Use Conventional Commits for commit messages
| Component | Technology |
|---|---|
| Language | Rust 2021 Edition |
| Tensor Backend | ndarray 0.16 |
| Random Number Generation | rand 0.8 + ChaCha20 |
| Serialization | serde + bincode |
| Logging | tracing |
| Testing | Built-in + approx |
| Future: Python Bindings | PyO3 |
| Future: Web Dashboard | Axum + WebSocket |
| Future: GPU | wgpu |
This project draws inspiration from and builds upon ideas in:
- PyTorch — API design and tensor semantics
- tch-rs — Rust bindings for libtorch
- candle — Minimalist ML framework in Rust
- burn — Deep learning framework in Rust
- Stable-Baselines3 — RL algorithm implementations
- CleanRL — Single-file RL implementations
- Mnih et al., "Playing Atari with Deep Reinforcement Learning" (DQN, 2013)
- Schulman et al., "Proximal Policy Optimization Algorithms" (PPO, 2017)
- Haarnoja et al., "Soft Actor-Critic" (SAC, 2018)
- Glorot & Bengio, "Understanding the difficulty of training deep feedforward neural networks" (Xavier Init, 2010)
- He et al., "Delving Deep into Rectifiers" (Kaiming Init, 2015)
This project is dual-licensed under:
You may choose either license. See LICENSE-MIT and LICENSE-APACHE for details.
Built with 🦀 Rust and ❤️ passion for RL
RustForge RL — Forging intelligent agents, one tensor at a time.