tilelang.cuda.intrinsics.macro.mma_macro_generator¶

Attributes¶

Classes¶

TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro.

Module Contents¶

tilelang.cuda.intrinsics.macro.mma_macro_generator.lift¶
class tilelang.cuda.intrinsics.macro.mma_macro_generator.TensorCoreIntrinEmitter(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, thread_var=None)¶

To eliminate Python syntax within TIR Macro.

Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • thread_var (tvm.tirx.Var | None)

M_DIM = 16¶
n_dim = 16¶
WARP_SIZE = 32¶
dtype_abbrv¶
is_m_first: bool = False¶
warp_rows: int = 1¶
warp_cols: int = 1¶
a_dtype¶
b_dtype¶
accum_dtype¶
a_transposed = False¶
b_transposed = False¶
block_row_warps = 2¶
block_col_warps = 2¶
warp_row_tiles = 8¶
warp_col_tiles = 8¶
chunk = 16¶
reduce_k = 1¶
threads = 128¶
num_elems_per_byte = 1¶
thread_var = None¶
get_thread_binding()¶
get_store_index_map(inverse=False)¶
Parameters:

inverse (bool)

Return type:

tvm.tirx.IndexMap

extract_thread_binding(thread_id, is_m_first=None)¶

is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]

Parameters:
  • thread_id (tvm.tirx.PrimExpr)

  • is_m_first (bool | None)

Return type:

tuple[tvm.tirx.PrimExpr, tvm.tirx.PrimExpr, tvm.tirx.PrimExpr]

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)¶
Parameters:
  • A_local_buf (tvm.tirx.Buffer)

  • A_shared_buf (tvm.tirx.Buffer | tvm.tirx.BufferRegion)

  • ki (tvm.tirx.PrimExpr)

  • rk (tvm.tirx.PrimExpr | None)

ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)¶
Parameters:
  • B_local_buf (tvm.tirx.Buffer)

  • B_shared_buf (tvm.tirx.Buffer | tvm.tirx.BufferRegion)

  • ki (tvm.tirx.PrimExpr)

  • rk (tvm.tirx.PrimExpr | None)

mma(A_local_buf, B_local_buf, C_local_buf, k_inner=0)¶
Parameters:
  • A_local_buf (tvm.tirx.Buffer)

  • B_local_buf (tvm.tirx.Buffer)

  • C_local_buf (tvm.tirx.Buffer)

  • k_inner (tvm.tirx.PrimExpr | None)

property mma_num_inst_m: int¶

Number of MMA instruction atoms along the M dimension.

Return type:

int

property mma_num_inst_n: int¶

Number of MMA instruction atoms along the N dimension.

Return type:

int

mma_atom(A_local_buf, B_local_buf, C_local_buf, inst_m_idx, inst_n_idx, k_inner=0)¶

Emit a single MMA atom for tile (inst_m_idx, inst_n_idx).

This is the atomic building block of mma(). Calling this method for every (i, j) in T.grid(mma_num_inst_m, mma_num_inst_n) produces identical TIR to a single mma() call.

Parameters:
  • A_local_buf (Buffer) – Fragment buffer for operand A.

  • B_local_buf (Buffer) – Fragment buffer for operand B.

  • C_local_buf (Buffer) – Accumulator fragment buffer.

  • inst_m_idx (int or PrimExpr) – M-dimension atom index (0 .. mma_num_inst_m - 1).

  • inst_n_idx (int or PrimExpr) – N-dimension atom index (0 .. mma_num_inst_n - 1).

  • k_inner (int or PrimExpr) – K-inner step index used to offset A/B fragments.

stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)¶
make_mma_load_layout(local_buf, matrix='A')¶

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

Parameters:
  • local_buf (tirx.Buffer) – The local buffer representing a fragment of a matrix.

  • matrix (Literal['A', 'B'])

Returns:

A fragment object that describes how threads and indices in local_buf are laid out.

Return type:

T.Fragment

Raises:

AssertionError – If local_buf is not detected to be a fragment buffer.

make_mma_store_layout(local_buf)¶

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

Parameters:

local_buf (tirx.Buffer) – The local buffer representing a fragment of a matrix.

Returns:

A fragment object that describes how threads and indices in local_buf are laid out.

Return type:

T.Fragment

Raises:

AssertionError – If local_buf is not detected to be a fragment buffer.

class tilelang.cuda.intrinsics.macro.mma_macro_generator.TensorCoreIntrinEmitterWithLadderTransform(a_dtype=T.float16, b_dtype=T.float16, accum_dtype=T.float16, a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)¶

Bases: TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.

Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (bool | None)

  • transform_kind_a (int | tilelang.common.TransformKind)

  • transform_kind_b (int | tilelang.common.TransformKind)

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)¶
ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)¶
mma(A_local_buf, B_local_buf, C_local_buf)¶