tilelang.contrib.cutedsl.gemm_tcgen05 ===================================== .. py:module:: tilelang.contrib.cutedsl.gemm_tcgen05 .. autoapi-nested-parse:: tcgen05 (SM100/Blackwell) MMA support for CuTeDSL backend. Provides: - Tcgen05SmemDescriptor: 64-bit SMEM descriptor for tcgen05 MMA - initialize_tcgen05_descriptor: bitfield packing matching common.h layout - tcgen05mma_ss / tcgen05mma_ws_ss / tcgen05mma_ts: MMA PTX inline asm - tcgen05_mma_arrive: mbarrier arrive for MMA commit - tmem_allocate / tmem_deallocate: TMEM allocation/deallocation Classes ------- .. autoapisummary:: tilelang.contrib.cutedsl.gemm_tcgen05.Tcgen05SmemDescriptor Functions --------- .. autoapisummary:: tilelang.contrib.cutedsl.gemm_tcgen05.initialize_tcgen05_descriptor tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ss tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ws_ss tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ts tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_mma_arrive tilelang.contrib.cutedsl.gemm_tcgen05.tmem_allocate tilelang.contrib.cutedsl.gemm_tcgen05.tmem_deallocate tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp32bNx tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp64bNx tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp128bNx tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp256bNx Module Contents --------------- .. py:class:: Tcgen05SmemDescriptor(desc_64 = None) 64-bit shared-memory descriptor for tcgen05 MMA (Blackwell). Mirrors tl::Tcgen05SMemDescriptor from common.h. Stored as two Int32 registers; recast to Int64 for the PTX operand. .. py:attribute:: desc .. py:attribute:: desc_i64 .. py:method:: __add__(offset) Add byte offset. Like C++ operator+, shifts offset >> 4. .. py:function:: initialize_tcgen05_descriptor(desc, start_address, leading_byte_offset, stride_byte_offset, base_offset, leading_abs, swizzle_mode) Pack the tcgen05 SMEM descriptor bitfields. Matches the C++ ``initialize_tcgen05_descriptor`` in common.h: Low 32 bits (reg32_[0]): [0:14) start_address >> 4 [16:30) leading_byte_offset (already >>4 from TIR) High 32 bits (reg32_[1]): [0:14) stride_byte_offset (already >>4 from TIR) [14:16) version = 1 [17:20) base_offset & 0x7 [20:21) lbo_mode (leading_is_absolute ? 1 : 0) [29:32) layout_type (swizzle_mode & 0x7) .. py:function:: tcgen05mma_ss(kind_dtype, desc_a, desc_b, tmem_c, desc_val, scale_out, mask0, mask1, mask2, mask3) tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, {masks}, p; Guarded by elect_one_sync — only one thread in the warp issues the MMA. The TIR codegen also wraps calls in ``if (threadIdx.x >> 5) == 0`` which selects warp 0. .. py:function:: tcgen05mma_ws_ss(kind_dtype, desc_a, desc_b, tmem_c, desc_val, scale_out) tcgen05.mma.ws.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, p, 0; .. py:function:: tcgen05mma_ts(kind_dtype, tmem_a, desc_b, tmem_c, desc_val, scale_out, mask0, mask1, mask2, mask3) tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], [tmem_a], desc_b, desc_val, {masks}, p; .. py:function:: tcgen05_mma_arrive(mbar_ptr) tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [mbar]; Guarded by elect_one_sync — only one thread in the warp issues the commit. .. py:function:: tmem_allocate(tmem_buffer_ptr, num_cols) tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [dst], num_cols; tmem_buffer_ptr: SMEM pointer that receives the allocated TMEM address. num_cols: number of columns to allocate. .. py:function:: tmem_deallocate(tmem_ptr, num_cols) tcgen05.dealloc.cta_group::1.sync.aligned.b32 tmem_addr, num_cols; tmem_ptr: SMEM pointer to the uint32 holding the TMEM address. num_cols: number of columns to deallocate. .. py:function:: tcgen05_ld_32dp32bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr) Load N uint32 values from TMEM using tcgen05.ld.sync.aligned.32x32b. Matches tl::tcgen05_ld_32dp32bNx from copy_sm100.h. N: number of 32-bit elements to load (x-count, compile-time constant). pack16: if True, use 16-bit packing (not implemented yet). tmem_start_col: TMEM base column address. tmem_col_offset: additional column offset. dst_ptr: destination pointer (register memory). .. py:function:: tcgen05_ld_32dp64bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr) Load from TMEM using 32dp64b pattern (2x 16x64b for lower/upper 16 rows). Matches tl::tmem_ld_32dp64bNx from tcgen_05_ld.h. N: x-count for 16x64b instructions. Total output: 2*N i32 regs. .. py:function:: tcgen05_ld_32dp128bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr) Load from TMEM using 32dp128b pattern (2x 16x128b for lower/upper 16 rows). Matches tl::tmem_ld_32dp128bNx from tcgen_05_ld.h. N: x-count for 16x128b instructions. Total output: 4*N i32 regs. 16x128b.xN produces 2*N i32 regs per half. .. py:function:: tcgen05_ld_32dp256bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr) Load from TMEM using 32dp256b pattern (2x 16x256b for lower/upper 16 rows). Matches tl::tmem_ld_32dp256bNx from tcgen_05_ld.h. N: x-count for 16x256b instructions. Total output: 8*N i32 regs. 16x256b.xN produces 4*N i32 regs per half.