tilelang.carver.matmul_analysis¶

A GEMM schedule rule for GPU operators.

Attributes¶

Classes¶

IterKind

Iter kinds for GEMM-liked programs.

IterTrait

Functions¶

collect_vars_from_expr(prim_expr)

auto_inline_producers(sch, block[, skip_blocks])

auto_inline_consumers(sch, block)

auto_inline_consumer_chain(sch, block)

find_first_similar_region(regions, buffer)

find_first_similar_buffer(regions, buffer)

find_last_producer_from_buffer(sch, main_block, buffer)

find_arg_idx_from_buffer_chain(sch, main_block, buffer)

traverse to find the arg index from the buffer

make_iter_fusion_index_map(traits, kind_order)

detect_iter_traits(block)

Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]

get_index_map(block[, layout])

Get index maps for the block

get_in_out_dtypes(block)

Detect In/Out data types for the given block based on the analysis if read/write buffers.

get_dequantize_block(sch, blocks)

is_identity_or_transpose_block(block_stmt)

is_identity_block(block_stmt)

is_transpose_block(block_stmt)

inline_transpose_block(sch, blocks)

normalize_to_matmul(sch, main_block[, layout])

get_tensorized_func_and_tags(func, target[, layout, ...])

transform function to matmul if necessary (e.g. transform conv2d with im2col)

get_propagate_map([trans, dtype, matrix_name, index_dtype])

get_ladder_stage3_map([dtype, index_dtype])

layout_propagate_chain(sch, start_block, start_buffer, ...)

Module Contents¶

tilelang.carver.matmul_analysis.logger¶
tilelang.carver.matmul_analysis.collect_vars_from_expr(prim_expr)¶
tilelang.carver.matmul_analysis.auto_inline_producers(sch, block, skip_blocks=None)¶
Parameters:
  • sch (tvm.tir.Schedule)

  • block (tvm.tir.schedule.BlockRV)

  • skip_blocks (Optional[List[tvm.tir.schedule.BlockRV]])

tilelang.carver.matmul_analysis.auto_inline_consumers(sch, block)¶
Parameters:
  • sch (tvm.tir.Schedule)

  • block (tvm.tir.schedule.BlockRV)

tilelang.carver.matmul_analysis.auto_inline_consumer_chain(sch, block)¶
Parameters:
  • sch (tvm.tir.Schedule)

  • block (tvm.tir.schedule.BlockRV)

tilelang.carver.matmul_analysis.find_first_similar_region(regions, buffer)¶
Parameters:
  • regions (List[tvm.tir.BufferRegion])

  • buffer (tvm.tir.Buffer)

tilelang.carver.matmul_analysis.find_first_similar_buffer(regions, buffer)¶
Parameters:
  • regions (List[tvm.tir.BufferRegion])

  • buffer (tvm.tir.Buffer)

tilelang.carver.matmul_analysis.find_last_producer_from_buffer(sch, main_block, buffer)¶
Parameters:

buffer (tvm.tir.Buffer)

Return type:

Optional[tvm.tir.schedule.schedule.BlockRV]

tilelang.carver.matmul_analysis.find_arg_idx_from_buffer_chain(sch, main_block, buffer)¶

traverse to find the arg index from the buffer

Parameters:
  • sch (tvm.tir.Schedule)

  • main_block (tvm.tir.schedule.BlockRV)

  • buffer (tvm.tir.Buffer)

Return type:

int

class tilelang.carver.matmul_analysis.IterKind¶

Bases: enum.Enum

Iter kinds for GEMM-liked programs. We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], where I, J, K are fundamental axes for gemm and S represents all other spatial axes (e.g. batches) kIter_S: spatial axes kIter_I: I axes kIter_J: J axes kIter_K: K axes kIter_T: trivial axes (i.e. with extent 1)

kIter_S = 0¶
kIter_I = 1¶
kIter_J = 2¶
kIter_K = 3¶
kIter_T = 4¶
class tilelang.carver.matmul_analysis.IterTrait¶
kind: IterKind¶
extent: tvm.tir.PrimExpr¶
tilelang.carver.matmul_analysis.make_iter_fusion_index_map(traits, kind_order)¶
Parameters:
Return type:

tvm.tir.IndexMap

tilelang.carver.matmul_analysis.detect_iter_traits(block)¶

Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]

Parameters:

block (tir.Block) – The block to be analyzed

Returns:

traits – The detected iter traits for axes in A, B and C. None if the block does not match the pattern.

Return type:

Optional[Tuple[List[IterTrait]]]

tilelang.carver.matmul_analysis.get_index_map(block, layout=None)¶

Get index maps for the block

Parameters:
  • block (tir.Block) – The block to be analyzed

  • layout (List[str]) – the target layout index map to be used. ‘n’ for [i, k] layout ‘t’ for [k, j] layout ‘a’ for auto inference based on whether the last axis is reduction.

Returns:

index_maps – The index maps for the block, or None if the block is not a gemm-liked kernel

Return type:

Optional[Tuple[tir.IndexMap]]

tilelang.carver.matmul_analysis.get_in_out_dtypes(block)¶

Detect In/Out data types for the given block based on the analysis if read/write buffers.

Parameters:

block (tvm.tir.Block)

Return type:

Tuple[str]

tilelang.carver.matmul_analysis.get_dequantize_block(sch, blocks)¶
Return type:

Optional[tvm.tir.schedule.schedule.BlockRV]

tilelang.carver.matmul_analysis.is_identity_or_transpose_block(block_stmt)¶
Parameters:

block_stmt (tvm.tir.Block)

Return type:

bool

tilelang.carver.matmul_analysis.is_identity_block(block_stmt)¶
Parameters:

block_stmt (tvm.tir.Block)

Return type:

bool

tilelang.carver.matmul_analysis.is_transpose_block(block_stmt)¶
Parameters:

block_stmt (tvm.tir.Block)

Return type:

bool

tilelang.carver.matmul_analysis.inline_transpose_block(sch, blocks)¶
Parameters:
  • sch (tvm.tir.Schedule)

  • blocks (List[tvm.tir.schedule.BlockRV])

tilelang.carver.matmul_analysis.normalize_to_matmul(sch, main_block, layout=None)¶
Parameters:
  • sch (tvm.tir.Schedule)

  • main_block (tvm.tir.schedule.schedule.BlockRV)

  • layout (Optional[List[str]])

Return type:

Optional[tvm.tir.Schedule]

tilelang.carver.matmul_analysis.get_tensorized_func_and_tags(func, target, layout=None, skip_normalize=False, allow_gemv=False)¶

transform function to matmul if necessary (e.g. transform conv2d with im2col)

Parameters:
  • func (tvm.tir.PrimFunc)

  • target (tvm.target.target.Target)

  • layout (Optional[List[str]])

  • skip_normalize (bool)

  • allow_gemv (bool)

Return type:

Tuple[tvm.tir.PrimFunc, Dict[str, Union[List[int], int]]]

tilelang.carver.matmul_analysis.get_propagate_map(trans=True, dtype='float16', matrix_name='A', index_dtype='int32')¶
Parameters:

trans (bool)

tilelang.carver.matmul_analysis.get_ladder_stage3_map(dtype='float16', index_dtype='int32')¶
tilelang.carver.matmul_analysis.layout_propagate_chain(sch, start_block, start_buffer, end_block, index_map)¶
Parameters:
  • sch (tvm.tir.Schedule)

  • start_block (tvm.tir.schedule.schedule.BlockRV)

  • start_buffer (tvm.tir.Buffer)

  • end_block (tvm.tir.schedule.schedule.BlockRV)

  • index_map (tvm.tir.IndexMap)