Skip to content

feat: support GEMM (C = alpha*A*B + beta*C) in matmul kernels #1280

@antimora

Description

@antimora

Summary

The current matmul kernels implement C = A * B (effectively alpha=1, beta=0), always overwriting the output tensor. Supporting the full GEMM interface C = alpha * A * B + beta * C would enable fused matmul+bias operations without extra kernel launches.

Motivation

Burn is adding a fused linear op to ModuleOps (tracel-ai/burn#4737) so backends can optimize the common y = x @ W + b pattern. With beta=1 support, the bias could be pre-loaded into the output tensor, and a single matmul kernel call would produce C = A * B + bias with no intermediate allocation or separate add kernel.

This is the single most common operation in transformer models (every attention projection and FFN layer), so even small gains compound across dozens of layers.

Proposed API

Add optional alpha and beta parameters to the matmul launch path:

pub fn matmul<R: CubeRuntime>(
    lhs: CubeTensor<R>,
    rhs: CubeTensor<R>,
    out: Option<CubeTensor<R>>,
    strategy: MatmulStrategy,
    out_dtype: DType,
    alpha: f32,  // default 1.0
    beta: f32,   // default 0.0
) -> Result<CubeTensor<R>, MatmulSetupError>

When beta != 0, the kernel would accumulate into the existing output values instead of overwriting them.

Use Case

Burn's linear forward pass would then become:

fn linear(x, weight, bias) -> Tensor {
    let out = bias.broadcast_to(output_shape); // or unsqueeze+expand
    matmul(x, weight, Some(out), strategy, dtype, 1.0, 1.0)
}

One kernel launch instead of two (matmul + add).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions