tilelang.primitives.gemm.gemm_mma module#
- class tilelang.primitives.gemm.gemm_mma.GemmPrimitiveMMA(A: ~tvm.tir.buffer.Buffer, B: ~tvm.tir.buffer.Buffer, C: ~tvm.tir.buffer.Buffer, transpose_A: bool = False, transpose_B: bool = False, block_row_warps: ~typing.Optional[int] = None, block_col_warps: ~typing.Optional[int] = None, warp_row_tiles: ~typing.Optional[int] = None, warp_col_tiles: ~typing.Optional[int] = None, chunk: ~typing.Optional[int] = None, policy: ~tilelang.primitives.gemm.base.GemmWarpPolicy = (<GemmWarpPolicy.Square: 0>, ), k_pack: int = 1)#
Bases:
GemmBaseParams
A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix Multiply and Accumulate) instructions. Inherits from GemmBaseParams which provides basic parameters such as A, B, C buffers and transposition flags.
- property accum_dtype: str#
returns: The accumulation data type for C. :rtype: str
- gemm_rrr(A: Buffer, B: Buffer, C: Buffer, mma_emitter: TensorCoreIntrinEmitter) PrimExpr #
- gemm_rsr(A: Buffer, B: Buffer, C: Buffer, mma_emitter: TensorCoreIntrinEmitter) PrimExpr #
- gemm_srr(A: Buffer, B: Buffer, C: Buffer, mma_emitter: TensorCoreIntrinEmitter) PrimExpr #
- gemm_ssr(A: Buffer, B: Buffer, C: Buffer, mma_emitter: TensorCoreIntrinEmitter) PrimExpr #
Perform a single-step reduction (SSR) GEMM using Tensor Core MMA primitives. Loads fragments of A and B from shared memory, multiplies them, and accumulates into C.
- Parameters:
A (tir.Buffer) – The buffer for matrix A (in shared memory).
B (tir.Buffer) – The buffer for matrix B (in shared memory).
C (tir.Buffer) – The buffer for the accumulation results.
mma_emitter (TensorCoreIntrinEmitter) – A helper object responsible for generating Tensor Core MMA instructions (ldmatrix, mma, etc.).
- Returns:
The generated IR expression (macro) representing the GEMM loop.
- Return type:
tir.PrimExpr
- property in_dtype: str#
returns: The input data type for A and B. Assumes both have the same dtype. :rtype: str
- Raises:
AssertionError – If A and B do not share the same dtype.
- invoke() PrimExpr #
Entry point to generate a GEMM SSR (single-step reduction) with Tensor Core instructions. Performs the following steps:
Infers block partition parameters if necessary.
Creates a TensorCoreIntrinEmitter with the correct data types and dimensions.
Invokes the GEMM SSR function to generate the final IR expression.
- Returns:
The generated GEMM IR expression.
- Return type:
tir.PrimExpr