tilelang.intrinsics.wmma_macro_generator¶
WMMA intrinsic emitter for AMD RDNA architectures (gfx11 / gfx12).
Only supports the f16->f32, 16x16x16 variant with warp-size=32.
- Thread-data mapping (per AMDGPU ISA):
A[16][K=16]: thread t holds A[t//2][(t%2)*8 : (t%2)*8+8] (8 fp16 = 4 f32 per thread) B[K=16][16]: same mapping as A for the transposed dimension C/D[16][16]: thread t holds D[t//2][(t%2)*8 : (t%2)*8+8] (8 f32 per thread)
Attributes¶
Classes¶
Intrinsic emitter for AMD RDNA WMMA (16×16×16, warp-size=32). |
Module Contents¶
- tilelang.intrinsics.wmma_macro_generator.lift¶
- class tilelang.intrinsics.wmma_macro_generator.WMMAIntrinEmitter(a_dtype='float16', b_dtype='float16', accum_dtype='float32', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=16, warp_col_tiles=16, chunk=16, k_pack=1, thread_var=None, target=None)¶
Intrinsic emitter for AMD RDNA WMMA (16×16×16, warp-size=32).
- Supports:
fp16 -> fp32 (f32_16x16x16_f16_w32 / _gfx12)
- Parameters:
- M_DIM = 16¶
- N_DIM = 16¶
- K_DIM = 16¶
- WARP_SIZE = 32¶
- a_dtype = 'float16'¶
- b_dtype = 'float16'¶
- accum_dtype = 'float32'¶
- a_transposed = False¶
- b_transposed = False¶
- block_row_warps = 2¶
- block_col_warps = 2¶
- warp_row_tiles = 16¶
- warp_col_tiles = 16¶
- chunk = 16¶
- k_pack = 1¶
- thread_var = None¶
- target = None¶
- micro_size_x = 16¶
- micro_size_y = 16¶
- micro_size_k = 16¶
- local_size_a = 8¶
- local_size_b = 8¶
- local_size_out = 8¶
- warp_rows = 1¶
- warp_cols = 1¶
- threads = 128¶
- wmma_shape = 'f32_16x16x16_f16_w32'¶
- get_thread_binding()¶
- Return type:
tvm.tir.PrimExpr
- extract_thread_binding(thread_id)¶
Return (lane_id, warp_n, warp_m).
- get_ldmatrix_index_map(is_b=False)¶
Return (forward, reverse) index maps for shared→local loading.
- For WMMA gfx12:
A is stored row-major [M, K]. Thread t loads A[t%16][(t//16)*8+local].
B (non-transposed) is stored row-major [K, N]. Thread t loads B[t%16][(t//16)*8+local] (same shape/pattern as A).
B (transposed) is stored [N, K]. Thread t loads B_T[t%16][(t//16)*8+local] (N-row, K-col).
- Parameters:
is_b (bool)
- get_store_index_map(inverse=False)¶
Return the store index map.
The forward map is (thread_id, local_id) -> (i, j), which is affine. The inverse map is (i, j) -> (thread_id, local_id).
- Parameters:
inverse (bool)
- Return type:
tvm.tir.IndexMap
- ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)¶
- ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)¶
- wmma(A_local_buf, B_local_buf, C_local_buf, k_inner=0)¶
- Parameters:
A_local_buf (tvm.tir.Buffer)
B_local_buf (tvm.tir.Buffer)
C_local_buf (tvm.tir.Buffer)
k_inner (tvm.tir.PrimExpr | None)
- stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)¶
- make_wmma_load_layout(local_buf, matrix='A')¶
- Parameters:
local_buf (tvm.tir.Buffer)
matrix (Literal['A', 'B'])
- Return type:
tilelang.language.Fragment
- make_wmma_store_layout(local_buf)¶
- Parameters:
local_buf (tvm.tir.Buffer)
- Return type:
tilelang.language.Fragment