tilelang.language.gemm module#
The language interface for tl programs.
- tilelang.language.gemm.gemm(A: Union[Buffer, Var], B: Union[Buffer, Var], C: Union[Buffer, Var], transpose_A: bool = False, transpose_B: bool = False, policy: GemmWarpPolicy = GemmWarpPolicy.Square, clear_accum: bool = False, k_pack: int = 1, wg_wait: int = 0)#
Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed. The operation supports various warp policies and accumulation modes.
- Parameters:
A (Union[tir.Buffer, tir.Var]) – First input matrix
B (Union[tir.Buffer, tir.Var]) – Second input matrix
C (Union[tir.Buffer, tir.Var]) – Output matrix for results
transpose_A (bool, optional) – Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional) – Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional) – Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional) – Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional) – Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional) – Warp group wait count. Defaults to 0.
- Returns:
A handle to the GEMM operation
- Return type:
tir.Call
- Raises:
AssertionError – If the K dimensions of matrices A and B don’t match