tilelang.language.gemm_op¶
GEMM (General Matrix Multiplication) operators exposed on the TileLang language surface.
Functions¶
Module Contents¶
- tilelang.language.gemm_op.gemm_v1(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)¶
GEMM v1: use op tl.gemm.
- Parameters:
A (tilelang._typing.BufferLikeType)
B (tilelang._typing.BufferLikeType)
C (tilelang._typing.BufferLikeType)
transpose_A (bool)
transpose_B (bool)
policy (tilelang.tileop.base.GemmWarpPolicy)
clear_accum (bool)
k_pack (int)
wg_wait (int)
mbar (tilelang._typing.BarrierType | None)
- Return type:
tvm.tir.PrimExpr
- tilelang.language.gemm_op.gemm_v2(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)¶
GEMM v2: use op tl.gemm_py.
- Parameters:
A (tilelang._typing.BufferLikeType)
B (tilelang._typing.BufferLikeType)
C (tilelang._typing.BufferLikeType)
transpose_A (bool)
transpose_B (bool)
policy (tilelang.tileop.base.GemmWarpPolicy)
clear_accum (bool)
k_pack (int)
wg_wait (int)
mbar (tilelang._typing.BarrierType | None)
- Return type:
tvm.tir.PrimExpr
- tilelang.language.gemm_op.gemm(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)¶
TileLang GEMM operator.
- Parameters:
A (BufferLikeType, i.e. Buffer | BufferLoad | BufferRegion, or Var) – Input buffer A.
B (BufferLikeType) – Input buffer B.
C (BufferLikeType) – Output buffer C.
transpose_A (bool) – Whether to transpose A. Defaults to False.
transpose_B (bool) – Whether to transpose B. Defaults to False.
policy (GemmWarpPolicy) – GEMM warp partition policy.
clear_accum (bool) – Whether to clear the accumulator.
k_pack (int) – Numbers of packed matrix cores, for ROCm only. Defaults to 1.
wg_wait (int) – Int identifier of the warpgroup MMA batch to wait on.. Defaults to 0.
mbar (BarrierType, i.e. Buffer | BufferLoad, or Var, optional) – Mbarrier in Blackwell. Defaults to None.
- Returns:
A handle to the GEMM operation.
- Return type:
tir.Call