tilelang.cuda.op.gemm.gemm_wgmma ================================ .. py:module:: tilelang.cuda.op.gemm.gemm_wgmma Attributes ---------- .. autoapisummary:: tilelang.cuda.op.gemm.gemm_wgmma.GEMM_INST_WGMMA Classes ------- .. autoapisummary:: tilelang.cuda.op.gemm.gemm_wgmma.GemmWGMMA Module Contents --------------- .. py:data:: GEMM_INST_WGMMA :value: 'cuda.wgmma' .. py:class:: GemmWGMMA Bases: :py:obj:`tilelang.tileop.gemm.gemm_base.GemmBase` Base class for GEMM tile operators. Classifies the GEMM variant by the memory scopes of operands A and B (SS, SR, RS, TS, RR) and provides common property accessors for the underlying ``gemm_node`` IR node. .. py:method:: infer_shared_layout(continuity) Infer the swizzle layout for shared memory based on continuity. WGMMA can directly use shared memory as input, so the swizzle layout must match the tensor core's access pattern. The swizzle granularity is determined by the continuous dimension size: - 128B swizzle (Full): continuity % (vectorized_size * 8) == 0 - 64B swizzle (Half): continuity % (vectorized_size * 4) == 0 - 32B swizzle (Quarter): continuity % (vectorized_size * 2) == 0 - Linear (no swizzle): otherwise See: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html .. py:method:: infer_layout(target, thread_nums) .. py:method:: lower(layout_map, target, thread_bounds, thread_var, mbar_phase_expr = None) .. py:method:: is_gemm_ss() Return True if both A and B are in shared memory (SS variant). .. py:method:: is_gemm_sr() Return True if A is in shared memory and B is in registers (SR variant). .. py:method:: is_gemm_rs() Return True if A is in registers and B is in shared memory (RS variant). .. py:method:: is_gemm_rr() Return True if both A and B are in registers (RR variant).