Introduction
In the field of machine learning, PyTorch has always been the framework of choice for developers. But if you are a Rust developer, would you also like to experience the charm of deep learning? The Candle framework launched by Hugging Face opens a new door for Rust developers.
Candle is a minimalist machine learning framework designed specifically for Rust. Its API is highly similar to PyTorch, which means that if you are familiar with PyTorch, getting started with Candle will be very easy. More importantly, Candle is a pure Rust implementation that does not require a Python runtime, allowing you to enjoy the performance and safety advantages of Rust.
This article will take you from scratch to learn how to build and train neural networks using Candle. We will start with basic tensor operations, gradually delve into building a complete multilayer perceptron, and finally implement a complete training process.
Environment Setup
First, add the following dependencies to your <span>Cargo.toml</span> file:
[dependencies]
candle-core = "0.9.1"
candle-nn = "0.9.1"
rand = "0.9.2"
The roles of these three packages are as follows:
<span>candle-core</span>: Provides tensors and basic operations<span>candle-nn</span>: Provides neural network layers, optimizers, and loss functions<span>rand</span>: Used for random shuffling during data loading
If you need GPU support, you can enable CUDA features:
[dependencies]
candle-core = { version = "0.9.1", features = ["cuda"] }
candle-nn = { version = "0.9.1", features = ["cuda"] }
Note: Using a GPU requires the CUDA toolkit to be installed on your system. All examples in this article use the CPU so that you can follow along on any machine.
Understanding Tensors
Tensors are the fundamental data structure in deep learning. Mathematically, a tensor is a generalization of vectors and matrices to higher dimensions. From a computational perspective, a tensor is a container for multidimensional arrays.
Dimensions of Tensors
- Scalar: A 0-dimensional tensor, which is a single number
- Vector: A 1-dimensional tensor
- Matrix: A 2-dimensional tensor
- Higher Dimensions: 3D tensors, 4D tensors, etc.
Creating Tensors
In Candle, we create tensors using the <span>Tensor::new</span> function:
use candle_core::{Tensor, Device};
fn main() -> Result<(), candle_core::Error> {
let device = Device::Cpu;
// Create a 0-dimensional tensor (scalar)
let tensor0d = Tensor::new(&[1u8], &device)?;
// Create a 1-dimensional tensor (vector)
let tensor1d = Tensor::new(&[1., 2., 3.], &device)?;
// Create a 2-dimensional tensor (matrix)
let tensor2d = Tensor::new(&[[1., 2.], [3., 4.]], &device)?;
// Create a 3-dimensional tensor
let tensor3d = Tensor::new(
&[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]],
&device
)?;
Ok(())
}
In Candle, when creating a tensor, you need to specify the device (CPU or GPU). The <span>?</span> operator is used to handle potential errors.
Data Types
Candle supports various data types, including <span>F64</span> (64-bit floating point), <span>F32</span> (32-bit floating point), <span>U8</span> (8-bit unsigned integer), etc.
In deep learning, 32-bit floating point precision is commonly used. It strikes a good balance between precision and performance, and GPU hardware is optimized for 32-bit operations.
We can use the <span>to_dtype</span> method to convert data types:
use candle_core::DType;
let tensor = Tensor::new(&[[[1., 2.], [3., 4.]]], &device)?;
println!("{:?}", tensor.dtype()); // Output: F64
let tensor = tensor.to_dtype(DType::F32)?;
println!("{:?}", tensor.dtype()); // Output: F32
Common Tensor Operations
Viewing Shape
Use the <span>shape()</span> method to view the shape of a tensor:
let tensor2d = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &device)?;
println!("{:?}", tensor2d.shape()); // Output: [2, 3]
Reshaping Tensors
Use the <span>reshape</span> method to change the shape of a tensor:
// Reshape a 2x3 tensor to 3x2
let reshaped = tensor2d.reshape((3, 2))?;
println!("{}", reshaped);
// Output:
// [[1., 2.],
// [3., 4.],
// [5., 6.]]
Note: The reshaped tensor must contain the same number of elements.
Transposing
Use the <span>t()</span> method to transpose a 2-dimensional tensor:
let transposed = tensor2d.t()?;
println!("{}", transposed);
// Output:
// [[1., 4.],
// [2., 5.],
// [3., 6.]]
Matrix Multiplication
Use the <span>matmul</span> method for matrix multiplication:
let result = tensor2d.matmul(&tensor2d.t()?)?;
println!("{}", result);
// Output:
// [[14., 32.],
// [32., 77.]]
Automatic Differentiation: The Core of Training
The training process in deep learning relies on the backpropagation algorithm, and the core of backpropagation is gradient computation. Candle’s automatic differentiation engine can automatically compute these gradients for us.
Computational Graph
When we perform a series of computations, Candle builds a computational graph in the background. This graph records all operations, allowing us to automatically compute gradients using the chain rule.
Let’s look at a specific example—a logistic regression classifier:
use candle_core::{Tensor, Device, Var};
use candle_nn::ops::sigmoid;
use candle_nn::loss::binary_cross_entropy_with_logit;
fn main() -> Result<(), candle_core::Error> {
let device = Device::Cpu;
// True labels
let y = Tensor::new(&[1.], &device)?;
// Input features
let x1 = Tensor::new(&[1.1], &device)?;
// Weight parameter (requires gradient computation)
let w1 = Var::new(&[2.2], &device)?;
// Bias unit (requires gradient computation)
let b = Var::new(&[0.], &device)?;
// Compute net input
let z = (&x1 * w1.as_tensor() + b.as_tensor())?;
// Compute loss
let loss = binary_cross_entropy_with_logit(&z, &y.flatten_all()?)?;
// Backpropagation - compute gradients
let grads = loss.backward()?;
// Get gradients of parameters
let grad_l_w1 = grads.get(&w1);
let grad_l_b = grads.get(&b);
println!("Gradient of w1: {}", grad_l_w1.unwrap());
println!("Gradient of b: {}", grad_l_b.unwrap());
Ok(())
}
Key points:
- Use the
<span>Var</span>type to mark parameters that require gradient computation - Use the
<span>w1.as_tensor()</span>to get the underlying tensor for computation - Call
<span>loss.backward()</span>to automatically compute all gradients - Use
<span>grads.get(&variable)</span>to get the gradient of a specific parameter
Building a Multilayer Neural Network
Now let’s build a complete multilayer perceptron (MLP). We will define a struct to hold the network layers and implement the <span>Module</span> trait to specify the forward propagation process.
Defining the Network Structure
use candle_core::{Error, Tensor};
use candle_nn::{Linear, Module, VarBuilder, linear};
#[derive(Debug)]
struct NeuralNetwork {
layer_1: Linear, // First hidden layer
layer_2: Linear, // Second hidden layer
output_layer: Linear, // Output layer
}
impl NeuralNetwork {
fn new(num_inputs: usize, num_outputs: usize, vb: VarBuilder) -> Result<Self, Error> {
Ok(Self {
// First hidden layer: input -> 30 neurons
layer_1: linear(num_inputs, 30, vb.pp("l1"))?,
// Second hidden layer: 30 -> 20 neurons
layer_2: linear(30, 20, vb.pp("l2"))?,
// Output layer: 20 -> output
output_layer: linear(20, num_outputs, vb.pp("out"))?,
})
}
}
impl Module for NeuralNetwork {
fn forward(&self, x: &Tensor) -> Result<Tensor, Error> {
// Pass through the first hidden layer, apply ReLU activation function
let x = self.layer_1.forward(x)?;
let x = x.relu()?;
// Pass through the second hidden layer, apply ReLU activation function
let x = self.layer_2.forward(&x)?;
let x = x.relu()?;
// Output layer (return logits)
self.output_layer.forward(&x)
}
}
The structure of this network is:
- Input layer:
<span>num_inputs</span>features - First hidden layer: 30 neurons + ReLU activation
- Second hidden layer: 20 neurons + ReLU activation
- Output layer:
<span>num_outputs</span>classes
Creating Model Instance
use candle_core::{DType, Device};
use candle_nn::{VarBuilder, VarMap};
fn main() -> Result<(), candle_core::Error> {
let device = Device::Cpu;
// Create parameter mapping
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
// Create a neural network with 50 inputs and 3 outputs
let model = NeuralNetwork::new(50, 3, vb)?;
println!("{:#?}", model);
Ok(())
}
Counting Trainable Parameters
We can count the total number of trainable parameters in the model:
fn num_trainable_params(varmap: &VarMap) -> usize {
let mut total_params = 0;
for var in varmap.all_vars().iter() {
let tensor = var.as_tensor();
total_params += tensor.elem_count();
}
total_params
}
println!(
"Total number of trainable parameters: {}",
num_trainable_params(&varmap)
);
// Output: Total number of trainable parameters: 2213
Data Loader
Before training, we need to set up a data loader to handle data in batches.
Defining the Dataset
First, create a simple toy dataset:
use candle_core::{DType, Device, Tensor};
fn main() -> Result<(), candle_core::Error> {
let device = Device::Cpu;
// Training data: 5 samples, each with 2 features
let x_train = Tensor::new(
&[
[-1.2, 3.1],
[-0.9, 2.9],
[-0.5, 2.6],
[2.3, -1.1],
[2.7, -1.5],
],
&device,
)?.to_dtype(DType::F32)?;
// Labels: 3 belong to class 0, 2 belong to class 1
let y_train = Tensor::new(&[0., 0., 0., 1., 1.], &device)?
.to_dtype(DType::U32)?;
Ok(())
}
Implementing the Dataset Structure
struct Dataset {
features: Tensor,
labels: Tensor,
}
impl Dataset {
fn new(x: Tensor, y: Tensor) -> Self {
Self {
features: x,
labels: y,
}
}
// Get a single sample
fn get_item(&self, index: usize) -> Result<(Tensor, Tensor), Error> {
let one_x = self.features.get(index)?;
let one_y = self.labels.get(index)?;
Ok((one_x, one_y))
}
// Get dataset size
fn len(&self) -> Result<usize, Error> {
self.labels.dims1()
}
}
Implementing the DataLoader
The DataLoader is responsible for batch sampling and shuffling data:
use rand::seq::SliceRandom;
struct DataLoader {
dataset: Dataset,
batch_size: usize,
indices: Vec<usize>,
drop_last: bool,
}
impl DataLoader {
fn new(
dataset: Dataset,
batch_size: usize,
shuffle: bool,
drop_last: bool,
) -> Result<Self, Error> {
let len = dataset.len()?;
let mut indices: Vec<usize> = (0..len).collect();
// If shuffling is required, randomly permute indices
if shuffle {
indices.shuffle(&mut rand::rng());
}
Ok(Self {
dataset,
batch_size,
indices,
drop_last,
})
}
// Return an iterator
fn iter(&self) -> DataLoaderIter {
DataLoaderIter {
dataset: &self.dataset,
indices: &self.indices,
batch_size: self.batch_size,
drop_last: self.drop_last,
current: 0,
}
}
}
struct DataLoaderIter<'a> {
dataset: &'a Dataset,
indices: &'a [usize],
batch_size: usize,
drop_last: bool,
current: usize,
}
impl<'a> Iterator for DataLoaderIter<'a> {
type Item = Result<(Tensor, Tensor), Error>;
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.indices.len() {
return None;
}
let end = (self.current + self.batch_size).min(self.indices.len());
let batch_indices = &self.indices[self.current..end];
// If drop_last is specified and the last batch is incomplete, skip it
if self.drop_last && batch_indices.len() < self.batch_size {
return None;
}
self.current = end;
// Collect all samples in the batch
let mut batch_features = Vec::new();
let mut batch_labels = Vec::new();
for &idx in batch_indices {
match self.dataset.get_item(idx) {
Ok((features, labels)) => {
batch_features.push(features);
batch_labels.push(labels);
}
Err(e) => return Some(Err(e)),
}
}
// Stack samples into batches
match (
Tensor::stack(&batch_features, 0),
Tensor::stack(&batch_labels, 0),
) {
(Ok(batch_x), Ok(batch_y)) => Some(Ok((batch_x, batch_y))),
(Err(e), _) | (_, Err(e)) => Some(Err(e)),
}
}
}
Using the DataLoader
let train_ds = Dataset::new(x_train.clone(), y_train.clone());
let train_loader = DataLoader::new(train_ds, 2, true, true)?;
// Iterate over batches
for (idx, batch_result) in train_loader.iter().enumerate() {
let (x, y) = batch_result?;
println!("Batch {}: Features {} Labels {}", idx + 1, x, y);
}
Parameter explanations:
<span>batch_size = 2</span>: 2 samples per batch<span>shuffle = true</span>: Shuffle the data order<span>drop_last = true</span>: Discard the last incomplete batch
Training the Neural Network
Now that we have the model and data loader, we can start training:
use candle_nn::{Optimizer, SGD};
use candle_nn::loss::cross_entropy;
fn main() -> Result<(), candle_core::Error> {
let device = Device::Cpu;
// Create model (2 input features, 2 classes)
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let model = NeuralNetwork::new(2, 2, vb)?;
// Create optimizer with learning rate of 0.5
let mut optimizer = SGD::new(varmap.all_vars(), 0.5)?;
let num_epochs = 3;
for epoch in 0..num_epochs {
for (batch_idx, batch_result) in train_loader.iter().enumerate() {
let (features, labels) = batch_result?;
// Forward propagation
let logits = model.forward(&features)?;
let loss = cross_entropy(&logits, &labels)?;
// Backpropagation and parameter update
optimizer.backward_step(&loss)?;
// Print log
let loss_val = loss.to_vec0::()?;
println!(
"Epoch: {:03}/{:03} | Batch {:03}/{:03} | Training Loss: {:.2}",
epoch + 1,
num_epochs,
batch_idx + 1,
train_loader.total_batches(),
loss_val
);
}
}
Ok(())
}
Example of training output:
Epoch: 001/003 | Batch 001/002 | Training Loss: 3.64
Epoch: 001/003 | Batch 002/002 | Training Loss: 0.00
Epoch: 002/003 | Batch 001/002 | Training Loss: 3.33
Epoch: 002/003 | Batch 002/002 | Training Loss: 0.00
Epoch: 003/003 | Batch 001/002 | Training Loss: 0.00
Epoch: 003/003 | Batch 002/002 | Training Loss: 0.00
From the output, we can see that the loss approaches 0 after 3 epochs, indicating that the model has converged on the training set.
Training Process Analysis
- Learning Rate: We use a learning rate of 0.5. This is a hyperparameter that needs to be experimentally adjusted.
- Number of Epochs: Train for 3 epochs. Each epoch will traverse the entire training set once.
- Loss Function: Use cross-entropy loss, which automatically applies softmax internally.
- Optimizer:
<span>optimizer.backward_step(&loss)</span><span> performs both gradient computation and parameter update.</span>
Model Prediction
After training, we can use the model for predictions:
use candle_nn::ops::softmax;
// Predict on training data
let outputs = model.forward(&x_train)?;
println!("Model Outputs (logits): {}", outputs);
// Convert to probabilities
let probas = softmax(&outputs, 1)?;
println!("Class Probabilities: {}", probas);
// Get predicted classes
let predictions = probas.argmax(1)?;
println!("Predicted Classes: {}", predictions);
Example output:
Model Outputs (logits):
[[ 4.7010e1, -3.7848e1],
[ 4.2802e1, -3.4469e1],
[ 3.7059e1, -2.9854e1],
[-6.4291e0, -3.9145e-1],
[-8.2538e0, 3.6420e-2]]
Class Probabilities:
[[ 1.0000e0, 1.4011e-37],
[ 1.0000e0, 2.7653e-34],
[ 1.0000e0, 8.7143e-30],
[ 2.3815e-3, 9.9762e-1],
[ 2.5089e-4, 9.9975e-1]]
Predicted Classes:
[0, 0, 0, 1, 1]
Calculating Accuracy
We can implement a generic accuracy calculation function:
fn compute_accuracy(model: &NeuralNetwork, dataloader: &DataLoader) -> Result<f32, Error> {
let mut correct = 0;
let mut total_examples = 0;
for batch_result in dataloader.iter() {
let (features, labels) = batch_result?;
// Forward propagation
let logits = model.forward(&features)?;
let predictions = logits.argmax(1)?;
// Compare predictions with true labels
let compare = predictions.eq(&labels)?;
let correct_batch = compare.sum_all()?;
correct += correct_batch.to_vec0::()?;
total_examples += compare.elem_count();
}
Ok(correct as f32 / total_examples as f32)
}
// Calculate training set accuracy
let train_accuracy = compute_accuracy(&model, &train_loader)?;
println!("Training Set Accuracy: {:.2} %", train_accuracy * 100.0);
// Calculate test set accuracy
let test_accuracy = compute_accuracy(&model, &test_loader)?;
println!("Test Set Accuracy: {:.2} %", test_accuracy * 100.0);
Model Saving and Loading
After training, we need to save the model for future use.
Saving the Model
Candle uses the SafeTensors format to save models, which is a safe and efficient tensor storage format:
// Save model parameters
varmap.save("model.safetensors")?;
Loading the Model
Loading a model requires three steps:
fn main() -> Result<(), candle_core::Error> {
let device = Device::Cpu;
// 1. Create a new parameter mapping
let mut varmap_loaded = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap_loaded, DType::F32, &device);
// 2. Rebuild the model with the same architecture
let model_loaded = NeuralNetwork::new(2, 2, vb)?;
// 3. Load the saved parameters
varmap_loaded.load("model.safetensors")?;
// Verify the loaded model
let accuracy = compute_accuracy(&model_loaded, &train_loader)?;
println!("Accuracy after loading the model: {:.2} %", accuracy * 100.0);
Ok(())
}
Notes:
- The model architecture must match exactly (inputs, outputs, number of layers, etc.)
- Variable names must be consistent (e.g., “l1”, “l2”, “out”)
- Advantages of SafeTensors format:
- Fast serialization and deserialization
- Safe, preventing code injection
- Cross-framework and language portability
Conclusion
This article systematically introduced how to build and train neural networks using the Candle framework from scratch. We learned:
- Environment Setup: How to add Candle dependencies in a Rust project
- Tensor Operations: Basic methods for creating, manipulating, and transforming tensors
- Automatic Differentiation: Understanding the principles of computational graphs and gradient computation
- Network Construction: Defining the structure and forward propagation of multilayer neural networks
- Data Loading: Implementing custom Dataset and DataLoader
- Model Training: A complete training loop, including forward propagation, loss computation, and backpropagation
- Model Evaluation: Calculating accuracy and making predictions
- Model Persistence: Saving and loading trained models
Although this article used a simple toy dataset, these concepts can be directly applied to real projects. The design philosophy of Candle is to be simple and easy to use while maintaining similarity to the PyTorch API, making the transition from PyTorch to Candle very smooth.
If you are interested in both Rust and machine learning, Candle is a framework worth trying. It combines the performance and safety of Rust with the usability of deep learning frameworks, bringing powerful machine learning capabilities to the Rust ecosystem.
A complete code example can be found on GitHub. I hope this article helps you embark on your Rust machine learning journey!
References
- Neural Networks with Candle: https://pranitha.dev/posts/neural-networks-with-candle/
Book Recommendations
The second edition of “The Rust Programming Language” is an authoritative learning resource written by the Rust core development team and translated by members of the Chinese Rust community. It is suitable for all software developers who wish to evaluate, get started, improve, and research the Rust language, and is considered essential reading for Rust development work.
This book introduces the basic concepts of the Rust language to unique practical tools, covering advanced concepts such as ownership, traits, lifetimes, and safety guarantees, as well as practical tools like pattern matching, error handling, package management, functional features, and concurrency mechanisms. The book includes three complete project development case studies, guiding readers to develop Rust practical projects from scratch.
Notably, this book has been updated to the Rust 2021 version, meeting the systematic learning needs of beginners and serving as a reference guide for experienced developers, making it the best entry point for building solid Rust skills.
Recommended Reading
-
Rust: The Performance King Sweeping C/C++/Go?
-
A C++ Perspective from a Rust Developer: Pros and Cons Revealed
-
Rust vs Zig: The Emerging Systems Programming Language Battle
-
Essential Design Patterns for Asynchronous Programming in Rust: Enhance Your Code Performance and Maintainability