New Approaches to Rust GPU Programming: Taming SIMT with Type Systems

Introduction

In the world of GPU programming, we often face difficult trade-offs between performance and code quality. Traditional GPU programming adopts the SIMT (Single Instruction Multiple Threads) model, which means that if threads execute different instruction branches, performance loss occurs. To pursue extreme performance, we often have to sacrifice type safety and maintainability of the code.

But what if we could maintain GPU performance while also enjoying the compile-time guarantees provided by Rust’s powerful type system? This article introduces an innovative language feature design—”Polysemous Functions”—which attempts to find the perfect balance between the two.

CPU Version: A Paradise of Type Safety

Let’s start with a concrete example. Suppose we need to calculate the rolling average of an array, and sometimes we also need to calculate a randomized version of the rolling average (raising the original value to a random power).

On the CPU, we would naturally write two independent functions:

const WINDOW_N: usize = 10;

// Basic version: only calculates the normal rolling average
fn calc_rolling_mean(
    arr: &[f32],
    out: &mut [f32],
    n: usize,
    i: usize,
) {
    if i >= n {
        return;
    }
    
    let mut rolling_mean = 0.;
    let start_i = if i < WINDOW_N { 0 } else { i - WINDOW_N };
    let end_i = if i < n - WINDOW_N { i + WINDOW_N } else { n };
    
    for idx in start_i..end_i {
        let v = arr[idx];
        rolling_mean += v;
    }
    
    out[i] = rolling_mean / (end_i - start_i) as f32;
}

// Extended version: calculates both normal and randomized rolling averages
fn calc_rolling_mean_with_rand(
    arr: &[f32],
    out: &mut [f32],
    n: usize,
    i: usize,
    rand_seed: u32,
    rand_out: &mut [f32],
) {
    if i >= n {
        return;
    }
    
    let mut rolling_mean = 0.;
    let mut rand_rolling_mean = 0.;
    let pow = rand_pow(rand_seed);  // Calculate random power
    
    let start_i = if i < WINDOW_N { 0 } else { i - WINDOW_N };
    let end_i = if i < n - WINDOW_N { i + WINDOW_N } else { n };
    
    for idx in start_i..end_i {
        let v = arr[idx];
        rolling_mean += v;
        rand_rolling_mean += v.powf(pow);  // Additional calculation for randomized version
    }
    
    out[i] = rolling_mean / (end_i - start_i) as f32;
    rand_out[i] = rand_rolling_mean / (end_i - start_i) as f32;
}

The calling code is also very intuitive:

fn calling_f(
    arr: &[f32],
    out: &mut [f32],
    n: usize,
    i: usize,
    rand_out: &mut [f32],
) {
    // Choose which function to call based on condition
    if arr[i] > 5. {
        calc_rolling_mean_with_rand(arr, out, n, i, (i/3) as u32, rand_out);
    } else {
        calc_rolling_mean(arr, out, n, i);
    }
}

This implementation is very clear: when random calculation is needed, pow and rand_rolling_mean are always valid f32 types; when not needed, they simply do not exist. The type system perfectly ensures the correctness of the code.

GPU Version: The Game of Performance and Safety

Now let’s port this code to the GPU. Since the GPU adopts the SIMT model, threads within the same group need to execute the same instructions. If we keep two independent functions, problems arise:

  • When arr[i] > 5. is true, threads call calc_rolling_mean_with_rand
  • When the condition is false, threads call calc_rolling_mean
  • The two groups of threads are actually executing completely different functions
  • This means half of the threads will be idle, causing severe performance loss

To optimize performance, we need to merge the two functions into one, using the Option type to handle optional calculations:

struct RandOutputArgs<'a> {
    seed: u32,
    arr: &'a mut [f32],
}

fn calc_rolling_mean(
    arr: &[f32],
    out: &mut [f32],
    n: usize,
    i: usize,
    rand_out: Option,  // Use Option to indicate optional functionality
) {
    if i >= n {
        return;
    }
    
    let mut rolling_mean = 0.;
    
    // Problems start to arise: these variables must also be Option
    let (mut rand_rolling_mean, pow) = if let Some(ref r) = rand_out {
        (Some(0.), Some(rand_pow(r.seed)))
    } else {
        (None, None)
    };
    
    let start_i = if i < WINDOW_N { 0 } else { i - WINDOW_N };
    let end_i = if i < n - WINDOW_N { i + WINDOW_N } else { n };
    
    for idx in start_i..end_i {
        let v = arr[idx];
        rolling_mean += v;
        
        // Must handle all possible combinations
        match (pow, rand_rolling_mean) {
            (Some(pow), Some(ref mut m)) => {
                *m += v.powf(pow);
            },
            (None, None) => {},
            _ => unreachable!(),  // These states "theoretically" should not occur
        }
    }
    
    out[i] = rolling_mean / (end_i - start_i) as f32;
    
    // Again handle all combinations
    match (rand_out, rand_rolling_mean) {
        (Some(r), Some(m)) => {
            r.arr[i] = m / (end_i - start_i) as f32;
        },
        (None, None) => {},
        _ => unreachable!(),  // Another "impossible" state
    }
}

Although this version performs better on the GPU (all threads execute the same function, only briefly diverging at nested branches), the code quality significantly decreases:

  1. Illegal states can be represented: The type system thinks rand_rolling_mean can be None while rand_out is Some, but this is logically impossible
  2. Runtime checks: We have to use the unreachable!() macro to handle these “impossible” cases
  3. Cognitive burden: Developers need to maintain two contexts in their minds (with and without randomization)

Polysemous Functions: Having Your Cake and Eating It Too

The proposed feature of “Polysemous Functions” attempts to solve this dilemma. The core idea is:Allow a function to be type-checked multiple times at compile time (once for each variant), but compiled into a single process at runtime.

Basic Syntax

// Define a polysemous function
poly fn calc_rolling_mean[
    WithRand,      // Variant 1: requires randomization
    WithoutRand<(), ()>      // Variant 2: does not require randomization
](
    arr: &[f32],
    out: &mut [f32],
    n: usize,
    i: usize,
    rand_out: RandOutputArgs,
) {
    // Function body
}

Key elements:

  • poly<RandRollingMean, Pow>: Declares two variant generic types
  • WithRand<f32, f32>: In the WithRand variant, both types are f32
  • WithoutRand<(), ()>: In the WithoutRand variant, both types are unit types ()

Function Body Implementation

poly fn calc_rolling_mean[
    WithRand,
    WithoutRand<(), ()>
](
    arr: &[f32],
    out: &mut [f32],
    n: usize,
    i: usize,
    rand_out: RandOutputArgs,
) {
    // Use polymatch to execute different code in different variants
    polymatch {
        WithRand => {},
        WithoutRand => {
            drop(rand_out);  // Discard parameter in WithoutRand variant
        },
    }
    
    if i >= n {
        return;
    }
    
    let mut rolling_mean = 0.;
    
    // Key: these two variables have different types in different variants
    let (mut rand_rolling_mean, pow): (RandRollingMean, Pow) = polymatch {
        WithRand => (0., rand_pow(rand_out.seed)),      // f32 type
        WithoutRand => ((), ()),                         // () type
    };
    
    let start_i = if i < WINDOW_N { 0 } else { i - WINDOW_N };
    let end_i = if i < n - WINDOW_N { i + WINDOW_N } else { n };
    
    for idx in start_i..end_i {
        let v = arr[idx];
        rolling_mean += v;
        
        // Only execute in WithRand variant
        polymatch {
            WithRand => {
                rand_rolling_mean += v.powf(pow);  // Type safe here!
            },
            WithoutRand => {},  // In this branch, rand_rolling_mean is (), cannot operate
        }
    }
    
    out[i] = rolling_mean / (end_i - start_i) as f32;
    
    polymatch {
        WithRand => {
            rand_out.arr[i] = rand_rolling_mean / (end_i - start_i) as f32;
        },
        WithoutRand => {},
    }
}

Calling Method

fn kernel(
    arr: &[f32],
    out: &mut [f32],
    n: usize,
    i: usize,
    rand_out: &mut [f32],
) {
    // Choose variant at runtime
    calc_rolling_mean::[
        if arr[i] > 5. {
            WithRand
        } else {
            WithoutRand
        }
    ](
        arr,
        out,
        n,
        i,
        RandOutputArgs {
            seed: (i/3) as u32,
            arr: rand_out,
        }
    );
}

How It Works

  1. At Compile Time:

  • The compiler performs type checking separately for WithRand and WithoutRand variants
  • In the WithRand variant, rand_rolling_mean and pow are f32
  • In the WithoutRand variant, they are (), any operation on them will lead to a compile error
  • polymatch statements act like a preprocessor during type checking, only considering branches of the current variant
  • At Runtime:

    • Only a unified function process is generated
    • Variant selection is essentially an enumeration value
    • polymatch statements turn into ordinary match statements
    • Stack space is reserved for variant generic types to accommodate all possible concrete types

    Enumset: Going Further

    The author also proposes the concept of “Enumset” to handle more complex scenarios.

    Problem Scenario

    Suppose we are writing a 2D grid simulation game, where the grid may contain walls, humans, mice, and hawks. The rules are as follows:

    • Hawks can see through walls, mice cannot
    • Mice cannot see hawks
    • Mice and hawks will avoid humans
    • Mice will approach other mice
    • Hawks will prey on mice

    CPU Version

    #[derive(Clone, Copy)]
    enum Occupant {
        Empty,
        Wall,
        Human,
        Mouse,
        Hawk,
    }
    
    // Things mice can see (excluding hawks)
    enum MouseSeen {
        Wall,
        Human,
        Mouse,
    }
    
    // Things hawks can see (can see through walls)
    enum HawkSeen {
        Human,
        Mouse,
        Hawk,
    }
    
    // View the grid based on different vision capabilities
    fn look<s>(
        grid: &[Occupant],
        x: usize,
        y: usize,
        grid_size: usize,
    ) -> Option<s> {
        // Viewing logic...
    }
    
    fn kernel(
        grid: &mut [Occupant],
        x: usize,
        y: usize,
        grid_size: usize,
    ) {
        match grid[index_grid(x, y, grid_size)] {
            Occupant::Mouse => {
                // Use MouseSeen type
                match look::(grid, x, y, grid_size) {
                    Some(MouseSeen::Human) => { /* Avoid humans */ },
                    Some(MouseSeen::Mouse) => { /* Approach mice */ },
                    _ => {},
                }
            },
            Occupant::Hawk => {
                // Use HawkSeen type
                match look::(grid, x, y, grid_size) {
                    Some(HawkSeen::Human) => { /* Avoid humans */ },
                    Some(HawkSeen::Mouse) => { /* Prey on mice */ },
                    Some(HawkSeen::Hawk) => { /* Avoid hawks */ },
                    _ => {},
                }
            },
            _ => {},
        }
    }
    </s></s>

    This version is type-safe, but will cause divergence on the GPU: mice and hawks call different functions.

    GPU Version (Without Enumset)

    fn kernel(
        grid: &mut [Occupant],
        x: usize,
        y: usize,
        grid_size: usize,
    ) {
        let this = grid[index_grid(x, y, grid_size)];
        
        // Use unified Occupant enum
        let seen: Occupant = /* Viewing logic */;
        
        match (this, seen) {
            (Occupant::Mouse, Occupant::Wall) => {},
            (Occupant::Hawk, Occupant::Human) | (Occupant::Mouse, Occupant::Human) => {
                // Avoid humans
            },
            (Occupant::Hawk, Occupant::Mouse) => { /* Prey */ },
            (Occupant::Mouse, Occupant::Mouse) => { /* Approach */ },
            (Occupant::Hawk, Occupant::Hawk) => { /* Avoid */ },
            _ => unreachable!(),  // Illegal state, such as a mouse seeing a hawk
        }
    }
    

    Performance is better, but type safety decreases: a mouse seeing a hawk is logically impossible, but the type system cannot guarantee it.

    Solution Using Enumset

    // Define Enumset
    enumset AnySeen = MouseSeen, HawkSeen {
        Wall { MouseSeen },     // Only mice can see walls
        Human,                  // Both can see humans
        Mouse,                  // Both can see mice
        Hawk { HawkSeen },      // Only hawks can see hawks
    }
    
    poly fn process[
        Hawk, 
        Mouse
    ](
        grid: &mut [Occupant],
        x: usize,
        y: usize,
        grid_size: usize,
    ) {
        let seen: Option = {
            // Use polymatch to call different viewing logic
            polymatch {
                Hawk => HawkSeen::look_at_tile(occ),
                Mouse => MouseSeen::look_at_tile(occ),
            }
        };
        
        // polymatch enum set, only need to handle valid combinations
        polymatch seen {
            (Mouse, AnySeen::Wall) => {},
            (Hawk, AnySeen::Human) | (Mouse, AnySeen::Human) => {
                // Avoid humans
            },
            (Hawk, AnySeen::Mouse) => { /* Prey */ },
            (Mouse, AnySeen::Mouse) => { /* Approach */ },
            (Hawk, AnySeen::Hawk) => { /* Avoid */ },
            // Note: no unreachable!() because illegal combinations are excluded at the type system level
        }
    }
    

    Advantages of Enumset:

    1. Type Safety: Combinations like (Mouse, AnySeen::Hawk) are illegal at compile time
    2. Performance Retained: Still a single function, GPU threads do not diverge
    3. Code Clarity: No need for unreachable!() macro

    Feasibility of Implementation

    The author mentions that this feature could potentially be implemented through Rust macros:

    1. Macro Transformation: Transform polysemous function syntax into ordinary Rust code at runtime
    2. Additional Type Checking Functions: Generate independent functions for each variant (only for type checking, not involved in actual execution)
    3. Stack Space Allocation: Reserve enough stack space for variant generic types

    Challenges:

    • Readability of error messages
    • Integration with existing Rust toolchain
    • IDE support (how to display and edit different variants)

    Conclusion

    The “Polysemous Functions” feature provides an interesting idea: to maintain the performance advantages of the GPU SIMT model while also enjoying the compile-time guarantees brought by a strong type system. This goal is achieved through the following mechanisms:

    1. Multiple Type Checks: The same function is type-checked separately for different variants at compile time
    2. Single Runtime Process: All variants share the same compiled function body
    3. Variant Generic Types: Allow variables to have different concrete types in different variants
    4. Polymatch Control Flow: Acts like a preprocessor at compile time, like a normal match at runtime

    Although the practical value of this feature still needs to be validated in more complex GPU general computing scenarios, it provides a new perspective on how to balance performance and type safety in parallel computing. For scenarios that require writing complex GPU kernels, this may be a direction worth exploring.

    Currently, most GPU applications (such as graphics rendering and machine learning) have relatively simple work for each thread, so shading languages have remained simple. However, as GPU general computing develops, the value of such features may gradually become apparent.

    References

    1. An idea for a GPU programming language feature-Rust: https://medium.com/@calvinhirsch7/an-idea-for-a-gpu-programming-language-feature-rust-3dc23a986ca5

    Book Recommendations

    The second edition of “The Rust Programming Language” in Chinese is an authoritative learning resource written by the Rust core development team and translated by members of the Chinese Rust community. It is suitable for all software developers who wish to evaluate, get started, improve, and research the Rust language, and is regarded as essential reading for Rust development work.

    This book introduces the basic concepts of the Rust language to unique practical tools in a gradual manner, covering advanced concepts such as ownership, traits, lifetimes, safety guarantees, as well as practical tools like pattern matching, error handling, package management, functional features, and concurrency mechanisms. The book includes three complete project development case studies, guiding readers to develop Rust practical projects from scratch.

    Notably, this book has been updated to the Rust 2021 version, meeting the systematic learning needs of beginners and serving as a reference guide for experienced developers, making it the best entry point for building solid Rust skills.

    Recommended Reading

    1. Rust: The Performance King Sweeping C/C++/Go?

    2. A C++ Perspective from Rust Developers: Revealing Pros and Cons

    3. Rust vs Zig: The Battle of Emerging System Programming Languages

    4. Essential Design Patterns for Rust Asynchronous Programming: Enhancing Your Code Performance and Maintainability

    Leave a Comment