Introduction
In the world of GPU programming, we often face difficult trade-offs between performance and code quality. Traditional GPU programming adopts the SIMT (Single Instruction Multiple Threads) model, which means that if threads execute different instruction branches, performance loss occurs. To pursue extreme performance, we often have to sacrifice type safety and maintainability of the code.
But what if we could maintain GPU performance while also enjoying the compile-time guarantees provided by Rust’s powerful type system? This article introduces an innovative language feature design—”Polysemous Functions”—which attempts to find the perfect balance between the two.
CPU Version: A Paradise of Type Safety
Let’s start with a concrete example. Suppose we need to calculate the rolling average of an array, and sometimes we also need to calculate a randomized version of the rolling average (raising the original value to a random power).
On the CPU, we would naturally write two independent functions:
const WINDOW_N: usize = 10;
// Basic version: only calculates the normal rolling average
fn calc_rolling_mean(
arr: &[f32],
out: &mut [f32],
n: usize,
i: usize,
) {
if i >= n {
return;
}
let mut rolling_mean = 0.;
let start_i = if i < WINDOW_N { 0 } else { i - WINDOW_N };
let end_i = if i < n - WINDOW_N { i + WINDOW_N } else { n };
for idx in start_i..end_i {
let v = arr[idx];
rolling_mean += v;
}
out[i] = rolling_mean / (end_i - start_i) as f32;
}
// Extended version: calculates both normal and randomized rolling averages
fn calc_rolling_mean_with_rand(
arr: &[f32],
out: &mut [f32],
n: usize,
i: usize,
rand_seed: u32,
rand_out: &mut [f32],
) {
if i >= n {
return;
}
let mut rolling_mean = 0.;
let mut rand_rolling_mean = 0.;
let pow = rand_pow(rand_seed); // Calculate random power
let start_i = if i < WINDOW_N { 0 } else { i - WINDOW_N };
let end_i = if i < n - WINDOW_N { i + WINDOW_N } else { n };
for idx in start_i..end_i {
let v = arr[idx];
rolling_mean += v;
rand_rolling_mean += v.powf(pow); // Additional calculation for randomized version
}
out[i] = rolling_mean / (end_i - start_i) as f32;
rand_out[i] = rand_rolling_mean / (end_i - start_i) as f32;
}
The calling code is also very intuitive:
fn calling_f(
arr: &[f32],
out: &mut [f32],
n: usize,
i: usize,
rand_out: &mut [f32],
) {
// Choose which function to call based on condition
if arr[i] > 5. {
calc_rolling_mean_with_rand(arr, out, n, i, (i/3) as u32, rand_out);
} else {
calc_rolling_mean(arr, out, n, i);
}
}
This implementation is very clear: when random calculation is needed, pow and rand_rolling_mean are always valid f32 types; when not needed, they simply do not exist. The type system perfectly ensures the correctness of the code.
GPU Version: The Game of Performance and Safety
Now let’s port this code to the GPU. Since the GPU adopts the SIMT model, threads within the same group need to execute the same instructions. If we keep two independent functions, problems arise:
- When
arr[i] > 5.is true, threads callcalc_rolling_mean_with_rand - When the condition is false, threads call
calc_rolling_mean - The two groups of threads are actually executing completely different functions
- This means half of the threads will be idle, causing severe performance loss
To optimize performance, we need to merge the two functions into one, using the Option type to handle optional calculations:
struct RandOutputArgs<'a> {
seed: u32,
arr: &'a mut [f32],
}
fn calc_rolling_mean(
arr: &[f32],
out: &mut [f32],
n: usize,
i: usize,
rand_out: Option, // Use Option to indicate optional functionality
) {
if i >= n {
return;
}
let mut rolling_mean = 0.;
// Problems start to arise: these variables must also be Option
let (mut rand_rolling_mean, pow) = if let Some(ref r) = rand_out {
(Some(0.), Some(rand_pow(r.seed)))
} else {
(None, None)
};
let start_i = if i < WINDOW_N { 0 } else { i - WINDOW_N };
let end_i = if i < n - WINDOW_N { i + WINDOW_N } else { n };
for idx in start_i..end_i {
let v = arr[idx];
rolling_mean += v;
// Must handle all possible combinations
match (pow, rand_rolling_mean) {
(Some(pow), Some(ref mut m)) => {
*m += v.powf(pow);
},
(None, None) => {},
_ => unreachable!(), // These states "theoretically" should not occur
}
}
out[i] = rolling_mean / (end_i - start_i) as f32;
// Again handle all combinations
match (rand_out, rand_rolling_mean) {
(Some(r), Some(m)) => {
r.arr[i] = m / (end_i - start_i) as f32;
},
(None, None) => {},
_ => unreachable!(), // Another "impossible" state
}
}
Although this version performs better on the GPU (all threads execute the same function, only briefly diverging at nested branches), the code quality significantly decreases:
- Illegal states can be represented: The type system thinks
rand_rolling_meancan beNonewhilerand_outisSome, but this is logically impossible - Runtime checks: We have to use the
unreachable!()macro to handle these “impossible” cases - Cognitive burden: Developers need to maintain two contexts in their minds (with and without randomization)
Polysemous Functions: Having Your Cake and Eating It Too
The proposed feature of “Polysemous Functions” attempts to solve this dilemma. The core idea is:Allow a function to be type-checked multiple times at compile time (once for each variant), but compiled into a single process at runtime.
Basic Syntax
// Define a polysemous function
poly fn calc_rolling_mean[
WithRand, // Variant 1: requires randomization
WithoutRand<(), ()> // Variant 2: does not require randomization
](
arr: &[f32],
out: &mut [f32],
n: usize,
i: usize,
rand_out: RandOutputArgs,
) {
// Function body
}
Key elements:
poly<RandRollingMean, Pow>: Declares two variant generic typesWithRand<f32, f32>: In theWithRandvariant, both types aref32WithoutRand<(), ()>: In theWithoutRandvariant, both types are unit types()
Function Body Implementation
poly fn calc_rolling_mean[
WithRand,
WithoutRand<(), ()>
](
arr: &[f32],
out: &mut [f32],
n: usize,
i: usize,
rand_out: RandOutputArgs,
) {
// Use polymatch to execute different code in different variants
polymatch {
WithRand => {},
WithoutRand => {
drop(rand_out); // Discard parameter in WithoutRand variant
},
}
if i >= n {
return;
}
let mut rolling_mean = 0.;
// Key: these two variables have different types in different variants
let (mut rand_rolling_mean, pow): (RandRollingMean, Pow) = polymatch {
WithRand => (0., rand_pow(rand_out.seed)), // f32 type
WithoutRand => ((), ()), // () type
};
let start_i = if i < WINDOW_N { 0 } else { i - WINDOW_N };
let end_i = if i < n - WINDOW_N { i + WINDOW_N } else { n };
for idx in start_i..end_i {
let v = arr[idx];
rolling_mean += v;
// Only execute in WithRand variant
polymatch {
WithRand => {
rand_rolling_mean += v.powf(pow); // Type safe here!
},
WithoutRand => {}, // In this branch, rand_rolling_mean is (), cannot operate
}
}
out[i] = rolling_mean / (end_i - start_i) as f32;
polymatch {
WithRand => {
rand_out.arr[i] = rand_rolling_mean / (end_i - start_i) as f32;
},
WithoutRand => {},
}
}
Calling Method
fn kernel(
arr: &[f32],
out: &mut [f32],
n: usize,
i: usize,
rand_out: &mut [f32],
) {
// Choose variant at runtime
calc_rolling_mean::[
if arr[i] > 5. {
WithRand
} else {
WithoutRand
}
](
arr,
out,
n,
i,
RandOutputArgs {
seed: (i/3) as u32,
arr: rand_out,
}
);
}
How It Works
-
At Compile Time:
- The compiler performs type checking separately for
WithRandandWithoutRandvariants - In the
WithRandvariant,rand_rolling_meanandpowaref32 - In the
WithoutRandvariant, they are(), any operation on them will lead to a compile error polymatchstatements act like a preprocessor during type checking, only considering branches of the current variant
At Runtime:
- Only a unified function process is generated
- Variant selection is essentially an enumeration value
polymatchstatements turn into ordinarymatchstatements- Stack space is reserved for variant generic types to accommodate all possible concrete types
Enumset: Going Further
The author also proposes the concept of “Enumset” to handle more complex scenarios.
Problem Scenario
Suppose we are writing a 2D grid simulation game, where the grid may contain walls, humans, mice, and hawks. The rules are as follows:
- Hawks can see through walls, mice cannot
- Mice cannot see hawks
- Mice and hawks will avoid humans
- Mice will approach other mice
- Hawks will prey on mice
CPU Version
#[derive(Clone, Copy)]
enum Occupant {
Empty,
Wall,
Human,
Mouse,
Hawk,
}
// Things mice can see (excluding hawks)
enum MouseSeen {
Wall,
Human,
Mouse,
}
// Things hawks can see (can see through walls)
enum HawkSeen {
Human,
Mouse,
Hawk,
}
// View the grid based on different vision capabilities
fn look<s>(
grid: &[Occupant],
x: usize,
y: usize,
grid_size: usize,
) -> Option<s> {
// Viewing logic...
}
fn kernel(
grid: &mut [Occupant],
x: usize,
y: usize,
grid_size: usize,
) {
match grid[index_grid(x, y, grid_size)] {
Occupant::Mouse => {
// Use MouseSeen type
match look::(grid, x, y, grid_size) {
Some(MouseSeen::Human) => { /* Avoid humans */ },
Some(MouseSeen::Mouse) => { /* Approach mice */ },
_ => {},
}
},
Occupant::Hawk => {
// Use HawkSeen type
match look::(grid, x, y, grid_size) {
Some(HawkSeen::Human) => { /* Avoid humans */ },
Some(HawkSeen::Mouse) => { /* Prey on mice */ },
Some(HawkSeen::Hawk) => { /* Avoid hawks */ },
_ => {},
}
},
_ => {},
}
}
</s></s>
This version is type-safe, but will cause divergence on the GPU: mice and hawks call different functions.
GPU Version (Without Enumset)
fn kernel(
grid: &mut [Occupant],
x: usize,
y: usize,
grid_size: usize,
) {
let this = grid[index_grid(x, y, grid_size)];
// Use unified Occupant enum
let seen: Occupant = /* Viewing logic */;
match (this, seen) {
(Occupant::Mouse, Occupant::Wall) => {},
(Occupant::Hawk, Occupant::Human) | (Occupant::Mouse, Occupant::Human) => {
// Avoid humans
},
(Occupant::Hawk, Occupant::Mouse) => { /* Prey */ },
(Occupant::Mouse, Occupant::Mouse) => { /* Approach */ },
(Occupant::Hawk, Occupant::Hawk) => { /* Avoid */ },
_ => unreachable!(), // Illegal state, such as a mouse seeing a hawk
}
}
Performance is better, but type safety decreases: a mouse seeing a hawk is logically impossible, but the type system cannot guarantee it.
Solution Using Enumset
// Define Enumset
enumset AnySeen = MouseSeen, HawkSeen {
Wall { MouseSeen }, // Only mice can see walls
Human, // Both can see humans
Mouse, // Both can see mice
Hawk { HawkSeen }, // Only hawks can see hawks
}
poly fn process[
Hawk,
Mouse
](
grid: &mut [Occupant],
x: usize,
y: usize,
grid_size: usize,
) {
let seen: Option = {
// Use polymatch to call different viewing logic
polymatch {
Hawk => HawkSeen::look_at_tile(occ),
Mouse => MouseSeen::look_at_tile(occ),
}
};
// polymatch enum set, only need to handle valid combinations
polymatch seen {
(Mouse, AnySeen::Wall) => {},
(Hawk, AnySeen::Human) | (Mouse, AnySeen::Human) => {
// Avoid humans
},
(Hawk, AnySeen::Mouse) => { /* Prey */ },
(Mouse, AnySeen::Mouse) => { /* Approach */ },
(Hawk, AnySeen::Hawk) => { /* Avoid */ },
// Note: no unreachable!() because illegal combinations are excluded at the type system level
}
}
Advantages of Enumset:
- Type Safety: Combinations like
(Mouse, AnySeen::Hawk)are illegal at compile time - Performance Retained: Still a single function, GPU threads do not diverge
- Code Clarity: No need for
unreachable!()macro
Feasibility of Implementation
The author mentions that this feature could potentially be implemented through Rust macros:
- Macro Transformation: Transform polysemous function syntax into ordinary Rust code at runtime
- Additional Type Checking Functions: Generate independent functions for each variant (only for type checking, not involved in actual execution)
- Stack Space Allocation: Reserve enough stack space for variant generic types
Challenges:
- Readability of error messages
- Integration with existing Rust toolchain
- IDE support (how to display and edit different variants)
Conclusion
The “Polysemous Functions” feature provides an interesting idea: to maintain the performance advantages of the GPU SIMT model while also enjoying the compile-time guarantees brought by a strong type system. This goal is achieved through the following mechanisms:
- Multiple Type Checks: The same function is type-checked separately for different variants at compile time
- Single Runtime Process: All variants share the same compiled function body
- Variant Generic Types: Allow variables to have different concrete types in different variants
- Polymatch Control Flow: Acts like a preprocessor at compile time, like a normal match at runtime
Although the practical value of this feature still needs to be validated in more complex GPU general computing scenarios, it provides a new perspective on how to balance performance and type safety in parallel computing. For scenarios that require writing complex GPU kernels, this may be a direction worth exploring.
Currently, most GPU applications (such as graphics rendering and machine learning) have relatively simple work for each thread, so shading languages have remained simple. However, as GPU general computing develops, the value of such features may gradually become apparent.
References
- An idea for a GPU programming language feature-Rust: https://medium.com/@calvinhirsch7/an-idea-for-a-gpu-programming-language-feature-rust-3dc23a986ca5
Book Recommendations
The second edition of “The Rust Programming Language” in Chinese is an authoritative learning resource written by the Rust core development team and translated by members of the Chinese Rust community. It is suitable for all software developers who wish to evaluate, get started, improve, and research the Rust language, and is regarded as essential reading for Rust development work.
This book introduces the basic concepts of the Rust language to unique practical tools in a gradual manner, covering advanced concepts such as ownership, traits, lifetimes, safety guarantees, as well as practical tools like pattern matching, error handling, package management, functional features, and concurrency mechanisms. The book includes three complete project development case studies, guiding readers to develop Rust practical projects from scratch.
Notably, this book has been updated to the Rust 2021 version, meeting the systematic learning needs of beginners and serving as a reference guide for experienced developers, making it the best entry point for building solid Rust skills.
Recommended Reading
-
Rust: The Performance King Sweeping C/C++/Go?
-
A C++ Perspective from Rust Developers: Revealing Pros and Cons
-
Rust vs Zig: The Battle of Emerging System Programming Languages
-
Essential Design Patterns for Rust Asynchronous Programming: Enhancing Your Code Performance and Maintainability