tilelang.cuda.intrinsics.macro.wgmma_macro_generator¶
Attributes¶
Classes¶
Pre-computed parameters for WGMMA descriptor initialization and atom offset computation. |
|
Enum where members are also (and must be) ints |
|
To eliminate Python syntax within TIR Macro. |
Module Contents¶
- tilelang.cuda.intrinsics.macro.wgmma_macro_generator.lift¶
- class tilelang.cuda.intrinsics.macro.wgmma_macro_generator.WGMMADescriptorParams¶
Pre-computed parameters for WGMMA descriptor initialization and atom offset computation.
Returned by
compute_wgmma_*_desc_params()and consumed byinit_wgmma_*_desc()andwgmma_*_atom()methods.- swizzle_mode: int¶
SwizzleMode enum value (passed directly to
T.initialize_wgmma_descriptor).
- leading_byte_offset: int¶
LBO >> 4, ready to pass to
T.initialize_wgmma_descriptor.
- stride_byte_offset: int¶
SBO >> 4, ready to pass to
T.initialize_wgmma_descriptor.
- swizzle_atom_elems: int¶
Number of elements per swizzle atom along the non-K dimension.
- k_atom_size: int¶
max(swizzle_atom_elems // micro_size_k, 1).
- elems_in_bytes: int¶
DataType(dtype).bits // 8.- Type:
Byte width of a single element
- class tilelang.cuda.intrinsics.macro.wgmma_macro_generator.SwizzleMode¶
Bases:
enum.IntEnumEnum where members are also (and must be) ints
- NONE = 0¶
- SWIZZLE_128B = 1¶
- SWIZZLE_64B = 2¶
- SWIZZLE_32B = 3¶
- swizzle_byte_size()¶
- Return type:
int
- swizzle_atom_size()¶
- Return type:
int
- class tilelang.cuda.intrinsics.macro.wgmma_macro_generator.TensorCoreIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, thread_var=None)¶
Bases:
tilelang.cuda.intrinsics.macro.mma_macro_generator.TensorCoreIntrinEmitterTo eliminate Python syntax within TIR Macro.
- Parameters:
- wgmma_prefix: str¶
- wgmma_inst_m: int¶
- wgmma_inst_n: int¶
- wgmma(A_region, B_region, C_region, clear_accum=False, wg_wait=0)¶
- Parameters:
A_region (tvm.tirx.BufferRegion)
B_region (tvm.tirx.BufferRegion)
C_region (tvm.tirx.BufferRegion)
clear_accum (tvm.tirx.PrimExpr)
wg_wait (int)
- wgmma_rs(A_region, B_region, C_region, clear_accum=False, wg_wait=0)¶
- Parameters:
A_region (tvm.tirx.BufferRegion)
B_region (tvm.tirx.BufferRegion)
C_region (tvm.tirx.BufferRegion)
clear_accum (tvm.tirx.PrimExpr)
wg_wait (int)
- property wgmma_num_inst_m: int¶
Number of WGMMA instruction atoms along the M dimension.
- Return type:
int
- property wgmma_num_inst_n: int¶
Number of WGMMA instruction atoms along the N dimension.
- Return type:
int
- property wgmma_num_k_atoms: int¶
Number of K-dimension micro-steps (
chunk // micro_size_k).- Return type:
int
- property wgmma_a_regs: int¶
Number of 32-bit registers occupied by the A fragment (RS variant).
- Return type:
int
- property wgmma_accum_regs: int¶
Number of 32-bit registers occupied by the accumulator fragment.
- Return type:
int
- compute_wgmma_b_desc_params(B_region)¶
Compute B descriptor parameters from the B shared buffer region.
This is a pure-Python helper – no TIR code is emitted. The returned
WGMMADescriptorParamsis passed toinit_wgmma_b_desc()andwgmma_*_atom()methods.- Parameters:
B_region (tvm.tirx.BufferRegion)
- Return type:
- compute_wgmma_a_desc_params(A_region)¶
Compute A descriptor parameters from the A shared buffer region (SS variant).
This is a pure-Python helper – no TIR code is emitted. The returned
WGMMADescriptorParamsis passed toinit_wgmma_a_desc()andwgmma_ss_atom()methods.- Parameters:
A_region (tvm.tirx.BufferRegion)
- Return type:
- init_wgmma_b_desc(desc_b, B_region, b_params)¶
Emit TIR to initialize a pre-allocated WGMMA B descriptor.
- Parameters:
desc_b (Buffer) – A descriptor buffer allocated via
T.alloc_wgmma_desc().B_region (BufferRegion) – The B operand shared memory region.
b_params (WGMMADescriptorParams) – Pre-computed parameters from
compute_wgmma_b_desc_params().
- init_wgmma_a_desc(desc_a, A_region, a_params)¶
Emit TIR to initialize a pre-allocated WGMMA A descriptor (SS variant).
- Parameters:
desc_a (Buffer) – A descriptor buffer allocated via
T.alloc_wgmma_desc().A_region (BufferRegion) – The A operand shared memory region.
a_params (WGMMADescriptorParams) – Pre-computed parameters from
compute_wgmma_a_desc_params().
- wgmma_fence_a(A_buf)¶
Emit
warpgroup_fence_operandfor the A fragment buffer.- Parameters:
A_buf (tvm.tirx.Buffer)
- wgmma_fence_c(C_buf)¶
Emit
warpgroup_fence_operandfor the accumulator buffer.- Parameters:
C_buf (tvm.tirx.Buffer)
- wgmma_arrive()¶
Emit
warpgroup_arrive().
- wgmma_commit()¶
Emit
warpgroup_commit_batch().
- wgmma_wait(n=0)¶
Emit
warpgroup_wait(n).- Parameters:
n (int)
- wgmma_rs_atom(A_buf, desc_b, C_buf, inst_m_idx, inst_n_idx, ki, b_params, clear_accum=False)¶
Emit a single WGMMA RS instruction for atom
(inst_m_idx, inst_n_idx, ki).Must be called between a
wgmma_fence_a/wgmma_fence_c/wgmma_arrivesequence and awgmma_commit/wgmma_waitsequence.Calling this for every
(j, i, ki)inT.grid(wgmma_num_inst_n, wgmma_num_inst_m, wgmma_num_k_atoms)produces identical TIR towgmma_rs().- Parameters:
A_buf (Buffer) – Fragment buffer for operand A (in registers).
desc_b (Buffer) – Initialized B descriptor (from
init_wgmma_b_desc).C_buf (Buffer) – Accumulator fragment buffer.
inst_m_idx (int) – M-dimension atom index (0 .. wgmma_num_inst_m - 1).
inst_n_idx (int) – N-dimension atom index (0 .. wgmma_num_inst_n - 1).
ki (int) – K-dimension atom index (0 .. wgmma_num_k_atoms - 1).
b_params (WGMMADescriptorParams) – Pre-computed B descriptor parameters.
clear_accum (PrimExpr) – Whether to zero the accumulator on the first K atom.
- wgmma_ss_atom(desc_a, desc_b, C_buf, inst_m_idx, inst_n_idx, ki, a_params, b_params, clear_accum=False)¶
Emit a single WGMMA SS instruction for atom
(inst_m_idx, inst_n_idx, ki).Must be called between fence/arrive and commit/wait sequences.
- Parameters:
desc_a (Buffer) – Initialized A descriptor (from
init_wgmma_a_desc).desc_b (Buffer) – Initialized B descriptor (from
init_wgmma_b_desc).C_buf (Buffer) – Accumulator fragment buffer.
inst_m_idx (int) – M-dimension atom index (0 .. wgmma_num_inst_m - 1).
inst_n_idx (int) – N-dimension atom index (0 .. wgmma_num_inst_n - 1).
ki (int) – K-dimension atom index (0 .. wgmma_num_k_atoms - 1).
a_params (WGMMADescriptorParams) – Pre-computed A descriptor parameters.
b_params (WGMMADescriptorParams) – Pre-computed B descriptor parameters.
clear_accum (PrimExpr) – Whether to zero the accumulator on the first K atom.
- make_mma_load_layout(local_buf, matrix='A')¶
Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.
- Parameters:
local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.
matrix (str)
- Returns:
A fragment object that describes how threads and indices in local_buf are laid out.
- Return type:
T.Fragment
- Raises:
AssertionError – If local_buf is not detected to be a fragment buffer.
- make_mma_store_layout(local_buf)¶
Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.
- Parameters:
local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.
- Returns:
A fragment object that describes how threads and indices in local_buf are laid out.
- Return type:
T.Fragment
- Raises:
AssertionError – If local_buf is not detected to be a fragment buffer.