tilelang.intrinsics.mma_macro_generator¶

Attributes¶

Classes¶

TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro.

INT4TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

INT4TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro.

Module Contents¶

tilelang.intrinsics.mma_macro_generator.lift¶
class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitter(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False)¶

Bases: object

To eliminate Python syntax within TIR Macro.

Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (Optional[bool])

M_DIM = 16¶
N_DIM = 16¶
WARP_SIZE = 32¶
dtype_abbrv¶
is_m_first = False¶
a_dtype = 'float16'¶
b_dtype = 'float16'¶
accum_dtype = 'float16'¶
a_transposed = False¶
b_transposed = False¶
block_row_warps = 2¶
block_col_warps = 2¶
warp_row_tiles = 8¶
warp_col_tiles = 8¶
chunk = 16¶
warp_rows = 0¶
warp_cols = 0¶
reduce_k = 1¶
threads = 128¶
num_elems_per_byte = 1¶
get_store_index_map(inverse=False)¶
Parameters:

inverse (bool)

Return type:

tvm.tir.IndexMap

extract_thread_binding(thread_id, is_m_first=None)¶

is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]

Parameters:
  • thread_id (tvm.tir.PrimExpr)

  • is_m_first (Optional[bool])

Return type:

Tuple[tvm.tir.PrimExpr, tvm.tir.PrimExpr, tvm.tir.PrimExpr]

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)¶
Parameters:
  • A_local_buf (tvm.tir.Buffer)

  • A_shared_buf (tvm.tir.Buffer)

  • ki (tvm.tir.PrimExpr)

  • rk (Optional[tvm.tir.PrimExpr])

ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)¶
Parameters:
  • B_local_buf (tvm.tir.Buffer)

  • B_shared_buf (tvm.tir.Buffer)

  • ki (tvm.tir.PrimExpr)

  • rk (Optional[tvm.tir.PrimExpr])

mma(A_local_buf, B_local_buf, C_local_buf, k_inner=0)¶
Parameters:
  • A_local_buf (tvm.tir.Buffer)

  • B_local_buf (tvm.tir.Buffer)

  • C_local_buf (tvm.tir.Buffer)

  • k_inner (Optional[tvm.tir.PrimExpr])

stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)¶
make_mma_load_layout(local_buf, matrix='A')¶

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

Parameters:
  • local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.

  • matrix (Literal['A', 'B'])

Returns:

A fragment object that describes how threads and indices in local_buf are laid out.

Return type:

T.Fragment

Raises:

AssertionError – If local_buf is not detected to be a fragment buffer.

make_mma_store_layout(local_buf)¶

Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.

Parameters:

local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.

Returns:

A fragment object that describes how threads and indices in local_buf are laid out.

Return type:

T.Fragment

Raises:

AssertionError – If local_buf is not detected to be a fragment buffer.

class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitterWithLadderTransform(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)¶

Bases: TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.

Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (Optional[bool])

  • transform_kind_a (Union[int, tilelang.common.TransformKind])

  • transform_kind_b (Union[int, tilelang.common.TransformKind])

ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)¶
ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)¶
mma(A_local_buf, B_local_buf, C_local_buf)¶
class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitter(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False)¶

Bases: TensorCoreIntrinEmitter

To eliminate Python syntax within TIR Macro.

Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (Optional[bool])

mma(A_local_buf, B_local_buf, C_local_buf)¶
class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitterWithLadderTransform(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)¶

Bases: TensorCoreIntrinEmitterWithLadderTransform

To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.

Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • reduce_k (int)

  • num_elems_per_byte (int)

  • is_m_first (Optional[bool])

  • transform_kind_a (Union[int, tilelang.common.TransformKind])

  • transform_kind_b (Union[int, tilelang.common.TransformKind])

mma(A_local_buf, B_local_buf, C_local_buf)¶