tilelang.carver.matmul_analysis module#
A GEMM schedule rule for GPU operators.
- class tilelang.carver.matmul_analysis.IterKind(value)#
Bases:
Enum
Iter kinds for GEMM-liked programs. We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], where I, J, K are fundamental axes for gemm and S represents all other spatial axes (e.g. batches) kIter_S: spatial axes kIter_I: I axes kIter_J: J axes kIter_K: K axes kIter_T: trivial axes (i.e. with extent 1)
- kIter_I = 1#
- kIter_J = 2#
- kIter_K = 3#
- kIter_S = 0#
- kIter_T = 4#
- class tilelang.carver.matmul_analysis.IterTrait(kind: tilelang.carver.matmul_analysis.IterKind, extent: tvm.ir.expr.PrimExpr)#
Bases:
object
- extent: PrimExpr#
- tilelang.carver.matmul_analysis.auto_inline_consumer_chain(sch: Schedule, block: BlockRV)#
- tilelang.carver.matmul_analysis.auto_inline_consumers(sch: Schedule, block: BlockRV)#
- tilelang.carver.matmul_analysis.auto_inline_producers(sch: Schedule, block: BlockRV, skip_blocks: Optional[List[BlockRV]] = None)#
- tilelang.carver.matmul_analysis.collect_vars_from_expr(prim_expr)#
- tilelang.carver.matmul_analysis.detect_iter_traits(block: Block) Optional[Tuple[List[IterTrait]]] #
Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
- Parameters:
block (tir.Block) – The block to be analyzed
- Returns:
traits – The detected iter traits for axes in A, B and C. None if the block does not match the pattern.
- Return type:
Optional[Tuple[List[IterTrait]]]
- tilelang.carver.matmul_analysis.find_arg_idx_from_buffer_chain(sch: Schedule, main_block: BlockRV, buffer: Buffer) int #
traverse to find the arg index from the buffer
- tilelang.carver.matmul_analysis.find_first_similar_buffer(regions: List[BufferRegion], buffer: Buffer)#
- tilelang.carver.matmul_analysis.find_first_similar_region(regions: List[BufferRegion], buffer: Buffer)#
- tilelang.carver.matmul_analysis.find_last_producer_from_buffer(sch, main_block, buffer: Buffer) Optional[BlockRV] #
- tilelang.carver.matmul_analysis.get_dequantize_block(sch, blocks) Optional[BlockRV] #
- tilelang.carver.matmul_analysis.get_in_out_dtypes(block: Block) Tuple[str] #
Detect In/Out data types for the given block based on the analysis if read/write buffers.
- tilelang.carver.matmul_analysis.get_index_map(block: Block, layout: Optional[List[str]] = None) Optional[Tuple[IndexMap, ...]] #
Get index maps for the block
- Parameters:
block (tir.Block) – The block to be analyzed
layout (List[str]) – the target layout index map to be used. ‘n’ for [i, k] layout ‘t’ for [k, j] layout ‘a’ for auto inference based on whether the last axis is reduction.
- Returns:
index_maps – The index maps for the block, or None if the block is not a gemm-liked kernel
- Return type:
Optional[Tuple[tir.IndexMap]]
- tilelang.carver.matmul_analysis.get_ladder_stage3_map(dtype='float16', index_dtype='int32')#
- tilelang.carver.matmul_analysis.get_propagate_map(trans: bool = True, dtype='float16', matrix_name='A', index_dtype='int32')#
- tilelang.carver.matmul_analysis.get_tensorized_func_and_tags(func: PrimFunc, target: Target, layout: Optional[List[str]] = None, skip_normalize: bool = False, allow_gemv: bool = False) Tuple[PrimFunc, Dict[str, Union[List[int], int]]] #
transform function to matmul if necessary (e.g. transform conv2d with im2col)
- tilelang.carver.matmul_analysis.inline_transpose_block(sch: Schedule, blocks: List[BlockRV])#
- tilelang.carver.matmul_analysis.is_identity_block(block_stmt: Block) bool #
- tilelang.carver.matmul_analysis.is_identity_or_transpose_block(block_stmt: Block) bool #
- tilelang.carver.matmul_analysis.is_transpose_block(block_stmt: Block) bool #
- tilelang.carver.matmul_analysis.layout_propagate_chain(sch: Schedule, start_block: BlockRV, start_buffer: Buffer, end_block: BlockRV, index_map: IndexMap)#
- tilelang.carver.matmul_analysis.make_iter_fusion_index_map(traits: List[IterTrait], kind_order: List[IterKind]) IndexMap #
- tilelang.carver.matmul_analysis.normalize_to_matmul(sch: Schedule, main_block: BlockRV, layout: Optional[List[str]] = None) Optional[Schedule] #