tilelang.intrinsics.mfma_macro_generator module#
- class tilelang.intrinsics.mfma_macro_generator.MatrixCoreIntrinEmitter(a_dtype: str = 'float16', b_dtype: str = 'float16', accum_dtype: str = 'float16', a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, block_col_warps: int = 2, warp_row_tiles: int = 8, warp_col_tiles: int = 8, chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, k_pack: Optional[int] = None, is_m_first: Optional[bool] = False)#
Bases:
object
To eliminate Python syntax within TIR Macro.
- M_DIM = 16#
- N_DIM = 16#
- WARP_SIZE = 64#
- dtype_abbrv = {'bfloat16': 'bf16', 'e4m3_float8': 'e4m3', 'e5m2_float8': 'e5m2', 'float16': 'fp16', 'float32': 'fp32', 'int32': 'int32', 'int8': 'int8'}#
- extract_thread_binding(thread_id, is_m_first=None) Tuple[PrimExpr, PrimExpr, PrimExpr] #
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
- get_ldmatrix_index_map(is_b=False)#
- is_m_first = False#
- k_pack = 1#
- ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)#
- ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)#
- mfma(A_local_buf, B_local_buf, C_local_buf)#
- stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)#