tilelang.intrinsics.mfma_macro_generator¶

Attributes¶

Classes¶

MatrixCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

Module Contents¶

tilelang.intrinsics.mfma_macro_generator.lift¶
class tilelang.intrinsics.mfma_macro_generator.MatrixCoreIntrinEmitter(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, k_pack=None, is_m_first=False)¶

Bases: object

To eliminate Python syntax within TIR Macro.

Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • k_pack (Optional[int])

  • is_m_first (Optional[bool])

M_DIM = 16¶
N_DIM = 16¶
WARP_SIZE = 64¶
dtype_abbrv¶
k_pack = 1¶
is_m_first = False¶
a_dtype = 'float16'¶
b_dtype = 'float16'¶
accum_dtype = 'float16'¶
a_transposed = False¶
b_transposed = False¶
block_row_warps = 2¶
block_col_warps = 2¶
warp_row_tiles = 8¶
warp_col_tiles = 8¶
chunk = 16¶
warp_rows = 0¶
warp_cols = 0¶
reduce_k = 1¶
threads = 256¶
num_elems_per_byte = 1¶
get_ldmatrix_index_map(is_b=False)¶
extract_thread_binding(thread_id, is_m_first=None)¶

is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]

Return type:

Tuple[tvm.tir.PrimExpr, tvm.tir.PrimExpr, tvm.tir.PrimExpr]

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)¶
ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)¶
mfma(A_local_buf, B_local_buf, C_local_buf)¶
stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)¶