tilelang.language.gemm_op¶

GEMM (General Matrix Multiplication) operators exposed on the TileLang language surface.

Functions¶

gemm_v1(A, B, C[, transpose_A, transpose_B, policy, ...])

GEMM v1: use op tl.gemm.

gemm_v2(A, B, C[, transpose_A, transpose_B, policy, ...])

GEMM v2: use op tl.gemm_py.

gemm(A, B, C[, transpose_A, transpose_B, policy, ...])

TileLang GEMM operator.

Module Contents¶

tilelang.language.gemm_op.gemm_v1(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)¶

GEMM v1: use op tl.gemm.

Parameters:
  • A (tilelang._typing.BufferLikeType)

  • B (tilelang._typing.BufferLikeType)

  • C (tilelang._typing.BufferLikeType)

  • transpose_A (bool)

  • transpose_B (bool)

  • policy (tilelang.tileop.base.GemmWarpPolicy)

  • clear_accum (bool)

  • k_pack (int)

  • wg_wait (int)

  • mbar (tilelang._typing.BarrierType | None)

Return type:

tvm.tir.PrimExpr

tilelang.language.gemm_op.gemm_v2(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)¶

GEMM v2: use op tl.gemm_py.

Parameters:
  • A (tilelang._typing.BufferLikeType)

  • B (tilelang._typing.BufferLikeType)

  • C (tilelang._typing.BufferLikeType)

  • transpose_A (bool)

  • transpose_B (bool)

  • policy (tilelang.tileop.base.GemmWarpPolicy)

  • clear_accum (bool)

  • k_pack (int)

  • wg_wait (int)

  • mbar (tilelang._typing.BarrierType | None)

Return type:

tvm.tir.PrimExpr

tilelang.language.gemm_op.gemm(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0, mbar=None)¶

TileLang GEMM operator.

Parameters:
  • A (BufferLikeType, i.e. Buffer | BufferLoad | BufferRegion, or Var) – Input buffer A.

  • B (BufferLikeType) – Input buffer B.

  • C (BufferLikeType) – Output buffer C.

  • transpose_A (bool) – Whether to transpose A. Defaults to False.

  • transpose_B (bool) – Whether to transpose B. Defaults to False.

  • policy (GemmWarpPolicy) – GEMM warp partition policy.

  • clear_accum (bool) – Whether to clear the accumulator.

  • k_pack (int) – Numbers of packed matrix cores, for ROCm only. Defaults to 1.

  • wg_wait (int) – Int identifier of the warpgroup MMA batch to wait on.. Defaults to 0.

  • mbar (BarrierType, i.e. Buffer | BufferLoad, or Var, optional) – Mbarrier in Blackwell. Defaults to None.

Returns:

A handle to the GEMM operation.

Return type:

tir.Call