tilelang.primitives.gemm.gemm_mma¶
Classes¶
A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix |
Module Contents¶
- class tilelang.primitives.gemm.gemm_mma.GemmPrimitiveMMA¶
Bases:
tilelang.primitives.gemm.base.GemmBaseParams
A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix Multiply and Accumulate) instructions. Inherits from GemmBaseParams which provides basic parameters such as A, B, C buffers and transposition flags.
- abstract gemm_rrr(A, B, C, mma_emitter)¶
- Parameters:
A (tvm.tir.Buffer)
B (tvm.tir.Buffer)
C (tvm.tir.Buffer)
mma_emitter (tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitter)
- Return type:
tvm.tir.PrimExpr
- gemm_rsr(A, B, C, mma_emitter)¶
- Parameters:
A (tvm.tir.Buffer)
B (tvm.tir.Buffer)
C (tvm.tir.Buffer)
mma_emitter (tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitter)
- Return type:
tvm.tir.PrimExpr
- abstract gemm_srr(A, B, C, mma_emitter)¶
- Parameters:
A (tvm.tir.Buffer)
B (tvm.tir.Buffer)
C (tvm.tir.Buffer)
mma_emitter (tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitter)
- Return type:
tvm.tir.PrimExpr
- gemm_ssr(A, B, C, mma_emitter)¶
Perform a single-step reduction (SSR) GEMM using Tensor Core MMA primitives. Loads fragments of A and B from shared memory, multiplies them, and accumulates into C.
- Parameters:
A (tir.Buffer) – The buffer for matrix A (in shared memory).
B (tir.Buffer) – The buffer for matrix B (in shared memory).
C (tir.Buffer) – The buffer for the accumulation results.
mma_emitter (TensorCoreIntrinEmitter) – A helper object responsible for generating Tensor Core MMA instructions (ldmatrix, mma, etc.).
- Returns:
The generated IR expression (macro) representing the GEMM loop.
- Return type:
tir.PrimExpr
- invoke()¶
Entry point to generate a GEMM SSR (single-step reduction) with Tensor Core instructions. Performs the following steps:
Infers block partition parameters if necessary.
Creates a TensorCoreIntrinEmitter with the correct data types and dimensions.
Invokes the GEMM SSR function to generate the final IR expression.
- Returns:
The generated GEMM IR expression.
- Return type:
tir.PrimExpr
- property in_dtype: str¶
returns: The input data type for A and B. Assumes both have the same dtype. :rtype: str
- Raises:
AssertionError – If A and B do not share the same dtype.
- Return type:
str
- property accum_dtype: str¶
returns: The accumulation data type for C. :rtype: str
- Return type:
str