tilelang.primitives.gemm.gemm_mma module#

class tilelang.primitives.gemm.gemm_mma.GemmPrimitiveMMA(A: ~tvm.tir.buffer.Buffer, B: ~tvm.tir.buffer.Buffer, C: ~tvm.tir.buffer.Buffer, transpose_A: bool = False, transpose_B: bool = False, block_row_warps: ~typing.Optional[int] = None, block_col_warps: ~typing.Optional[int] = None, warp_row_tiles: ~typing.Optional[int] = None, warp_col_tiles: ~typing.Optional[int] = None, chunk: ~typing.Optional[int] = None, policy: ~tilelang.primitives.gemm.base.GemmWarpPolicy = (<GemmWarpPolicy.Square: 0>, ), k_pack: int = 1)#

Bases: GemmBaseParams

A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix Multiply and Accumulate) instructions. Inherits from GemmBaseParams which provides basic parameters such as A, B, C buffers and transposition flags.

property accum_dtype: str#

returns: The accumulation data type for C. :rtype: str

gemm_rrr(A: Buffer, B: Buffer, C: Buffer, mma_emitter: TensorCoreIntrinEmitter) PrimExpr#
gemm_rsr(A: Buffer, B: Buffer, C: Buffer, mma_emitter: TensorCoreIntrinEmitter) PrimExpr#
gemm_srr(A: Buffer, B: Buffer, C: Buffer, mma_emitter: TensorCoreIntrinEmitter) PrimExpr#
gemm_ssr(A: Buffer, B: Buffer, C: Buffer, mma_emitter: TensorCoreIntrinEmitter) PrimExpr#

Perform a single-step reduction (SSR) GEMM using Tensor Core MMA primitives. Loads fragments of A and B from shared memory, multiplies them, and accumulates into C.

Parameters:
  • A (tir.Buffer) – The buffer for matrix A (in shared memory).

  • B (tir.Buffer) – The buffer for matrix B (in shared memory).

  • C (tir.Buffer) – The buffer for the accumulation results.

  • mma_emitter (TensorCoreIntrinEmitter) – A helper object responsible for generating Tensor Core MMA instructions (ldmatrix, mma, etc.).

Returns:

The generated IR expression (macro) representing the GEMM loop.

Return type:

tir.PrimExpr

property in_dtype: str#

returns: The input data type for A and B. Assumes both have the same dtype. :rtype: str

Raises:

AssertionError – If A and B do not share the same dtype.

invoke() PrimExpr#

Entry point to generate a GEMM SSR (single-step reduction) with Tensor Core instructions. Performs the following steps:

  1. Infers block partition parameters if necessary.

  2. Creates a TensorCoreIntrinEmitter with the correct data types and dimensions.

  3. Invokes the GEMM SSR function to generate the final IR expression.

Returns:

The generated GEMM IR expression.

Return type:

tir.PrimExpr