tilelang.language.gemm_opΒΆ

GEMM (General Matrix Multiplication) operators exposed on the TileLang language surface.

FunctionsΒΆ

gemm(A, B, C[, transpose_A, transpose_B, policy, ...])

TileLang GEMM operator.

wgmma_gemm(A, B, C[, transpose_A, transpose_B, ...])

Explicit Hopper WGMMA GEMM without an implicit wait.

tcgen05_gemm(A, B, C[, transpose_A, transpose_B, ...])

Explicit Blackwell TCGEN05 GEMM without an implicit wait.

tcgen05_gemm_blockscaled(A, B, C, SFA_tmem, SFB_tmem)

Explicit Blackwell TCGEN05 block-scaled GEMM without an implicit wait.

make_blockscaled_gemm_layout(C, A[, transpose_A])

Build the TMEM store layout for the C accumulator of a block-scaled GEMM.

Module ContentsΒΆ

tilelang.language.gemm_op.gemm(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, mbar=None)ΒΆ

TileLang GEMM operator.

This is the default synchronous GEMM interface. On Hopper, if the compiler selects WGMMA lowering, TileLang inserts the corresponding wait implicitly. On Blackwell TCGEN5MMA, TileLang inserts the corresponding mbarrier_wait_parity(…) implicitly after issue.

For manual asynchronous scheduling, use T.wgmma_gemm(…) with T.wait_wgmma(…) on Hopper, or T.tcgen05_gemm(…) with T.mbarrier_wait_parity(…) on Blackwell.

Parameters:
  • A (BufferLikeType, i.e. Buffer | BufferLoad | BufferRegion, or Var) – Input buffer A.

  • B (BufferLikeType) – Input buffer B.

  • C (BufferLikeType) – Output buffer C.

  • transpose_A (bool) – Whether to transpose A. Defaults to False.

  • transpose_B (bool) – Whether to transpose B. Defaults to False.

  • policy (GemmWarpPolicy) – GEMM warp partition policy.

  • clear_accum (bool) – Whether to clear the accumulator.

  • k_pack (int) – Numbers of packed matrix cores, for ROCm only. Defaults to 1.

  • mbar (BarrierType, i.e. Buffer | BufferLoad, or Var, optional) – Mbarrier in Blackwell. Required when this GEMM lowers to TCGEN5MMA. Defaults to None.

Returns:

A handle to the GEMM operation.

Return type:

tir.Call

tilelang.language.gemm_op.wgmma_gemm(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False)ΒΆ

Explicit Hopper WGMMA GEMM without an implicit wait.

This is the explicit asynchronous Hopper WGMMA counterpart to the default synchronous T.gemm(…) interface, with two stricter guarantees: - it always requests the WGMMA lowering path - it never auto-emits an inlined warpgroup_wait

If the current target or operand pattern cannot use Hopper WGMMA, compilation fails instead of silently falling back to MMA.

Parameters:
Return type:

tvm.tir.PrimExpr

tilelang.language.gemm_op.tcgen05_gemm(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, *, mbar, use_2cta=False)ΒΆ

Explicit Blackwell TCGEN05 GEMM without an implicit wait.

This is the explicit asynchronous Blackwell TCGEN5MMA counterpart to the default synchronous T.gemm(…) interface, with two stricter guarantees: - it always requests the TCGEN5MMA lowering path - it never auto-emits an inlined mbarrier_wait_parity

When use_2cta=True, the instruction is lowered to the 2CTA variant which requires cluster_dims to be (2,1,1) or (1,2,1).

If the current target or operand pattern cannot use Blackwell TCGEN5MMA, compilation fails instead of silently falling back to another GEMM path.

Parameters:
  • A (tilelang._typing.BufferLikeType)

  • B (tilelang._typing.BufferLikeType)

  • C (tilelang._typing.BufferLikeType)

  • transpose_A (bool)

  • transpose_B (bool)

  • policy (tilelang.tileop.base.GemmWarpPolicy)

  • clear_accum (bool)

  • mbar (tilelang._typing.BarrierType)

  • use_2cta (bool)

Return type:

tvm.tir.PrimExpr

tilelang.language.gemm_op.tcgen05_gemm_blockscaled(A, B, C, SFA_tmem, SFB_tmem, transpose_A=False, transpose_B=False, clear_accum=False, wg_wait=0, mbar=None, sf_a_id=0, sf_b_id=0, *, use_2cta=False)ΒΆ

Explicit Blackwell TCGEN05 block-scaled GEMM without an implicit wait.

This is the explicit asynchronous Blackwell TCGEN5MMA block-scaled counterpart to T.tcgen05_gemm(…). It never auto-emits an inlined mbarrier_wait_parity, and compilation fails instead of silently falling back if the requested ISA path is unavailable.

With use_2cta=True, this lowers to the true 2CTA block-scaled TCGEN05 path only; there is no fallback or emulation. That mode requires cluster_dims to be (2,1,1) or (1,2,1).

A and B are FP8 (E4M3/E5M2) in shared memory, C is the accumulator in tensor memory, and SFA/SFB are E8M0 scale factors already resident in tensor memory. As with T.tcgen05_gemm(…), this API is explicit-async: it issues the MMA and leaves synchronization to the user schedule.

Parameters:
  • A (tilelang._typing.BufferLikeType) – FP8 input buffer A in shared memory.

  • B (tilelang._typing.BufferLikeType) – FP8 input buffer B in shared memory.

  • C (tilelang._typing.BufferLikeType) – Accumulator in tensor memory.

  • SFA_tmem (tilelang._typing.BufferLikeType) – Scale factors for A in tensor memory.

  • SFB_tmem (tilelang._typing.BufferLikeType) – Scale factors for B in tensor memory.

  • transpose_A (bool) – Whether A is MN-major. Default: False (K-major).

  • transpose_B (bool) – Whether B is K-major. Default: False (MN-major).

  • clear_accum – Whether to zero the accumulator.

  • wg_wait (int) – Warp group wait identifier.

  • mbar (tilelang._typing.BarrierType | None) – Mbarrier for MMA completion signaling.

  • sf_a_id (int) – Scale factor ID for A (0-3).

  • sf_b_id (int) – Scale factor ID for B (0-3).

  • use_2cta (bool) – Whether to request true cta_group::2 lowering.

Return type:

tvm.tir.PrimExpr

tilelang.language.gemm_op.make_blockscaled_gemm_layout(C, A, transpose_A=False)ΒΆ

Build the TMEM store layout for the C accumulator of a block-scaled GEMM.

Users must call T.annotate_layout({C_tmem: layout}) with the returned layout so that subsequent T.copy(C_tmem, ...) can be lowered correctly.

Parameters:
  • C (tilelang._typing.BufferLikeType) – The TMEM accumulator buffer (block_M, block_N).

  • A (tilelang._typing.BufferLikeType) – The FP8 operand A buffer (used to infer K and dtype).

  • transpose_A (bool) – Whether A is MN-major.

Returns:

A Layout object for C’s TMEM storage.

Return type:

tilelang.layout.Layout