tilelang.language.gemm module#

The language interface for tl programs.

tilelang.language.gemm.gemm(A: Union[Buffer, Var], B: Union[Buffer, Var], C: Union[Buffer, Var], transpose_A: bool = False, transpose_B: bool = False, policy: GemmWarpPolicy = GemmWarpPolicy.Square, clear_accum: bool = False, k_pack: int = 1, wg_wait: int = 0)#

Perform a General Matrix Multiplication (GEMM) operation.

This function computes C = A @ B where A and B can optionally be transposed. The operation supports various warp policies and accumulation modes.

Parameters:
  • A (Union[tir.Buffer, tir.Var]) – First input matrix

  • B (Union[tir.Buffer, tir.Var]) – Second input matrix

  • C (Union[tir.Buffer, tir.Var]) – Output matrix for results

  • transpose_A (bool, optional) – Whether to transpose matrix A. Defaults to False.

  • transpose_B (bool, optional) – Whether to transpose matrix B. Defaults to False.

  • policy (GemmWarpPolicy, optional) – Warp execution policy. Defaults to GemmWarpPolicy.Square.

  • clear_accum (bool, optional) – Whether to clear accumulator before computation. Defaults to False.

  • k_pack (int, optional) – Number of k dimensions packed into a single warp. Defaults to 1.

  • wg_wait (int, optional) – Warp group wait count. Defaults to 0.

Returns:

A handle to the GEMM operation

Return type:

tir.Call

Raises:

AssertionError – If the K dimensions of matrices A and B don’t match