Summary
The current matmul kernels implement C = A * B (effectively alpha=1, beta=0), always overwriting the output tensor. Supporting the full GEMM interface C = alpha * A * B + beta * C would enable fused matmul+bias operations without extra kernel launches.
Motivation
Burn is adding a fused linear op to ModuleOps (tracel-ai/burn#4737) so backends can optimize the common y = x @ W + b pattern. With beta=1 support, the bias could be pre-loaded into the output tensor, and a single matmul kernel call would produce C = A * B + bias with no intermediate allocation or separate add kernel.
This is the single most common operation in transformer models (every attention projection and FFN layer), so even small gains compound across dozens of layers.
Proposed API
Add optional alpha and beta parameters to the matmul launch path:
pub fn matmul<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
out: Option<CubeTensor<R>>,
strategy: MatmulStrategy,
out_dtype: DType,
alpha: f32, // default 1.0
beta: f32, // default 0.0
) -> Result<CubeTensor<R>, MatmulSetupError>
When beta != 0, the kernel would accumulate into the existing output values instead of overwriting them.
Use Case
Burn's linear forward pass would then become:
fn linear(x, weight, bias) -> Tensor {
let out = bias.broadcast_to(output_shape); // or unsqueeze+expand
matmul(x, weight, Some(out), strategy, dtype, 1.0, 1.0)
}
One kernel launch instead of two (matmul + add).
Summary
The current matmul kernels implement
C = A * B(effectively alpha=1, beta=0), always overwriting the output tensor. Supporting the full GEMM interfaceC = alpha * A * B + beta * Cwould enable fused matmul+bias operations without extra kernel launches.Motivation
Burn is adding a fused
linearop toModuleOps(tracel-ai/burn#4737) so backends can optimize the commony = x @ W + bpattern. With beta=1 support, the bias could be pre-loaded into the output tensor, and a single matmul kernel call would produceC = A * B + biaswith no intermediate allocation or separate add kernel.This is the single most common operation in transformer models (every attention projection and FFN layer), so even small gains compound across dozens of layers.
Proposed API
Add optional
alphaandbetaparameters to the matmul launch path:When
beta != 0, the kernel would accumulate into the existing output values instead of overwriting them.Use Case
Burn's
linearforward pass would then become:One kernel launch instead of two (matmul + add).