tilelang.contrib.cutedsl.gemm_tcgen05¶
tcgen05 (SM100/Blackwell) MMA support for CuTeDSL backend.
- Provides:
Tcgen05SmemDescriptor: 64-bit SMEM descriptor for tcgen05 MMA
initialize_tcgen05_descriptor: bitfield packing matching common.h layout
tcgen05mma_ss / tcgen05mma_ws_ss / tcgen05mma_ts: MMA PTX inline asm
tcgen05_mma_arrive: mbarrier arrive for MMA commit
tmem_allocate / tmem_deallocate: TMEM allocation/deallocation
Classes¶
64-bit shared-memory descriptor for tcgen05 MMA (Blackwell). |
Functions¶
|
Pack the tcgen05 SMEM descriptor bitfields. |
|
tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, {masks}, p; |
|
tcgen05.mma.ws.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, p, 0; |
|
tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], [tmem_a], desc_b, desc_val, {masks}, p; |
|
tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [mbar]; |
|
tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [dst], num_cols; |
|
tcgen05.dealloc.cta_group::1.sync.aligned.b32 tmem_addr, num_cols; |
|
Load N uint32 values from TMEM using tcgen05.ld.sync.aligned.32x32b. |
|
Load from TMEM using 32dp64b pattern (2x 16x64b for lower/upper 16 rows). |
|
Load from TMEM using 32dp128b pattern (2x 16x128b for lower/upper 16 rows). |
|
Load from TMEM using 32dp256b pattern (2x 16x256b for lower/upper 16 rows). |
Module Contents¶
- class tilelang.contrib.cutedsl.gemm_tcgen05.Tcgen05SmemDescriptor(desc_64=None)¶
64-bit shared-memory descriptor for tcgen05 MMA (Blackwell).
Mirrors tl::Tcgen05SMemDescriptor from common.h. Stored as two Int32 registers; recast to Int64 for the PTX operand.
- Parameters:
desc_64 (cutlass.cute.Int64)
- desc¶
- desc_i64¶
- __add__(offset)¶
Add byte offset. Like C++ operator+, shifts offset >> 4.
- tilelang.contrib.cutedsl.gemm_tcgen05.initialize_tcgen05_descriptor(desc, start_address, leading_byte_offset, stride_byte_offset, base_offset, leading_abs, swizzle_mode)¶
Pack the tcgen05 SMEM descriptor bitfields.
- Matches the C++
initialize_tcgen05_descriptorin common.h: - Low 32 bits (reg32_[0]):
[0:14) start_address >> 4 [16:30) leading_byte_offset (already >>4 from TIR)
- High 32 bits (reg32_[1]):
[0:14) stride_byte_offset (already >>4 from TIR) [14:16) version = 1 [17:20) base_offset & 0x7 [20:21) lbo_mode (leading_is_absolute ? 1 : 0) [29:32) layout_type (swizzle_mode & 0x7)
- Matches the C++
- tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ss(kind_dtype, desc_a, desc_b, tmem_c, desc_val, scale_out, mask0, mask1, mask2, mask3)¶
tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, {masks}, p;
Guarded by elect_one_sync — only one thread in the warp issues the MMA. The TIR codegen also wraps calls in
if (threadIdx.x >> 5) == 0which selects warp 0.- Parameters:
kind_dtype (str)
desc_a (Tcgen05SmemDescriptor)
desc_b (Tcgen05SmemDescriptor)
tmem_c (int)
desc_val (int)
scale_out (int)
mask0 (int)
mask1 (int)
mask2 (int)
mask3 (int)
- tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ws_ss(kind_dtype, desc_a, desc_b, tmem_c, desc_val, scale_out)¶
tcgen05.mma.ws.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, p, 0;
- Parameters:
kind_dtype (str)
desc_a (Tcgen05SmemDescriptor)
desc_b (Tcgen05SmemDescriptor)
tmem_c (int)
desc_val (int)
scale_out (int)
- tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ts(kind_dtype, tmem_a, desc_b, tmem_c, desc_val, scale_out, mask0, mask1, mask2, mask3)¶
tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], [tmem_a], desc_b, desc_val, {masks}, p;
- Parameters:
kind_dtype (str)
tmem_a (int)
desc_b (Tcgen05SmemDescriptor)
tmem_c (int)
desc_val (int)
scale_out (int)
mask0 (int)
mask1 (int)
mask2 (int)
mask3 (int)
- tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_mma_arrive(mbar_ptr)¶
tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [mbar];
Guarded by elect_one_sync — only one thread in the warp issues the commit.
- Parameters:
mbar_ptr (cutlass.cute.Pointer)
- tilelang.contrib.cutedsl.gemm_tcgen05.tmem_allocate(tmem_buffer_ptr, num_cols)¶
tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [dst], num_cols;
tmem_buffer_ptr: SMEM pointer that receives the allocated TMEM address. num_cols: number of columns to allocate.
- Parameters:
tmem_buffer_ptr (cutlass.cute.Pointer)
num_cols (int)
- tilelang.contrib.cutedsl.gemm_tcgen05.tmem_deallocate(tmem_ptr, num_cols)¶
tcgen05.dealloc.cta_group::1.sync.aligned.b32 tmem_addr, num_cols;
tmem_ptr: SMEM pointer to the uint32 holding the TMEM address. num_cols: number of columns to deallocate.
- Parameters:
tmem_ptr (cutlass.cute.Pointer)
num_cols (int)
- tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp32bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)¶
Load N uint32 values from TMEM using tcgen05.ld.sync.aligned.32x32b.
Matches tl::tcgen05_ld_32dp32bNx from copy_sm100.h. N: number of 32-bit elements to load (x-count, compile-time constant). pack16: if True, use 16-bit packing (not implemented yet). tmem_start_col: TMEM base column address. tmem_col_offset: additional column offset. dst_ptr: destination pointer (register memory).
- Parameters:
N (cutlass.cutlass_dsl.Constexpr[int])
pack16 (cutlass.cutlass_dsl.Constexpr[bool])
tmem_start_col (int)
tmem_col_offset (int)
dst_ptr (cutlass.cute.Pointer)
- tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp64bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)¶
Load from TMEM using 32dp64b pattern (2x 16x64b for lower/upper 16 rows).
Matches tl::tmem_ld_32dp64bNx from tcgen_05_ld.h. N: x-count for 16x64b instructions. Total output: 2*N i32 regs.
- Parameters:
N (cutlass.cutlass_dsl.Constexpr[int])
pack16 (cutlass.cutlass_dsl.Constexpr[bool])
tmem_start_col (int)
tmem_col_offset (int)
dst_ptr (cutlass.cute.Pointer)
- tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp128bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)¶
Load from TMEM using 32dp128b pattern (2x 16x128b for lower/upper 16 rows).
Matches tl::tmem_ld_32dp128bNx from tcgen_05_ld.h. N: x-count for 16x128b instructions. Total output: 4*N i32 regs. 16x128b.xN produces 2*N i32 regs per half.
- Parameters:
N (cutlass.cutlass_dsl.Constexpr[int])
pack16 (cutlass.cutlass_dsl.Constexpr[bool])
tmem_start_col (int)
tmem_col_offset (int)
dst_ptr (cutlass.cute.Pointer)
- tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp256bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)¶
Load from TMEM using 32dp256b pattern (2x 16x256b for lower/upper 16 rows).
Matches tl::tmem_ld_32dp256bNx from tcgen_05_ld.h. N: x-count for 16x256b instructions. Total output: 8*N i32 regs. 16x256b.xN produces 4*N i32 regs per half.
- Parameters:
N (cutlass.cutlass_dsl.Constexpr[int])
pack16 (cutlass.cutlass_dsl.Constexpr[bool])
tmem_start_col (int)
tmem_col_offset (int)
dst_ptr (cutlass.cute.Pointer)