Skip to content

rlarson20/Volta

Repository files navigation

Volta ⚡

A PyTorch-like deep learning framework in pure Rust

Build Status Crates.io License: MIT

Volta is a minimal deep learning and automatic differentiation library built from scratch in pure Rust, heavily inspired by PyTorch. It provides a dynamic computation graph, NumPy-style broadcasting, and common neural network primitives.

This project is an educational endeavor to demystify the inner workings of modern autograd engines. It prioritizes correctness, clarity, and a clean API over raw performance, while still providing hooks for hardware acceleration.

Key Features

  • Dynamic Computation Graph: Build and backpropagate through graphs on the fly, just like PyTorch.
  • Reverse-Mode Autodiff: Efficient reverse-mode automatic differentiation with topological sorting.
  • Rich Tensor Operations: A comprehensive set of unary, binary, reduction, and matrix operations via an ergonomic TensorOps trait.
  • Broadcasting: Full NumPy-style broadcasting support for arithmetic operations.
  • Neural Network Layers: Linear, Conv2d, ConvTranspose2d, MaxPool2d, Embedding, LSTMCell, PixelShuffle, Flatten, ReLU, Sigmoid, Tanh, Dropout, BatchNorm1d, BatchNorm2d.
  • Optimizers: SGD (momentum + weight decay), Adam (bias-corrected + weight decay), and experimental Muon.
  • External Model Loading: Load weights from PyTorch, HuggingFace, and other frameworks via StateDictMapper with automatic weight transposition and key remapping. Supports SafeTensors format.
  • Named Layers: Human-readable state dict keys with Sequential::builder() pattern for robust serialization.
  • Multi-dtype Support: Initial support for f16, bf16, f32, f64, i32, i64, u8, and bool tensors.
  • IO System: Save and load model weights (state dicts) via bincode or SafeTensors format.
  • BLAS Acceleration (macOS): Optional acceleration for matrix multiplication via Apple's Accelerate framework.
  • GPU Acceleration: Experimental WGPU-based GPU support for core tensor operations (elementwise, matmul, reductions, movement ops) with automatic backward pass on GPU.
  • Validation-Focused: Includes a robust numerical gradient checker to ensure the correctness of all implemented operations.

Project Status

This library is functional for training MLPs, CNNs, RNNs, GANs, VAEs, and other architectures on CPU. It features a verified autograd engine and correctly implemented im2col convolutions.

  • What's Working:

    • Core Autograd: All operations verified with numerical gradient checking
    • Layers: Linear, Conv2d, ConvTranspose2d, MaxPool2d, Embedding, LSTMCell, PixelShuffle, BatchNorm1d/2d, Dropout
    • Optimizers: SGD (with momentum), Adam, Muon
    • External Loading: PyTorch/HuggingFace model weights via SafeTensors with automatic transposition
    • Named Layers: Robust serialization with human-readable state dict keys
    • Loss Functions: MSE, Cross-Entropy, NLL, BCE, KL Divergence
    • Examples: MNIST, CIFAR, character LM, VAE, DCGAN, super-resolution, LSTM time series
  • ⚠️ What's in Progress:

    • Performance: Not yet a primary focus. BLAS acceleration available for macOS matmul, most ops use naive loops.
    • GPU Support: Experimental WGPU-based acceleration via gpu feature:
      • ✅ Core ops on GPU: elementwise (unary/binary), matmul, reductions (sum/max/mean), movement ops (permute/expand/pad/shrink/stride)
      • ✅ GPU backward pass for autograd with lazy CPU↔GPU transfers
      • ⚠️ Neural network layers still CPU-only (Linear, Conv2d forward passes being ported)
      • ⚠️ Broadcasting preprocessing happens on CPU before GPU dispatch
  • What's Missing:

    • Production-ready GPU integration, distributed training, learning-rate schedulers, attention/transformer layers

Installation

Add Volta to your Cargo.toml:

[dependencies]
volta = "0.2.0"

Enabling BLAS on macOS

For a significant performance boost in matrix multiplication on macOS, enable the accelerate feature:

[dependencies]
volta = { version = "0.2.0", features = ["accelerate"] }

Enabling GPU Support

For experimental GPU acceleration via WGPU, enable the gpu feature:

[dependencies]
volta = { version = "0.2.0", features = ["gpu"] }

Or combine both for maximum performance:

[dependencies]
volta = { version = "0.2.0", features = ["accelerate", "gpu"] }

Examples:

Training an MLP

Here's how to define a simple Multi-Layer Perceptron (MLP) with named layers, train it on synthetic data, and save the model.

use volta::{nn::*, tensor::*, Adam, Sequential, TensorOps, io};

fn main() {
    // 1. Define a simple model with named layers: 2 -> 8 -> 1
    let model = Sequential::builder()
        .add_named("fc1", Box::new(Linear::new(2, 8, true)))
        .add_unnamed(Box::new(ReLU))
        .add_named("fc2", Box::new(Linear::new(8, 1, true)))
        .build();

    // 2. Create synthetic data
    let x_data = vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0];
    let x = RawTensor::new(x_data, &[4, 2], false); // Batch size 4, 2 features

    let y_data = vec![0.0, 1.0, 1.0, 0.0];
    let y = RawTensor::new(y_data, &[4], false); // Flattened targets

    // 3. Set up the optimizer
    let params = model.parameters();
    let mut optimizer = Adam::new(params, 0.1, (0.9, 0.999), 1e-8, 0.0);

    // 4. Training loop
    println!("Training a simple MLP to learn XOR...");
    for epoch in 0..=300 {
        optimizer.zero_grad();

        let pred = model.forward(&x).reshape(&[4]); //alignment
        let loss = mse_loss(&pred, &y);

        if epoch % 20 == 0 {
            println!("Epoch {}: loss = {:.6}", epoch, loss.borrow().data[0]);
        }

        loss.backward();
        optimizer.step();
    }

    // 5. Save and Load State Dict (human-readable keys: "fc1.weight", "fc1.bias", etc.)
    let state = model.state_dict();
    io::save_state_dict(&state, "model.bin").expect("Failed to save");

    // Verify loading
    let mut new_model = Sequential::builder()
        .add_named("fc1", Box::new(Linear::new(2, 8, true)))
        .add_unnamed(Box::new(ReLU))
        .add_named("fc2", Box::new(Linear::new(8, 1, true)))
        .build();
    let loaded_state = io::load_state_dict("model.bin").expect("Failed to load");
    new_model.load_state_dict(&loaded_state);
}

LeNet-style CNN training on CPU

The following utilizes the current API to define a training-ready CNN.

use volta::{Sequential, Conv2d, MaxPool2d, Flatten, Linear, ReLU, Adam};
use volta::nn::Module;
use volta::TensorOps;

fn main() {
    // 1. Define Model
    let model = Sequential::new(vec![
        // Input: 1x28x28
        Box::new(Conv2d::new(1, 6, 5, 1, 2, true)), // Padding 2
        Box::new(ReLU),
        Box::new(MaxPool2d::new(2, 2, 0)),
        // Feature map size here: 6x14x14
        Box::new(Flatten::new()),
        Box::new(Linear::new(6 * 14 * 14, 10, true)),
    ]);

    // 2. Data & Optimizer
    let input = volta::randn(&[4, 1, 28, 28]); // Batch 4
    let target = volta::randn(&[4, 10]);       // Dummy targets
    let params = model.parameters();
    let mut optim = Adam::new(params, 1e-3, (0.9, 0.999), 1e-8, 0.0);

    // 3. Training Step
    optim.zero_grad();
    let output = model.forward(&input);
    let loss = volta::mse_loss(&output, &target);
    loss.backward();
    optim.step();

    println!("Loss: {:?}", loss.borrow().data[0]);
}

Loading External PyTorch Models

Volta can load weights from PyTorch, HuggingFace, and other frameworks using SafeTensors format with automatic weight mapping and transposition.

use volta::{
    Linear, Module, ReLU, Sequential,
    io::{load_safetensors, mapping::StateDictMapper},
};

fn main() {
    // 1. Build matching architecture with named layers
    let mut model = Sequential::builder()
        .add_named("fc1", Box::new(Linear::new(784, 128, true)))
        .add_unnamed(Box::new(ReLU))
        .add_named("fc2", Box::new(Linear::new(128, 10, true)))
        .build();

    // 2. Load PyTorch weights with automatic transposition
    // PyTorch Linear stores weights as [out, in], Volta uses [in, out]
    let pytorch_state = load_safetensors("pytorch_model.safetensors")
        .expect("Failed to load SafeTensors");

    let mapper = StateDictMapper::new()
        .transpose("fc1.weight")  // [128,784] → [784,128]
        .transpose("fc2.weight"); // [10,128] → [128,10]

    let volta_state = mapper.map(pytorch_state);

    // 3. Load into model
    model.load_state_dict(&volta_state);

    // 4. Run inference
    let input = volta::randn(&[1, 784]);
    let output = model.forward(&input);
    println!("Output shape: {:?}", output.borrow().shape);
}

Weight Mapping Features:

  • rename(from, to) - Rename individual keys
  • rename_prefix(old, new) - Rename all keys with prefix
  • strip_prefix(prefix) - Remove prefix from keys
  • transpose(key) - Transpose 2D weight matrices (PyTorch compatibility)
  • transpose_pattern(pattern) - Transpose all matching keys
  • select_keys(keys) / exclude_keys(keys) - Filter state dict

See examples/load_external_mnist.rs for a complete end-to-end example with validation.

GPU Acceleration Example

use volta::{Device, TensorOps, randn};

fn main() {
    // Create tensors on CPU
    let a = randn(&[1024, 1024]);
    let b = randn(&[1024, 1024]);

    // Move to GPU
    let device = Device::gpu().expect("GPU required");
    let a_gpu = a.to_device(device.clone());
    let b_gpu = b.to_device(device.clone());

    // Operations execute on GPU automatically
    let c_gpu = a_gpu.matmul(&b_gpu);  // GPU matmul
    let sum_gpu = c_gpu.sum();          // GPU reduction

    // Gradients computed on GPU when possible
    sum_gpu.backward();
    println!("Gradient shape: {:?}", a_gpu.borrow().grad.as_ref().unwrap().shape());
}

API Overview

The library is designed around a few core concepts:

  • Tensor: The central data structure, an Rc<RefCell<RawTensor>>, which holds data, shape, gradient information, and device location. Supports multiple data types (f32, f16, bf16, f64, i32, i64, u8, bool).
  • TensorOps: A trait implemented for Tensor that provides the ergonomic, user-facing API for all operations (e.g., tensor.add(&other), tensor.matmul(&weights)).
  • nn::Module: A trait for building neural network layers and composing them into larger models. Provides forward(), parameters(), state_dict(), load_state_dict(), and to_device() methods.
  • Sequential::builder(): Builder pattern for composing layers with named parameters for robust serialization. Supports both add_named() for human-readable state dict keys and add_unnamed() for activation layers.
  • Optimizers (Adam, SGD, Muon): Structures that take a list of model parameters and update their weights based on computed gradients during step().
  • Device: Abstraction for CPU/GPU compute. Tensors can be moved between devices with to_device(), and operations automatically dispatch to GPU kernels when available.
  • External Model Loading: StateDictMapper provides transformations (rename, transpose, prefix handling) to load weights from PyTorch, HuggingFace, and other frameworks via SafeTensors format.
  • Vision Support: Conv2d, ConvTranspose2d (for GANs/VAEs), MaxPool2d, PixelShuffle (for super-resolution), BatchNorm1d/2d, and Dropout.
  • Sequence Support: Embedding layers for discrete inputs, LSTMCell for recurrent architectures.

Running the Test Suite

Volta has an extensive test suite that validates the correctness of every operation and its gradient. To run the tests:

cargo test -- --nocapture

To run tests with BLAS acceleration enabled (on macOS):

cargo test --features accelerate -- --nocapture

To run tests with GPU support:

cargo test --features gpu -- --nocapture

Run specific test categories:

cargo test core          # Core tensor tests
cargo test grad_check    # Numerical gradient validation
cargo test broadcasting  # Broadcasting rules
cargo test neural        # Neural network layers
cargo test optimizer     # Optimizer convergence

Available Examples

The examples/ directory contains complete working examples demonstrating various capabilities:

# Basic examples
cargo run --example readme1                    # Simple MLP training
cargo run --example readme2                    # LeNet-style CNN
cargo run --example showcase                   # Feature showcase

# Vision tasks
cargo run --example mnist_cnn                  # MNIST digit classification
cargo run --example super_resolution           # Image upscaling with PixelShuffle
cargo run --example dcgan                      # Deep Convolutional GAN

# Generative models
cargo run --example vae                        # Variational Autoencoder

# Sequence models
cargo run --example char_lm                    # Character-level language model
cargo run --example lstm_time_series           # Time series prediction

# External model loading
cargo run --example load_external_mnist        # Load PyTorch weights via SafeTensors

# GPU acceleration
cargo run --example gpu --features gpu         # GPU tensor operations
cargo run --example gpu_training --features gpu # GPU-accelerated training

# Regression
cargo run --example polynomial_regression      # Polynomial curve fitting

Roadmap

The next major steps for Volta are focused on expanding its capabilities to handle more complex models and improving performance.

  1. Complete GPU Integration: Port remaining neural network layers (Linear, Conv2d) to GPU, optimize GEMM kernels with shared memory tiling.
  2. Performance Optimization: Implement SIMD for element-wise operations, optimize broadcasting on GPU, kernel fusion for composite operations.
  3. Transformer Support: Add attention mechanisms, positional encodings, layer normalization.
  4. Learning Rate Schedulers: Cosine annealing, step decay, warmup schedules.

Outstanding Issues

  • Conv2d Memory Inefficiency: im2col implementation in src/nn/layers/conv.rs materializes the entire matrix in memory. Large batch sizes or high-resolution images will easily OOM even on high-end machines.
  • GPU Kernel Efficiency: Current GPU matmul uses naive implementation without shared memory tiling. Significant performance gains possible with optimized GEMM kernels.
  • Multi-dtype Completeness: While storage supports multiple dtypes (f16, bf16, f64, etc.), most operations still assume f32. Full dtype support requires operation kernels for each type.
  • Single-threaded: Uses Rc<RefCell> instead of Arc<Mutex>, limiting to single-threaded execution on CPU.

Contributing

Contributions, issues, and feature requests are welcome! Feel free to check the issues page.

License

This project is licensed under the MIT License - see the LICENSE file for details.

About

A Rust powered PyTorch-like Machine Learning framework.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages