tilelang.intrinsics.mma_macro_generator module#
- class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitter(a_dtype: str = 'float16', b_dtype: str = 'float16', accum_dtype: str = 'float16', a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, block_col_warps: int = 2, warp_row_tiles: int = 8, warp_col_tiles: int = 8, chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, is_m_first: Optional[bool] = False)#
Bases:
TensorCoreIntrinEmitter
- mma(A_local_buf, B_local_buf, C_local_buf)#
- class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitterWithLadderTransform(a_dtype: str = 'float16', b_dtype: str = 'float16', accum_dtype: str = 'float16', a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, block_col_warps: int = 2, warp_row_tiles: int = 8, warp_col_tiles: int = 8, chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, is_m_first: Optional[bool] = False, transform_kind_a: Union[int, TransformKind] = 0, transform_kind_b: Union[int, TransformKind] = 0)#
Bases:
TensorCoreIntrinEmitterWithLadderTransform
- mma(A_local_buf, B_local_buf, C_local_buf)#
- class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitter(a_dtype: str = 'float16', b_dtype: str = 'float16', accum_dtype: str = 'float16', a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, block_col_warps: int = 2, warp_row_tiles: int = 8, warp_col_tiles: int = 8, chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, is_m_first: Optional[bool] = False)#
Bases:
object
To eliminate Python syntax within TIR Macro.
- M_DIM = 16#
- N_DIM = 16#
- WARP_SIZE = 32#
- dtype_abbrv = {'bfloat16': 'bf16', 'e4m3_float8': 'e4m3', 'e5m2_float8': 'e5m2', 'float16': 'fp16', 'float32': 'fp32', 'int32': 'int32', 'int8': 'int8'}#
- extract_thread_binding(thread_id: PrimExpr, is_m_first: Optional[bool] = None) Tuple[PrimExpr, PrimExpr, PrimExpr] #
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
- get_store_index_map(inverse: bool = False) IndexMap #
- is_m_first = False#
- ldmatrix_a(A_local_buf: Buffer, A_shared_buf: Buffer, ki: PrimExpr, rk: Optional[PrimExpr] = 0)#
- ldmatrix_b(B_local_buf: Buffer, B_shared_buf: Buffer, ki: PrimExpr, rk: Optional[PrimExpr] = 0)#
- make_mma_load_layout(local_buf: Buffer, matrix: Literal['A', 'B'] = 'A') Fragment #
Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.
- Parameters:
local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.
- Returns:
A fragment object that describes how threads and indices in local_buf are laid out.
- Return type:
T.Fragment
- Raises:
AssertionError – If local_buf is not detected to be a fragment buffer.
- make_mma_store_layout(local_buf: Buffer) Fragment #
Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.
- Parameters:
local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.
- Returns:
A fragment object that describes how threads and indices in local_buf are laid out.
- Return type:
T.Fragment
- Raises:
AssertionError – If local_buf is not detected to be a fragment buffer.
- mma(A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: Optional[PrimExpr] = 0)#
- stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)#
- class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitterWithLadderTransform(a_dtype: str = 'float16', b_dtype: str = 'float16', accum_dtype: str = 'float16', a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, block_col_warps: int = 2, warp_row_tiles: int = 8, warp_col_tiles: int = 8, chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, is_m_first: Optional[bool] = False, transform_kind_a: Union[int, TransformKind] = 0, transform_kind_b: Union[int, TransformKind] = 0)#
Bases:
TensorCoreIntrinEmitter
To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.
- ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)#
- ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)#
- mma(A_local_buf, B_local_buf, C_local_buf)#