tilelang.contrib.cutedsl.gemm_tcgen05¶

tcgen05 (SM100/Blackwell) MMA support for CuTeDSL backend.

Provides:
  • Tcgen05SmemDescriptor: 64-bit SMEM descriptor for tcgen05 MMA

  • initialize_tcgen05_descriptor: bitfield packing matching common.h layout

  • tcgen05mma_ss / tcgen05mma_ws_ss / tcgen05mma_ts: MMA PTX inline asm

  • tcgen05_mma_arrive: mbarrier arrive for MMA commit

  • tmem_allocate / tmem_deallocate: TMEM allocation/deallocation

Classes¶

Tcgen05SmemDescriptor

64-bit shared-memory descriptor for tcgen05 MMA (Blackwell).

Functions¶

initialize_tcgen05_descriptor(desc, start_address, ...)

Pack the tcgen05 SMEM descriptor bitfields.

tcgen05mma_ss(kind_dtype, desc_a, desc_b, tmem_c, ...)

tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, {masks}, p;

tcgen05mma_ws_ss(kind_dtype, desc_a, desc_b, tmem_c, ...)

tcgen05.mma.ws.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, p, 0;

tcgen05mma_ts(kind_dtype, tmem_a, desc_b, tmem_c, ...)

tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], [tmem_a], desc_b, desc_val, {masks}, p;

tcgen05_mma_arrive(mbar_ptr)

tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [mbar];

tmem_allocate(tmem_buffer_ptr, num_cols)

tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [dst], num_cols;

tmem_deallocate(tmem_ptr, num_cols)

tcgen05.dealloc.cta_group::1.sync.aligned.b32 tmem_addr, num_cols;

tcgen05_ld_32dp32bNx(N, pack16, tmem_start_col, ...)

Load N uint32 values from TMEM using tcgen05.ld.sync.aligned.32x32b.

tcgen05_ld_32dp64bNx(N, pack16, tmem_start_col, ...)

Load from TMEM using 32dp64b pattern (2x 16x64b for lower/upper 16 rows).

tcgen05_ld_32dp128bNx(N, pack16, tmem_start_col, ...)

Load from TMEM using 32dp128b pattern (2x 16x128b for lower/upper 16 rows).

tcgen05_ld_32dp256bNx(N, pack16, tmem_start_col, ...)

Load from TMEM using 32dp256b pattern (2x 16x256b for lower/upper 16 rows).

Module Contents¶

class tilelang.contrib.cutedsl.gemm_tcgen05.Tcgen05SmemDescriptor(desc_64=None)¶

64-bit shared-memory descriptor for tcgen05 MMA (Blackwell).

Mirrors tl::Tcgen05SMemDescriptor from common.h. Stored as two Int32 registers; recast to Int64 for the PTX operand.

Parameters:

desc_64 (cutlass.cute.Int64)

desc¶
desc_i64¶
__add__(offset)¶

Add byte offset. Like C++ operator+, shifts offset >> 4.

tilelang.contrib.cutedsl.gemm_tcgen05.initialize_tcgen05_descriptor(desc, start_address, leading_byte_offset, stride_byte_offset, base_offset, leading_abs, swizzle_mode)¶

Pack the tcgen05 SMEM descriptor bitfields.

Matches the C++ initialize_tcgen05_descriptor in common.h:
Low 32 bits (reg32_[0]):

[0:14) start_address >> 4 [16:30) leading_byte_offset (already >>4 from TIR)

High 32 bits (reg32_[1]):

[0:14) stride_byte_offset (already >>4 from TIR) [14:16) version = 1 [17:20) base_offset & 0x7 [20:21) lbo_mode (leading_is_absolute ? 1 : 0) [29:32) layout_type (swizzle_mode & 0x7)

tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ss(kind_dtype, desc_a, desc_b, tmem_c, desc_val, scale_out, mask0, mask1, mask2, mask3)¶

tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, {masks}, p;

Guarded by elect_one_sync — only one thread in the warp issues the MMA. The TIR codegen also wraps calls in if (threadIdx.x >> 5) == 0 which selects warp 0.

Parameters:
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ws_ss(kind_dtype, desc_a, desc_b, tmem_c, desc_val, scale_out)¶

tcgen05.mma.ws.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, p, 0;

Parameters:
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ts(kind_dtype, tmem_a, desc_b, tmem_c, desc_val, scale_out, mask0, mask1, mask2, mask3)¶

tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], [tmem_a], desc_b, desc_val, {masks}, p;

Parameters:
  • kind_dtype (str)

  • tmem_a (int)

  • desc_b (Tcgen05SmemDescriptor)

  • tmem_c (int)

  • desc_val (int)

  • scale_out (int)

  • mask0 (int)

  • mask1 (int)

  • mask2 (int)

  • mask3 (int)

tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_mma_arrive(mbar_ptr)¶

tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [mbar];

Guarded by elect_one_sync — only one thread in the warp issues the commit.

Parameters:

mbar_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_tcgen05.tmem_allocate(tmem_buffer_ptr, num_cols)¶

tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [dst], num_cols;

tmem_buffer_ptr: SMEM pointer that receives the allocated TMEM address. num_cols: number of columns to allocate.

Parameters:
  • tmem_buffer_ptr (cutlass.cute.Pointer)

  • num_cols (int)

tilelang.contrib.cutedsl.gemm_tcgen05.tmem_deallocate(tmem_ptr, num_cols)¶

tcgen05.dealloc.cta_group::1.sync.aligned.b32 tmem_addr, num_cols;

tmem_ptr: SMEM pointer to the uint32 holding the TMEM address. num_cols: number of columns to deallocate.

Parameters:
  • tmem_ptr (cutlass.cute.Pointer)

  • num_cols (int)

tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp32bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)¶

Load N uint32 values from TMEM using tcgen05.ld.sync.aligned.32x32b.

Matches tl::tcgen05_ld_32dp32bNx from copy_sm100.h. N: number of 32-bit elements to load (x-count, compile-time constant). pack16: if True, use 16-bit packing (not implemented yet). tmem_start_col: TMEM base column address. tmem_col_offset: additional column offset. dst_ptr: destination pointer (register memory).

Parameters:
  • N (cutlass.cutlass_dsl.Constexpr[int])

  • pack16 (cutlass.cutlass_dsl.Constexpr[bool])

  • tmem_start_col (int)

  • tmem_col_offset (int)

  • dst_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp64bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)¶

Load from TMEM using 32dp64b pattern (2x 16x64b for lower/upper 16 rows).

Matches tl::tmem_ld_32dp64bNx from tcgen_05_ld.h. N: x-count for 16x64b instructions. Total output: 2*N i32 regs.

Parameters:
  • N (cutlass.cutlass_dsl.Constexpr[int])

  • pack16 (cutlass.cutlass_dsl.Constexpr[bool])

  • tmem_start_col (int)

  • tmem_col_offset (int)

  • dst_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp128bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)¶

Load from TMEM using 32dp128b pattern (2x 16x128b for lower/upper 16 rows).

Matches tl::tmem_ld_32dp128bNx from tcgen_05_ld.h. N: x-count for 16x128b instructions. Total output: 4*N i32 regs. 16x128b.xN produces 2*N i32 regs per half.

Parameters:
  • N (cutlass.cutlass_dsl.Constexpr[int])

  • pack16 (cutlass.cutlass_dsl.Constexpr[bool])

  • tmem_start_col (int)

  • tmem_col_offset (int)

  • dst_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp256bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)¶

Load from TMEM using 32dp256b pattern (2x 16x256b for lower/upper 16 rows).

Matches tl::tmem_ld_32dp256bNx from tcgen_05_ld.h. N: x-count for 16x256b instructions. Total output: 8*N i32 regs. 16x256b.xN produces 4*N i32 regs per half.

Parameters:
  • N (cutlass.cutlass_dsl.Constexpr[int])

  • pack16 (cutlass.cutlass_dsl.Constexpr[bool])

  • tmem_start_col (int)

  • tmem_col_offset (int)

  • dst_ptr (cutlass.cute.Pointer)