tilelang.primitives.gemm.gemm_mma ================================= .. py:module:: tilelang.primitives.gemm.gemm_mma Classes ------- .. autoapisummary:: tilelang.primitives.gemm.gemm_mma.GemmPrimitiveMMA Module Contents --------------- .. py:class:: GemmPrimitiveMMA Bases: :py:obj:`tilelang.primitives.gemm.base.GemmBaseParams` A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix Multiply and Accumulate) instructions. Inherits from GemmBaseParams which provides basic parameters such as A, B, C buffers and transposition flags. .. py:method:: gemm_rrr(A, B, C, mma_emitter) :abstractmethod: .. py:method:: gemm_rsr(A, B, C, mma_emitter) .. py:method:: gemm_srr(A, B, C, mma_emitter) :abstractmethod: .. py:method:: gemm_ssr(A, B, C, mma_emitter) Perform a single-step reduction (SSR) GEMM using Tensor Core MMA primitives. Loads fragments of A and B from shared memory, multiplies them, and accumulates into C. :param A: The buffer for matrix A (in shared memory). :type A: tir.Buffer :param B: The buffer for matrix B (in shared memory). :type B: tir.Buffer :param C: The buffer for the accumulation results. :type C: tir.Buffer :param mma_emitter: A helper object responsible for generating Tensor Core MMA instructions (ldmatrix, mma, etc.). :type mma_emitter: TensorCoreIntrinEmitter :returns: The generated IR expression (macro) representing the GEMM loop. :rtype: tir.PrimExpr .. py:method:: invoke() Entry point to generate a GEMM SSR (single-step reduction) with Tensor Core instructions. Performs the following steps: 1. Infers block partition parameters if necessary. 2. Creates a `TensorCoreIntrinEmitter` with the correct data types and dimensions. 3. Invokes the GEMM SSR function to generate the final IR expression. :returns: The generated GEMM IR expression. :rtype: tir.PrimExpr .. py:property:: in_dtype :type: str returns: The input data type for A and B. Assumes both have the same dtype. :rtype: str :raises AssertionError: If A and B do not share the same dtype. .. py:property:: accum_dtype :type: str returns: The accumulation data type for C. :rtype: str