tilelang.cuda.intrinsics.macro.wgmma_macro_generator¶

Attributes¶

Classes¶

WGMMADescriptorParams

Pre-computed parameters for WGMMA descriptor initialization and atom offset computation.

SwizzleMode

Enum where members are also (and must be) ints

TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

Module Contents¶

tilelang.cuda.intrinsics.macro.wgmma_macro_generator.lift¶
class tilelang.cuda.intrinsics.macro.wgmma_macro_generator.WGMMADescriptorParams¶

Pre-computed parameters for WGMMA descriptor initialization and atom offset computation.

Returned by compute_wgmma_*_desc_params() and consumed by init_wgmma_*_desc() and wgmma_*_atom() methods.

swizzle_mode: int¶

SwizzleMode enum value (passed directly to T.initialize_wgmma_descriptor).

leading_byte_offset: int¶

LBO >> 4, ready to pass to T.initialize_wgmma_descriptor.

stride_byte_offset: int¶

SBO >> 4, ready to pass to T.initialize_wgmma_descriptor.

swizzle_atom_elems: int¶

Number of elements per swizzle atom along the non-K dimension.

k_atom_size: int¶

max(swizzle_atom_elems // micro_size_k, 1).

elems_in_bytes: int¶

DataType(dtype).bits // 8.

Type:

Byte width of a single element

is_k_major: bool¶

Whether the matrix is stored in K-major order (affects offset formula branching).

class tilelang.cuda.intrinsics.macro.wgmma_macro_generator.SwizzleMode¶

Bases: enum.IntEnum

Enum where members are also (and must be) ints

NONE = 0¶
SWIZZLE_128B = 1¶
SWIZZLE_64B = 2¶
SWIZZLE_32B = 3¶
is_none()¶
Return type:

bool

is_swizzle_32b()¶
Return type:

bool

is_swizzle_64b()¶
Return type:

bool

is_swizzle_128b()¶
Return type:

bool

swizzle_byte_size()¶
Return type:

int

swizzle_atom_size()¶
Return type:

int

class tilelang.cuda.intrinsics.macro.wgmma_macro_generator.TensorCoreIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, thread_var=None)¶

Bases: tilelang.cuda.intrinsics.macro.mma_macro_generator.TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • thread_var (tvm.tirx.Var | None)

wgmma_prefix: str¶
wgmma_inst_m: int¶
wgmma_inst_n: int¶
a_shared_layout: tilelang.layout.Layout = None¶
b_shared_layout: tilelang.layout.Layout = None¶
wgmma(A_region, B_region, C_region, clear_accum=False, wg_wait=0)¶
Parameters:
  • A_region (tvm.tirx.BufferRegion)

  • B_region (tvm.tirx.BufferRegion)

  • C_region (tvm.tirx.BufferRegion)

  • clear_accum (tvm.tirx.PrimExpr)

  • wg_wait (int)

wgmma_rs(A_region, B_region, C_region, clear_accum=False, wg_wait=0)¶
Parameters:
  • A_region (tvm.tirx.BufferRegion)

  • B_region (tvm.tirx.BufferRegion)

  • C_region (tvm.tirx.BufferRegion)

  • clear_accum (tvm.tirx.PrimExpr)

  • wg_wait (int)

property wgmma_num_inst_m: int¶

Number of WGMMA instruction atoms along the M dimension.

Return type:

int

property wgmma_num_inst_n: int¶

Number of WGMMA instruction atoms along the N dimension.

Return type:

int

property wgmma_num_k_atoms: int¶

Number of K-dimension micro-steps (chunk // micro_size_k).

Return type:

int

property wgmma_a_regs: int¶

Number of 32-bit registers occupied by the A fragment (RS variant).

Return type:

int

property wgmma_accum_regs: int¶

Number of 32-bit registers occupied by the accumulator fragment.

Return type:

int

compute_wgmma_b_desc_params(B_region)¶

Compute B descriptor parameters from the B shared buffer region.

This is a pure-Python helper – no TIR code is emitted. The returned WGMMADescriptorParams is passed to init_wgmma_b_desc() and wgmma_*_atom() methods.

Parameters:

B_region (tvm.tirx.BufferRegion)

Return type:

WGMMADescriptorParams

compute_wgmma_a_desc_params(A_region)¶

Compute A descriptor parameters from the A shared buffer region (SS variant).

This is a pure-Python helper – no TIR code is emitted. The returned WGMMADescriptorParams is passed to init_wgmma_a_desc() and wgmma_ss_atom() methods.

Parameters:

A_region (tvm.tirx.BufferRegion)

Return type:

WGMMADescriptorParams

init_wgmma_b_desc(desc_b, B_region, b_params)¶

Emit TIR to initialize a pre-allocated WGMMA B descriptor.

Parameters:
  • desc_b (Buffer) – A descriptor buffer allocated via T.alloc_wgmma_desc().

  • B_region (BufferRegion) – The B operand shared memory region.

  • b_params (WGMMADescriptorParams) – Pre-computed parameters from compute_wgmma_b_desc_params().

init_wgmma_a_desc(desc_a, A_region, a_params)¶

Emit TIR to initialize a pre-allocated WGMMA A descriptor (SS variant).

Parameters:
  • desc_a (Buffer) – A descriptor buffer allocated via T.alloc_wgmma_desc().

  • A_region (BufferRegion) – The A operand shared memory region.

  • a_params (WGMMADescriptorParams) – Pre-computed parameters from compute_wgmma_a_desc_params().

wgmma_fence_a(A_buf)¶

Emit warpgroup_fence_operand for the A fragment buffer.

Parameters:

A_buf (tvm.tirx.Buffer)

wgmma_fence_c(C_buf)¶

Emit warpgroup_fence_operand for the accumulator buffer.

Parameters:

C_buf (tvm.tirx.Buffer)

wgmma_arrive()¶

Emit warpgroup_arrive().

wgmma_commit()¶

Emit warpgroup_commit_batch().

wgmma_wait(n=0)¶

Emit warpgroup_wait(n).

Parameters:

n (int)

wgmma_rs_atom(A_buf, desc_b, C_buf, inst_m_idx, inst_n_idx, ki, b_params, clear_accum=False)¶

Emit a single WGMMA RS instruction for atom (inst_m_idx, inst_n_idx, ki).

Must be called between a wgmma_fence_a/wgmma_fence_c/wgmma_arrive sequence and a wgmma_commit/wgmma_wait sequence.

Calling this for every (j, i, ki) in T.grid(wgmma_num_inst_n, wgmma_num_inst_m, wgmma_num_k_atoms) produces identical TIR to wgmma_rs().

Parameters:
  • A_buf (Buffer) – Fragment buffer for operand A (in registers).

  • desc_b (Buffer) – Initialized B descriptor (from init_wgmma_b_desc).

  • C_buf (Buffer) – Accumulator fragment buffer.

  • inst_m_idx (int) – M-dimension atom index (0 .. wgmma_num_inst_m - 1).

  • inst_n_idx (int) – N-dimension atom index (0 .. wgmma_num_inst_n - 1).

  • ki (int) – K-dimension atom index (0 .. wgmma_num_k_atoms - 1).

  • b_params (WGMMADescriptorParams) – Pre-computed B descriptor parameters.

  • clear_accum (PrimExpr) – Whether to zero the accumulator on the first K atom.

wgmma_ss_atom(desc_a, desc_b, C_buf, inst_m_idx, inst_n_idx, ki, a_params, b_params, clear_accum=False)¶

Emit a single WGMMA SS instruction for atom (inst_m_idx, inst_n_idx, ki).

Must be called between fence/arrive and commit/wait sequences.

Parameters:
  • desc_a (Buffer) – Initialized A descriptor (from init_wgmma_a_desc).

  • desc_b (Buffer) – Initialized B descriptor (from init_wgmma_b_desc).

  • C_buf (Buffer) – Accumulator fragment buffer.

  • inst_m_idx (int) – M-dimension atom index (0 .. wgmma_num_inst_m - 1).

  • inst_n_idx (int) – N-dimension atom index (0 .. wgmma_num_inst_n - 1).

  • ki (int) – K-dimension atom index (0 .. wgmma_num_k_atoms - 1).

  • a_params (WGMMADescriptorParams) – Pre-computed A descriptor parameters.

  • b_params (WGMMADescriptorParams) – Pre-computed B descriptor parameters.

  • clear_accum (PrimExpr) – Whether to zero the accumulator on the first K atom.

make_mma_load_layout(local_buf, matrix='A')¶

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

Parameters:
  • local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.

  • matrix (str)

Returns:

A fragment object that describes how threads and indices in local_buf are laid out.

Return type:

T.Fragment

Raises:

AssertionError – If local_buf is not detected to be a fragment buffer.

make_mma_store_layout(local_buf)¶

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

Parameters:

local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.

Returns:

A fragment object that describes how threads and indices in local_buf are laid out.

Return type:

T.Fragment

Raises:

AssertionError – If local_buf is not detected to be a fragment buffer.