Building Neural Networks with Rust and Candle: A Complete Guide from Scratch

Introduction

In the field of machine learning, PyTorch has always been the framework of choice for developers. But if you are a Rust developer, would you also like to experience the charm of deep learning? The Candle framework launched by Hugging Face opens a new door for Rust developers.

Candle is a minimalist machine learning framework designed specifically for Rust. Its API is highly similar to PyTorch, which means that if you are familiar with PyTorch, getting started with Candle will be very easy. More importantly, Candle is a pure Rust implementation that does not require a Python runtime, allowing you to enjoy the performance and safety advantages of Rust.

This article will take you from scratch to learn how to build and train neural networks using Candle. We will start with basic tensor operations, gradually delve into building a complete multilayer perceptron, and finally implement a complete training process.

Environment Setup

First, add the following dependencies to your <span>Cargo.toml</span> file:

[dependencies]
candle-core = "0.9.1"
candle-nn = "0.9.1"
rand = "0.9.2"

The roles of these three packages are as follows:

  • <span>candle-core</span>: Provides tensors and basic operations
  • <span>candle-nn</span>: Provides neural network layers, optimizers, and loss functions
  • <span>rand</span>: Used for random shuffling during data loading

If you need GPU support, you can enable CUDA features:

[dependencies]
candle-core = { version = "0.9.1", features = ["cuda"] }
candle-nn = { version = "0.9.1", features = ["cuda"] }

Note: Using a GPU requires the CUDA toolkit to be installed on your system. All examples in this article use the CPU so that you can follow along on any machine.

Understanding Tensors

Tensors are the fundamental data structure in deep learning. Mathematically, a tensor is a generalization of vectors and matrices to higher dimensions. From a computational perspective, a tensor is a container for multidimensional arrays.

Dimensions of Tensors

  • Scalar: A 0-dimensional tensor, which is a single number
  • Vector: A 1-dimensional tensor
  • Matrix: A 2-dimensional tensor
  • Higher Dimensions: 3D tensors, 4D tensors, etc.

Creating Tensors

In Candle, we create tensors using the <span>Tensor::new</span> function:

use candle_core::{Tensor, Device};

fn main() -> Result<(), candle_core::Error> {
    let device = Device::Cpu;
    
    // Create a 0-dimensional tensor (scalar)
    let tensor0d = Tensor::new(&[1u8], &device)?;
    
    // Create a 1-dimensional tensor (vector)
    let tensor1d = Tensor::new(&[1., 2., 3.], &device)?;
    
    // Create a 2-dimensional tensor (matrix)
    let tensor2d = Tensor::new(&[[1., 2.], [3., 4.]], &device)?;
    
    // Create a 3-dimensional tensor
    let tensor3d = Tensor::new(
        &[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], 
        &device
    )?;
    
    Ok(())
}

In Candle, when creating a tensor, you need to specify the device (CPU or GPU). The <span>?</span> operator is used to handle potential errors.

Data Types

Candle supports various data types, including <span>F64</span> (64-bit floating point), <span>F32</span> (32-bit floating point), <span>U8</span> (8-bit unsigned integer), etc.

In deep learning, 32-bit floating point precision is commonly used. It strikes a good balance between precision and performance, and GPU hardware is optimized for 32-bit operations.

We can use the <span>to_dtype</span> method to convert data types:

use candle_core::DType;

let tensor = Tensor::new(&[[[1., 2.], [3., 4.]]], &device)?;
println!("{:?}", tensor.dtype()); // Output: F64

let tensor = tensor.to_dtype(DType::F32)?;
println!("{:?}", tensor.dtype()); // Output: F32

Common Tensor Operations

Viewing Shape

Use the <span>shape()</span> method to view the shape of a tensor:

let tensor2d = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &device)?;
println!("{:?}", tensor2d.shape()); // Output: [2, 3]

Reshaping Tensors

Use the <span>reshape</span> method to change the shape of a tensor:

// Reshape a 2x3 tensor to 3x2
let reshaped = tensor2d.reshape((3, 2))?;
println!("{}", reshaped);
// Output:
// [[1., 2.],
//  [3., 4.],
//  [5., 6.]]

Note: The reshaped tensor must contain the same number of elements.

Transposing

Use the <span>t()</span> method to transpose a 2-dimensional tensor:

let transposed = tensor2d.t()?;
println!("{}", transposed);
// Output:
// [[1., 4.],
//  [2., 5.],
//  [3., 6.]]

Matrix Multiplication

Use the <span>matmul</span> method for matrix multiplication:

let result = tensor2d.matmul(&tensor2d.t()?)?;
println!("{}", result);
// Output:
// [[14., 32.],
//  [32., 77.]]

Automatic Differentiation: The Core of Training

The training process in deep learning relies on the backpropagation algorithm, and the core of backpropagation is gradient computation. Candle’s automatic differentiation engine can automatically compute these gradients for us.

Computational Graph

When we perform a series of computations, Candle builds a computational graph in the background. This graph records all operations, allowing us to automatically compute gradients using the chain rule.

Let’s look at a specific example—a logistic regression classifier:

use candle_core::{Tensor, Device, Var};
use candle_nn::ops::sigmoid;
use candle_nn::loss::binary_cross_entropy_with_logit;

fn main() -> Result<(), candle_core::Error> {
    let device = Device::Cpu;
    
    // True labels
    let y = Tensor::new(&[1.], &device)?;
    
    // Input features
    let x1 = Tensor::new(&[1.1], &device)?;
    
    // Weight parameter (requires gradient computation)
    let w1 = Var::new(&[2.2], &device)?;
    
    // Bias unit (requires gradient computation)
    let b = Var::new(&[0.], &device)?;
    
    // Compute net input
    let z = (&x1 * w1.as_tensor() + b.as_tensor())?;
    
    // Compute loss
    let loss = binary_cross_entropy_with_logit(&z, &y.flatten_all()?)?;
    
    // Backpropagation - compute gradients
    let grads = loss.backward()?;
    
    // Get gradients of parameters
    let grad_l_w1 = grads.get(&w1);
    let grad_l_b = grads.get(&b);
    
    println!("Gradient of w1: {}", grad_l_w1.unwrap());
    println!("Gradient of b: {}", grad_l_b.unwrap());
    
    Ok(())
}

Key points:

  1. Use the <span>Var</span> type to mark parameters that require gradient computation
  2. Use the <span>w1.as_tensor()</span> to get the underlying tensor for computation
  3. Call <span>loss.backward()</span> to automatically compute all gradients
  4. Use <span>grads.get(&variable)</span> to get the gradient of a specific parameter

Building a Multilayer Neural Network

Now let’s build a complete multilayer perceptron (MLP). We will define a struct to hold the network layers and implement the <span>Module</span> trait to specify the forward propagation process.

Defining the Network Structure

use candle_core::{Error, Tensor};
use candle_nn::{Linear, Module, VarBuilder, linear};

#[derive(Debug)]
struct NeuralNetwork {
    layer_1: Linear,      // First hidden layer
    layer_2: Linear,      // Second hidden layer
    output_layer: Linear, // Output layer
}

impl NeuralNetwork {
    fn new(num_inputs: usize, num_outputs: usize, vb: VarBuilder) -> Result<Self, Error> {
        Ok(Self {
            // First hidden layer: input -> 30 neurons
            layer_1: linear(num_inputs, 30, vb.pp("l1"))?,
            // Second hidden layer: 30 -> 20 neurons
            layer_2: linear(30, 20, vb.pp("l2"))?,
            // Output layer: 20 -> output
            output_layer: linear(20, num_outputs, vb.pp("out"))?,
        })
    }
}

impl Module for NeuralNetwork {
    fn forward(&self, x: &Tensor) -> Result<Tensor, Error> {
        // Pass through the first hidden layer, apply ReLU activation function
        let x = self.layer_1.forward(x)?;
        let x = x.relu()?;
        
        // Pass through the second hidden layer, apply ReLU activation function
        let x = self.layer_2.forward(&x)?;
        let x = x.relu()?;
        
        // Output layer (return logits)
        self.output_layer.forward(&x)
    }
}

The structure of this network is:

  • Input layer: <span>num_inputs</span> features
  • First hidden layer: 30 neurons + ReLU activation
  • Second hidden layer: 20 neurons + ReLU activation
  • Output layer: <span>num_outputs</span> classes

Creating Model Instance

use candle_core::{DType, Device};
use candle_nn::{VarBuilder, VarMap};

fn main() -> Result<(), candle_core::Error> {
    let device = Device::Cpu;
    
    // Create parameter mapping
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
    
    // Create a neural network with 50 inputs and 3 outputs
    let model = NeuralNetwork::new(50, 3, vb)?;
    
    println!("{:#?}", model);
    
    Ok(())
}

Counting Trainable Parameters

We can count the total number of trainable parameters in the model:

fn num_trainable_params(varmap: &VarMap) -> usize {
    let mut total_params = 0;
    for var in varmap.all_vars().iter() {
        let tensor = var.as_tensor();
        total_params += tensor.elem_count();
    }
    total_params
}

println!(
    "Total number of trainable parameters: {}", 
    num_trainable_params(&varmap)
);
// Output: Total number of trainable parameters: 2213

Data Loader

Before training, we need to set up a data loader to handle data in batches.

Defining the Dataset

First, create a simple toy dataset:

use candle_core::{DType, Device, Tensor};

fn main() -> Result<(), candle_core::Error> {
    let device = Device::Cpu;
    
    // Training data: 5 samples, each with 2 features
    let x_train = Tensor::new(
        &[
            [-1.2, 3.1],
            [-0.9, 2.9],
            [-0.5, 2.6],
            [2.3, -1.1],
            [2.7, -1.5],
        ],
        &device,
    )?.to_dtype(DType::F32)?;
    
    // Labels: 3 belong to class 0, 2 belong to class 1
    let y_train = Tensor::new(&[0., 0., 0., 1., 1.], &device)?
        .to_dtype(DType::U32)?;
    
    Ok(())
}

Implementing the Dataset Structure

struct Dataset {
    features: Tensor,
    labels: Tensor,
}

impl Dataset {
    fn new(x: Tensor, y: Tensor) -> Self {
        Self {
            features: x,
            labels: y,
        }
    }
    
    // Get a single sample
    fn get_item(&self, index: usize) -> Result<(Tensor, Tensor), Error> {
        let one_x = self.features.get(index)?;
        let one_y = self.labels.get(index)?;
        Ok((one_x, one_y))
    }
    
    // Get dataset size
    fn len(&self) -> Result<usize, Error> {
        self.labels.dims1()
    }
}

Implementing the DataLoader

The DataLoader is responsible for batch sampling and shuffling data:

use rand::seq::SliceRandom;

struct DataLoader {
    dataset: Dataset,
    batch_size: usize,
    indices: Vec<usize>,
    drop_last: bool,
}

impl DataLoader {
    fn new(
        dataset: Dataset,
        batch_size: usize,
        shuffle: bool,
        drop_last: bool,
    ) -> Result<Self, Error> {
        let len = dataset.len()?;
        let mut indices: Vec<usize> = (0..len).collect();
        
        // If shuffling is required, randomly permute indices
        if shuffle {
            indices.shuffle(&mut rand::rng());
        }
        
        Ok(Self {
            dataset,
            batch_size,
            indices,
            drop_last,
        })
    }
    
    // Return an iterator
    fn iter(&self) -> DataLoaderIter {
        DataLoaderIter {
            dataset: &self.dataset,
            indices: &self.indices,
            batch_size: self.batch_size,
            drop_last: self.drop_last,
            current: 0,
        }
    }
}

struct DataLoaderIter<'a> {
    dataset: &'a Dataset,
    indices: &'a [usize],
    batch_size: usize,
    drop_last: bool,
    current: usize,
}

impl<'a> Iterator for DataLoaderIter<'a> {
    type Item = Result<(Tensor, Tensor), Error>;
    
    fn next(&mut self) -> Option<Self::Item> {
        if self.current >= self.indices.len() {
            return None;
        }
        
        let end = (self.current + self.batch_size).min(self.indices.len());
        let batch_indices = &self.indices[self.current..end];
        
        // If drop_last is specified and the last batch is incomplete, skip it
        if self.drop_last && batch_indices.len() < self.batch_size {
            return None;
        }
        
        self.current = end;
        
        // Collect all samples in the batch
        let mut batch_features = Vec::new();
        let mut batch_labels = Vec::new();
        
        for &idx in batch_indices {
            match self.dataset.get_item(idx) {
                Ok((features, labels)) => {
                    batch_features.push(features);
                    batch_labels.push(labels);
                }
                Err(e) => return Some(Err(e)),
            }
        }
        
        // Stack samples into batches
        match (
            Tensor::stack(&batch_features, 0),
            Tensor::stack(&batch_labels, 0),
        ) {
            (Ok(batch_x), Ok(batch_y)) => Some(Ok((batch_x, batch_y))),
            (Err(e), _) | (_, Err(e)) => Some(Err(e)),
        }
    }
}

Using the DataLoader

let train_ds = Dataset::new(x_train.clone(), y_train.clone());
let train_loader = DataLoader::new(train_ds, 2, true, true)?;

// Iterate over batches
for (idx, batch_result) in train_loader.iter().enumerate() {
    let (x, y) = batch_result?;
    println!("Batch {}: Features {} Labels {}", idx + 1, x, y);
}

Parameter explanations:

  • <span>batch_size = 2</span>: 2 samples per batch
  • <span>shuffle = true</span>: Shuffle the data order
  • <span>drop_last = true</span>: Discard the last incomplete batch

Training the Neural Network

Now that we have the model and data loader, we can start training:

use candle_nn::{Optimizer, SGD};
use candle_nn::loss::cross_entropy;

fn main() -> Result<(), candle_core::Error> {
    let device = Device::Cpu;
    
    // Create model (2 input features, 2 classes)
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
    let model = NeuralNetwork::new(2, 2, vb)?;
    
    // Create optimizer with learning rate of 0.5
    let mut optimizer = SGD::new(varmap.all_vars(), 0.5)?;
    
    let num_epochs = 3;
    
    for epoch in 0..num_epochs {
        for (batch_idx, batch_result) in train_loader.iter().enumerate() {
            let (features, labels) = batch_result?;
            
            // Forward propagation
            let logits = model.forward(&features)?;
            let loss = cross_entropy(&logits, &labels)?;
            
            // Backpropagation and parameter update
            optimizer.backward_step(&loss)?;
            
            // Print log
            let loss_val = loss.to_vec0::()?;
            println!(
                "Epoch: {:03}/{:03} | Batch {:03}/{:03} | Training Loss: {:.2}",
                epoch + 1,
                num_epochs,
                batch_idx + 1,
                train_loader.total_batches(),
                loss_val
            );
        }
    }
    
    Ok(())
}

Example of training output:

Epoch: 001/003 | Batch 001/002 | Training Loss: 3.64
Epoch: 001/003 | Batch 002/002 | Training Loss: 0.00
Epoch: 002/003 | Batch 001/002 | Training Loss: 3.33
Epoch: 002/003 | Batch 002/002 | Training Loss: 0.00
Epoch: 003/003 | Batch 001/002 | Training Loss: 0.00
Epoch: 003/003 | Batch 002/002 | Training Loss: 0.00

From the output, we can see that the loss approaches 0 after 3 epochs, indicating that the model has converged on the training set.

Training Process Analysis

  1. Learning Rate: We use a learning rate of 0.5. This is a hyperparameter that needs to be experimentally adjusted.
  2. Number of Epochs: Train for 3 epochs. Each epoch will traverse the entire training set once.
  3. Loss Function: Use cross-entropy loss, which automatically applies softmax internally.
  4. Optimizer: <span>optimizer.backward_step(&loss)</span><span> performs both gradient computation and parameter update.</span>

Model Prediction

After training, we can use the model for predictions:

use candle_nn::ops::softmax;

// Predict on training data
let outputs = model.forward(&x_train)?;
println!("Model Outputs (logits): {}", outputs);

// Convert to probabilities
let probas = softmax(&outputs, 1)?;
println!("Class Probabilities: {}", probas);

// Get predicted classes
let predictions = probas.argmax(1)?;
println!("Predicted Classes: {}", predictions);

Example output:

Model Outputs (logits):
[[ 4.7010e1, -3.7848e1],
 [ 4.2802e1, -3.4469e1],
 [ 3.7059e1, -2.9854e1],
 [-6.4291e0, -3.9145e-1],
 [-8.2538e0,  3.6420e-2]]

Class Probabilities:
[[ 1.0000e0, 1.4011e-37],
 [ 1.0000e0, 2.7653e-34],
 [ 1.0000e0, 8.7143e-30],
 [ 2.3815e-3, 9.9762e-1],
 [ 2.5089e-4, 9.9975e-1]]

Predicted Classes:
[0, 0, 0, 1, 1]

Calculating Accuracy

We can implement a generic accuracy calculation function:

fn compute_accuracy(model: &NeuralNetwork, dataloader: &DataLoader) -> Result<f32, Error> {
    let mut correct = 0;
    let mut total_examples = 0;
    
    for batch_result in dataloader.iter() {
        let (features, labels) = batch_result?;
        
        // Forward propagation
        let logits = model.forward(&features)?;
        let predictions = logits.argmax(1)?;
        
        // Compare predictions with true labels
        let compare = predictions.eq(&labels)?;
        let correct_batch = compare.sum_all()?;
        
        correct += correct_batch.to_vec0::()?;
        total_examples += compare.elem_count();
    }
    
    Ok(correct as f32 / total_examples as f32)
}

// Calculate training set accuracy
let train_accuracy = compute_accuracy(&model, &train_loader)?;
println!("Training Set Accuracy: {:.2} %", train_accuracy * 100.0);

// Calculate test set accuracy
let test_accuracy = compute_accuracy(&model, &test_loader)?;
println!("Test Set Accuracy: {:.2} %", test_accuracy * 100.0);

Model Saving and Loading

After training, we need to save the model for future use.

Saving the Model

Candle uses the SafeTensors format to save models, which is a safe and efficient tensor storage format:

// Save model parameters
varmap.save("model.safetensors")?;

Loading the Model

Loading a model requires three steps:

fn main() -> Result<(), candle_core::Error> {
    let device = Device::Cpu;
    
    // 1. Create a new parameter mapping
    let mut varmap_loaded = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap_loaded, DType::F32, &device);
    
    // 2. Rebuild the model with the same architecture
    let model_loaded = NeuralNetwork::new(2, 2, vb)?;
    
    // 3. Load the saved parameters
    varmap_loaded.load("model.safetensors")?;
    
    // Verify the loaded model
    let accuracy = compute_accuracy(&model_loaded, &train_loader)?;
    println!("Accuracy after loading the model: {:.2} %", accuracy * 100.0);
    
    Ok(())
}

Notes:

  1. The model architecture must match exactly (inputs, outputs, number of layers, etc.)
  2. Variable names must be consistent (e.g., “l1”, “l2”, “out”)
  3. Advantages of SafeTensors format:
  • Fast serialization and deserialization
  • Safe, preventing code injection
  • Cross-framework and language portability

Conclusion

This article systematically introduced how to build and train neural networks using the Candle framework from scratch. We learned:

  1. Environment Setup: How to add Candle dependencies in a Rust project
  2. Tensor Operations: Basic methods for creating, manipulating, and transforming tensors
  3. Automatic Differentiation: Understanding the principles of computational graphs and gradient computation
  4. Network Construction: Defining the structure and forward propagation of multilayer neural networks
  5. Data Loading: Implementing custom Dataset and DataLoader
  6. Model Training: A complete training loop, including forward propagation, loss computation, and backpropagation
  7. Model Evaluation: Calculating accuracy and making predictions
  8. Model Persistence: Saving and loading trained models

Although this article used a simple toy dataset, these concepts can be directly applied to real projects. The design philosophy of Candle is to be simple and easy to use while maintaining similarity to the PyTorch API, making the transition from PyTorch to Candle very smooth.

If you are interested in both Rust and machine learning, Candle is a framework worth trying. It combines the performance and safety of Rust with the usability of deep learning frameworks, bringing powerful machine learning capabilities to the Rust ecosystem.

A complete code example can be found on GitHub. I hope this article helps you embark on your Rust machine learning journey!

References

  1. Neural Networks with Candle: https://pranitha.dev/posts/neural-networks-with-candle/

Book Recommendations

The second edition of “The Rust Programming Language” is an authoritative learning resource written by the Rust core development team and translated by members of the Chinese Rust community. It is suitable for all software developers who wish to evaluate, get started, improve, and research the Rust language, and is considered essential reading for Rust development work.

This book introduces the basic concepts of the Rust language to unique practical tools, covering advanced concepts such as ownership, traits, lifetimes, and safety guarantees, as well as practical tools like pattern matching, error handling, package management, functional features, and concurrency mechanisms. The book includes three complete project development case studies, guiding readers to develop Rust practical projects from scratch.

Notably, this book has been updated to the Rust 2021 version, meeting the systematic learning needs of beginners and serving as a reference guide for experienced developers, making it the best entry point for building solid Rust skills.

Recommended Reading

  1. Rust: The Performance King Sweeping C/C++/Go?

  2. A C++ Perspective from a Rust Developer: Pros and Cons Revealed

  3. Rust vs Zig: The Emerging Systems Programming Language Battle

  4. Essential Design Patterns for Asynchronous Programming in Rust: Enhance Your Code Performance and Maintainability

Leave a Comment