Prefix Sum on WebGPU: from Hillis–Steele, Blelloch, to Subgroups

Writing some WebGPU compute shaders that can scan faster than the CPU


In this article I implement prefix sum using WebGPU compute shaders, aiming to beat a straightforward CPU sequential prefix sum. The implementation is written in Rust with wgpu.

For everything except the subgroup part, I mainly followed the classic explanation in GPU Gems 3, Chapter 39.

Full source code is here.

You can find the Japanese version of this article from below.

What is a prefix sum?

Given an array AA, a prefix sum array SS is defined as S[i]=S[i1]+A[i]S[i] = S[i-1] + A[i]. In other words, S[i]S[i] contains the sum of A[0..i]A[0..i] for inclusive scan or A[0..i1]A[0..i-1] for exclusive scan. Prefix sums show up everywhere, i.e. radix sort, quadtrees, and a bunch of other “build a structure from counts/offsets” workloads.

A naive in-place CPU implementation looks like the following pseudocode.

function inclusiveScan(A[0..n-1])
    for i:= 1,..,n-1 do
        A[i] := A[i-1] + A[i]
    end for
end function

The problem for GPU implementation is obvious, each element depends on the previous result, so the naive form doesn’t parallelize. That’s why we need scan algorithms designed for parallel hardware.

Hillis–Steele scan

Hillis–Steele is an iterative algorithm where the dependency distance doubles each step. If you draw a binary-tree-like reduction that computes S[i]S[i] from leaves A[0..i]A[0..i], you’ll notice that at each height you only depend on two values, which is a good start for parallel execution.

Hillis-Steele scan

Implementation

Hillis–Steele is typically described as “in-place”, but if you try to do it with a single buffer on a GPU you’ll quickly run into read-after-write hazards, that mean you might read values that were already updated in the same step. One solution for the problem is using double buffering.

// hillis_steele_scan.rs
let byte_len = (n * size_of::<u32>()) as u64;

let data0 = device.create_buffer(&wgpu::BufferDescriptor {
    label: Some("data0"),
    size: byte_len,
    usage: wgpu::BufferUsages::STORAGE
        | wgpu::BufferUsages::COPY_DST
        | wgpu::BufferUsages::COPY_SRC,
    mapped_at_creation: false,
});
let data1 = device.create_buffer(&wgpu::BufferDescriptor {
    label: Some("data1"),
    size: byte_len,
    usage: wgpu::BufferUsages::STORAGE
        | wgpu::BufferUsages::COPY_DST
        | wgpu::BufferUsages::COPY_SRC,
    mapped_at_creation: false,
});

Next, we need to pass the current step size into the shader: 1,2,4,8,1, 2, 4, 8, \ldots. WebGPU provides uniform buffers for this kind of small “constant-ish” data.

You could update the same uniform buffer every iteration, dispatch and submit at a time, and repeat… but submit is not free (validation, scheduling, driver work, etc.). Instead, I precompute all steps into a single uniform buffer and use dynamic offsets so one submit can dispatch every step.

// hillis_steele_scan.rs
let align = device.limits().min_uniform_buffer_offset_alignment as usize;
let uni_size = size_of::<Uniforms>();
let stride = align_up(uni_size, align);
let uniform_stride = stride as u32;

// Pack uniforms with required stride
let mut blob = vec![0u8; stride * (max_steps as usize)];
for i in 0..max_steps {
    let u = Uniforms {
        step: 1u32 << i,
        _pad: [0; 3],
    };
    let bytes = bytemuck::bytes_of(&u);
    let offset = (i as usize) * stride;
    blob[offset..offset + bytes.len()].copy_from_slice(bytes);
}

let uniform = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
    label: Some("uniform"),
    contents: &blob,
    usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
});

To enable dynamic offsets, the bind group layout needs has_dynamic_offset: true for the uniform binding.

// hillis_steele_scan.rs
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
    label: Some("prefix-sum bgl"),
    entries: &[
        // src
        wgpu::BindGroupLayoutEntry {
            binding: 0,
            visibility: wgpu::ShaderStages::COMPUTE,
            ty: wgpu::BindingType::Buffer {
                ty: wgpu::BufferBindingType::Storage { read_only: true },
                has_dynamic_offset: false,
                min_binding_size: None,
            },
            count: None,
        },
        // dst
        wgpu::BindGroupLayoutEntry {
            binding: 1,
            visibility: wgpu::ShaderStages::COMPUTE,
            ty: wgpu::BindingType::Buffer {
                ty: wgpu::BufferBindingType::Storage { read_only: false },
                has_dynamic_offset: false,
                min_binding_size: None,
            },
            count: None,
        },
        // uniforms (dynamic offset!)
        wgpu::BindGroupLayoutEntry {
            binding: 2,
            visibility: wgpu::ShaderStages::COMPUTE,
            ty: wgpu::BindingType::Buffer {
                ty: wgpu::BufferBindingType::Uniform,
                has_dynamic_offset: true,
                min_binding_size: NonZeroU64::new(size_of::<Uniforms>() as u64),
            },
            count: None,
        },
    ],
});

let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
    label: Some("prefix-sum pipeline layout"),
    bind_group_layouts: &[&bind_group_layout],
    immediate_size: 0,
});

The WGSL shader itself is fairly compact.

// hillis_steele_scan.wgsl
struct Uniforms {
  step: u32,
};

@group(0) @binding(0) var<storage, read> src: array<u32>;
@group(0) @binding(1) var<storage, read_write> dst: array<u32>;
@group(0) @binding(2) var<uniform> uni: Uniforms;

@compute
@workgroup_size(64)
fn main(
  @builtin(global_invocation_id) gid: vec3<u32>,
  @builtin(num_workgroups) nwg: vec3<u32>,
) {
    let total = arrayLength(&src);

    let width = nwg.x * 64u;
    let i = gid.x + gid.y * width;

    if (i >= total) {
        return;
    }

    if (i < uni.step) {
        dst[i] = src[i];
    } else {
        dst[i] = src[i] + src[i - uni.step];
    }
}

Then we just dispatch all steps in one compute pass.

// hillis_steele_scan.rs
pub fn run_prefix_scan(&self) {
    const WG_SIZE: u32 = 64;
    let workgroups_needed = self.n.div_ceil(WG_SIZE as usize) as u32;

    let max_dim = self.device.limits().max_compute_workgroups_per_dimension;
    let x = workgroups_needed.min(max_dim);
    let y = (workgroups_needed + x - 1) / x;

    let mut encoder = self.device.create_command_encoder(&Default::default());
    {
        let mut pass = encoder.begin_compute_pass(&Default::default());
        pass.set_pipeline(&self.pipeline);
        for i in 0..self.max_steps {
            let offset_bytes = i * self.uniform_stride;
            let bg = if i % 2 == 0 { &self.bind_group_0 } else { &self.bind_group_1 };
            pass.set_bind_group(0, bg, &[offset_bytes]);
            pass.dispatch_workgroups(x, y, 1);
        }
    }
    self.queue.submit([encoder.finish()]);
}

Benchmark

Test machine: M4 Pro Mac mini (24GB). I used Criterion for benchmarking. Ideally timestamp-query should be the best way for this benchmark, but on Metal I couldn’t get it to work for some reasons, so I measure from submit until the device becomes usable again. This includes a polling overhead (on my machine 500µs to 1ms), but that’s good enough if we compare at large NN.

CPU Sequential vs GPU Hillis-Steele

…and way slower than sequential scan on CPU.

Hillis–Steele has an ideal parallel time of O(logN)O(\log N), so it sounds like it could be faster than O(N)O(N) sequential scan. But in practice it’s bottlenecked by memory bandwidth and work inefficiency.

Let’s do a rough bandwidth-based estimate. On M4 Pro, the unified memory bandwidth is advertised as 273 GB/s. If N=108N = 10^8 and we store u32, the array size is 108×410^8 \times 4 bytes = 400,000,000400,000,000 bytes \approx 400 MB. Hillis–Steele needs log2N=27\lceil \log_2 N \rceil = 27 steps. With double buffering, each step touches the whole array with one read and one write → about 2× array traffic per step.

  • Total transfers ≈ 400 MB×54=21.6 GB400\text{ MB} \times 54 = 21.6\text{ GB}
  • Ideal bandwidth time ≈ 21.6/2730.079121.6 / 273 \approx 0.0791 s ≈ 79.1 ms

In my measurement at N=108N = 10^8, it took 104.36 ms, which is consistent with “bandwidth dominates, plus overheads”.

Also, if you execute Hillis–Steele on a single thread, the total work is O(NlogN)O(N \log N), which is not optimal compared to O(N)O(N) algorithms.

Blelloch scan

Blelloch scan fixes the two big Hillis–Steele issues,

  • Work: O(N)O(N)
  • Parallel time: O(NP+logN)O(\frac{N}{P} + \log N) (where PP is the number of threads)

Blelloch has two phases, up-sweep for building partial sums in a tree-reduction pattern, then down-sweep for converting the tree sums into a prefix sum. The bottom pictures demonstrates exclusive scan with Blelloch’s algorithm.

Blelloch up-sweep Up-sweep phase of Blelloch scan

Before the down-sweep, you set the last element to 0. In down-sweep, if we call the previous index LL and the current index RR, the core update is S[L]=S[R],S[R]=S[R]+S[L]S[L] = S[R], \quad S[R] = S[R] + S[L].

Blelloch down-sweep.png Down-sweep phase of Blelloch scan

It looks like wizardry, but if you trace one element backward through the tree it really does become the right prefix sum.

Implementation

The overall structure is similar to Hillis–Steele, but no double buffer is needed anymore. Here are the WGSL kernels of the up-sweep and the down-sweep steps.

// global_blelloch_scan_up_sweep.wgsl
struct Uniform {
  step: u32, // power of 2
};

@group(0) @binding(0) var<storage, read_write> data: array<u32>;
@group(0) @binding(1) var<uniform> uni: Uniform;

@compute
@workgroup_size(64)
fn main(
  @builtin(global_invocation_id) gid: vec3<u32>,
  @builtin(num_workgroups) nwg: vec3<u32>,
) {
    let n = arrayLength(&data);
    let step = uni.step;
    let half = step >> 1u;

    let width = nwg.x * 64u;
    let plane = width * nwg.y;
    let t = gid.x + gid.y * width + gid.z * plane;

    let active_idx = n / step;
    if (t >= active_idx) { return; }

    let i = (step - 1u) + t * step;
    data[i] += data[i - half];
}
// global_blelloch_scan_down_sweep.wgsl
struct Uniform {
  step: u32, // power of 2
};

@group(0) @binding(0) var<storage, read_write> data: array<u32>;
@group(0) @binding(1) var<uniform> uni: Uniform;

@compute
@workgroup_size(64)
fn main(
  @builtin(global_invocation_id) gid: vec3<u32>,
  @builtin(num_workgroups) nwg: vec3<u32>,
) {
    let n = arrayLength(&data);
    let step = uni.step;
    let half = step >> 1u;

    let width = nwg.x * 64u;
    let plane = width * nwg.y;
    let t = gid.x + gid.y * width + gid.z * plane;

    let active_idx = n / step;
    if (t >= active_idx) { return; }

    let i = (step - 1u) + t * step;
    let prev = i - half;

    let left = data[i];
    data[i] = data[i] + data[prev];
    data[prev] = left;
}

Benchmark

CPU vs Hillis-Steele vs Blelloch

Much better than Hillis–Steele, but still not ideal. Even though we reduced work, this implementation still hits global memory a lot (roughly 7N7N reads/writes on my implementation I guess), so there’s room to improve.

Blocked Blelloch scan

The next optimization is using workgroup (shared) memory. In WGSL that’s var<workgroup>.

We pick the block size to match the workgroup size. Then,

  • Compute an exclusive scan inside each block
  • Write each block’s total sum into a separate block sum array
  • Scan block sum array to get offsets
  • Add those offsets back into the already-scanned blocks (“carry add”)

Block scan Blocked scanning

Scanning block sum array is itself a prefix sum problem, so we apply the same logic recursively until the block sum array fits into a single block.

Implementation

The core idea is still Blelloch’s, but we use workgroup memory to reduce the amount of global memory access. This enables us to move the up-sweep and down-sweep loops into a single shader, as we can use a workgroup barrier to synchronise the work within a workgroup. This greatly simplifies the implementation of recursive calls to the scan.

// blelloch_block_scan.wgsl
const WG_SIZE: u32 = 64u;

@group(0) @binding(0) var<storage, read_write> global_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> block_sum: array<u32>;

var<workgroup> local_data: array<u32, 64u>;

fn linearize_workgroup_id(wid: vec3<u32>, num_wg: vec3<u32>) -> u32 {
    return wid.x + wid.y * num_wg.x + wid.z * (num_wg.x * num_wg.y);
}

fn get_indices(lid: vec3<u32>, wid: vec3<u32>, num_wg: vec3<u32>) -> array<u32, 2> {
    let local_idx = lid.x;
    let wg_linear = linearize_workgroup_id(wid, num_wg);
    let block_base = wg_linear * WG_SIZE;
    let global_idx = block_base + local_idx;
    return array<u32, 2>(local_idx, global_idx);
}

fn copy_global_data_to_local(n: u32, local_idx: u32, global_idx: u32) {
    var global_val = 0u;
    if (global_idx < n) { global_val = global_data[global_idx]; }
    local_data[local_idx] = global_val;
    workgroupBarrier();
}

fn up_sweep(local_idx: u32) {
    var step = 2u;
    while (step <= WG_SIZE) {
        let num_targets = WG_SIZE / step;
        if (local_idx < num_targets) {
            let target_idx = (local_idx + 1u) * step - 1u;
            local_data[target_idx] += local_data[target_idx - (step >> 1u)];
        }
        workgroupBarrier();
        step = step << 1u;
    }
}

fn down_sweep(local_idx: u32) {
    var step = WG_SIZE;
    while (step >= 2u) {
        let num_targets = WG_SIZE / step;
        if (local_idx < num_targets) {
            let target_idx = (local_idx + 1u) * step - 1u;
            let prev_idx = target_idx - (step >> 1u);
            let prev_val = local_data[prev_idx];
            local_data[prev_idx] = local_data[target_idx];
            local_data[target_idx] += prev_val;
        }
        workgroupBarrier();
        step = step >> 1u;
    }
}

@compute @workgroup_size(WG_SIZE)
fn block_scan_write_sum(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wid: vec3<u32>,
    @builtin(num_workgroups) num_wg: vec3<u32>
) {
    let n = arrayLength(&global_data);
    let indices = get_indices(lid, wid, num_wg);
    let local_idx = indices[0];
    let global_idx = indices[1];
    copy_global_data_to_local(n, local_idx, global_idx);

    up_sweep(local_idx);

    // write block sum, then set last element to 0
    let wg_linear = linearize_workgroup_id(wid, num_wg);
    let n_blocks = arrayLength(&block_sum);
    if (local_idx == 0u) {
        if (wg_linear < n_blocks) {
            block_sum[wg_linear] = local_data[WG_SIZE - 1u];
        }
        local_data[WG_SIZE - 1u] = 0u;
    }
    workgroupBarrier();

    down_sweep(local_idx);

    if (global_idx < n) {
        global_data[global_idx] = local_data[local_idx];
    }
}

Carry add is just reading the block offset and adding it to each element in the block.

// blelloch_add_carry.wgsl
const WG_SIZE: u32 = 64u;

@group(0) @binding(0) var<storage, read_write> global_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> block_sum: array<u32>;

fn linearize_workgroup_id(wid: vec3<u32>, num_wg: vec3<u32>) -> u32 {
    return wid.x + wid.y * num_wg.x + wid.z * (num_wg.x * num_wg.y);
}

@compute @workgroup_size(WG_SIZE)
fn add_carry(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wid: vec3<u32>,
    @builtin(num_workgroups) num_wg: vec3<u32>,
) {
    let n_data = arrayLength(&global_data);
    let n_blocks = arrayLength(&block_sum);

    let wg_linear = linearize_workgroup_id(wid, num_wg);
    if (wg_linear >= n_blocks) { return; }

    let global_idx = wg_linear * WG_SIZE + lid.x;
    if (global_idx >= n_data) { return; }

    let carry = block_sum[wg_linear];
    global_data[global_idx] += carry;
}

On the CPU side, we build the intermediate buffers and bind groups for each recursive level”, then dispatch them in sequence.

Benchmark

CPU vs Blelloch global vs Blelloch block

This was slightly faster than the global-memory version, and it got roughly comparable to the CPU.

I also tried the “process 2 or 4 elements per thread” optimization mentioned in GPU Gems 3, but on my setup it actually got slower. I tested workgroup sizes from 32 to 256, and 64 was consistently best. My guess is that with more barriers, larger workgroups (or heavier per-thread work) create more imbalance and longer waits at the barriers, but this part is speculation.

Anyway, even with blocked Blelloch scan, it wasn’t clearly beating the CPU.

Scan by subgroups

WebGPU has a feature called subgroups. Subgroups map to hardware concepts like warp or simdgroups, and provide fast cross-lane communication.

Implementation

WGSL provides subgroupExclusiveAdd, which is literally an exclusive scan within a subgroup. This is what we wanted! This also fits the blocked scan approach pretty well, we just need to replace the scanning part with the new method.

Subgroup sizes vary by backend, but are often around 32 or 64, most probably smaller than a workgroup. So we still need a “subgroup sums → scan → add carry” step, but inside each workgroup.

// subgroup_block_scan.wgsl
@compute @workgroup_size(WG_SIZE)
fn block_scan_write_sum(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wid: vec3<u32>,
    @builtin(num_workgroups) num_wg: vec3<u32>,
    @builtin(subgroup_size) sg_size: u32,
    @builtin(subgroup_invocation_id) sg_lane: u32,
    @builtin(subgroup_id) sg_id: u32,
) {
    let n = arrayLength(&global_data);

    let wg_linear = linearize_workgroup_id(wid, num_wg);
    let global_idx = wg_linear * WG_SIZE + lid.x;
    let in_range = global_idx < n;

    var v = 0u;
    if (in_range) {
        v = global_data[global_idx];
    }

    // exclusive scan within the subgroup
    let sg_prefix = subgroupExclusiveAdd(v);

    // subgroup-wide sum (same value for all lanes in the subgroup)
    let sg_sum = subgroupAdd(v);

    // write each subgroup sum to workgroup memory
    if (sg_lane == 0u) {
        local_data[sg_id] = sg_sum;
    }
    workgroupBarrier();

    // scan subgroup sums (serial, small)
    let num_sg = (WG_SIZE + sg_size - 1u) / sg_size;
    if (lid.x == 0u) {
        var sg_sum_total = 0u;
        for (var i = 0u; i < num_sg; i = i + 1u) {
            let tmp = local_data[i];
            local_data[i] = sg_sum_total;
            sg_sum_total = sg_sum_total + tmp;
        }

        let n_blocks = arrayLength(&block_sum);
        if (wg_linear < n_blocks) {
            block_sum[wg_linear] = sg_sum_total;
        }
    }
    workgroupBarrier();

    // add subgroup offset + intra-subgroup prefix
    if (in_range) {
        global_data[global_idx] = local_data[sg_id] + sg_prefix;
    }
}

Benchmark

CPU vs Blelloch block vs subgroup

Finally the GPU wins!

Some notes

  • If you only need scan as a standalone operation, subgroups can be a big win. But if you want to fuse scan with other computations, a Blelloch-style approach may still be useful as a base. (Maybe?)
  • There are even faster scan variants like Decoupled Look-back.
  • My benchmark measures from submit until the device becomes usable again, so it does not include overheads of creating pipelines, bindgroups, etc. Thus for small NN, GPU scan is often not worth it.
  • Blocked scans are basically mandatory for performance and for handling large arrays, but recursive intermediate buffers increase memory usage. I suspect there are smarter ways to manage that, which I didn’t dig further.
  • WebGPU is fun🙂

References