tilelang.carver.matmul_analysis module#

A GEMM schedule rule for GPU operators.

class tilelang.carver.matmul_analysis.IterKind(value)#

Bases: Enum

Iter kinds for GEMM-liked programs. We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], where I, J, K are fundamental axes for gemm and S represents all other spatial axes (e.g. batches) kIter_S: spatial axes kIter_I: I axes kIter_J: J axes kIter_K: K axes kIter_T: trivial axes (i.e. with extent 1)

kIter_I = 1#
kIter_J = 2#
kIter_K = 3#
kIter_S = 0#
kIter_T = 4#
class tilelang.carver.matmul_analysis.IterTrait(kind: tilelang.carver.matmul_analysis.IterKind, extent: tvm.ir.expr.PrimExpr)#

Bases: object

extent: PrimExpr#
kind: IterKind#
tilelang.carver.matmul_analysis.auto_inline_consumer_chain(sch: Schedule, block: BlockRV)#
tilelang.carver.matmul_analysis.auto_inline_consumers(sch: Schedule, block: BlockRV)#
tilelang.carver.matmul_analysis.auto_inline_producers(sch: Schedule, block: BlockRV, skip_blocks: Optional[List[BlockRV]] = None)#
tilelang.carver.matmul_analysis.collect_vars_from_expr(prim_expr)#
tilelang.carver.matmul_analysis.detect_iter_traits(block: Block) Optional[Tuple[List[IterTrait]]]#

Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]

Parameters:

block (tir.Block) – The block to be analyzed

Returns:

traits – The detected iter traits for axes in A, B and C. None if the block does not match the pattern.

Return type:

Optional[Tuple[List[IterTrait]]]

tilelang.carver.matmul_analysis.find_arg_idx_from_buffer_chain(sch: Schedule, main_block: BlockRV, buffer: Buffer) int#

traverse to find the arg index from the buffer

tilelang.carver.matmul_analysis.find_first_similar_buffer(regions: List[BufferRegion], buffer: Buffer)#
tilelang.carver.matmul_analysis.find_first_similar_region(regions: List[BufferRegion], buffer: Buffer)#
tilelang.carver.matmul_analysis.find_last_producer_from_buffer(sch, main_block, buffer: Buffer) Optional[BlockRV]#
tilelang.carver.matmul_analysis.get_dequantize_block(sch, blocks) Optional[BlockRV]#
tilelang.carver.matmul_analysis.get_in_out_dtypes(block: Block) Tuple[str]#

Detect In/Out data types for the given block based on the analysis if read/write buffers.

tilelang.carver.matmul_analysis.get_index_map(block: Block, layout: Optional[List[str]] = None) Optional[Tuple[IndexMap, ...]]#

Get index maps for the block

Parameters:
  • block (tir.Block) – The block to be analyzed

  • layout (List[str]) – the target layout index map to be used. ‘n’ for [i, k] layout ‘t’ for [k, j] layout ‘a’ for auto inference based on whether the last axis is reduction.

Returns:

index_maps – The index maps for the block, or None if the block is not a gemm-liked kernel

Return type:

Optional[Tuple[tir.IndexMap]]

tilelang.carver.matmul_analysis.get_ladder_stage3_map(dtype='float16', index_dtype='int32')#
tilelang.carver.matmul_analysis.get_propagate_map(trans: bool = True, dtype='float16', matrix_name='A', index_dtype='int32')#
tilelang.carver.matmul_analysis.get_tensorized_func_and_tags(func: PrimFunc, target: Target, layout: Optional[List[str]] = None, skip_normalize: bool = False, allow_gemv: bool = False) Tuple[PrimFunc, Dict[str, Union[List[int], int]]]#

transform function to matmul if necessary (e.g. transform conv2d with im2col)

tilelang.carver.matmul_analysis.inline_transpose_block(sch: Schedule, blocks: List[BlockRV])#
tilelang.carver.matmul_analysis.is_identity_block(block_stmt: Block) bool#
tilelang.carver.matmul_analysis.is_identity_or_transpose_block(block_stmt: Block) bool#
tilelang.carver.matmul_analysis.is_transpose_block(block_stmt: Block) bool#
tilelang.carver.matmul_analysis.layout_propagate_chain(sch: Schedule, start_block: BlockRV, start_buffer: Buffer, end_block: BlockRV, index_map: IndexMap)#
tilelang.carver.matmul_analysis.make_iter_fusion_index_map(traits: List[IterTrait], kind_order: List[IterKind]) IndexMap#
tilelang.carver.matmul_analysis.normalize_to_matmul(sch: Schedule, main_block: BlockRV, layout: Optional[List[str]] = None) Optional[Schedule]#