tilelang.intrinsics.mfma_macro_generator module#

class tilelang.intrinsics.mfma_macro_generator.MatrixCoreIntrinEmitter(a_dtype: str = 'float16', b_dtype: str = 'float16', accum_dtype: str = 'float16', a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, block_col_warps: int = 2, warp_row_tiles: int = 8, warp_col_tiles: int = 8, chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, k_pack: Optional[int] = None, is_m_first: Optional[bool] = False)#

Bases: object

To eliminate Python syntax within TIR Macro.

M_DIM = 16#
N_DIM = 16#
WARP_SIZE = 64#
dtype_abbrv = {'bfloat16': 'bf16', 'e4m3_float8': 'e4m3', 'e5m2_float8': 'e5m2', 'float16': 'fp16', 'float32': 'fp32', 'int32': 'int32', 'int8': 'int8'}#
extract_thread_binding(thread_id, is_m_first=None) Tuple[PrimExpr, PrimExpr, PrimExpr]#

is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]

get_ldmatrix_index_map(is_b=False)#
is_m_first = False#
k_pack = 1#
ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)#
ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)#
mfma(A_local_buf, B_local_buf, C_local_buf)#
stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)#