tilelang.carver.matmul_analysis =============================== .. py:module:: tilelang.carver.matmul_analysis .. autoapi-nested-parse:: A GEMM schedule rule for GPU operators. Attributes ---------- .. autoapisummary:: tilelang.carver.matmul_analysis.logger Classes ------- .. autoapisummary:: tilelang.carver.matmul_analysis.IterKind tilelang.carver.matmul_analysis.IterTrait Functions --------- .. autoapisummary:: tilelang.carver.matmul_analysis.collect_vars_from_expr tilelang.carver.matmul_analysis.auto_inline_producers tilelang.carver.matmul_analysis.auto_inline_consumers tilelang.carver.matmul_analysis.auto_inline_consumer_chain tilelang.carver.matmul_analysis.find_first_similar_region tilelang.carver.matmul_analysis.find_first_similar_buffer tilelang.carver.matmul_analysis.find_last_producer_from_buffer tilelang.carver.matmul_analysis.find_arg_idx_from_buffer_chain tilelang.carver.matmul_analysis.make_iter_fusion_index_map tilelang.carver.matmul_analysis.detect_iter_traits tilelang.carver.matmul_analysis.get_index_map tilelang.carver.matmul_analysis.get_in_out_dtypes tilelang.carver.matmul_analysis.get_dequantize_block tilelang.carver.matmul_analysis.is_identity_or_transpose_block tilelang.carver.matmul_analysis.is_identity_block tilelang.carver.matmul_analysis.is_transpose_block tilelang.carver.matmul_analysis.inline_transpose_block tilelang.carver.matmul_analysis.normalize_to_matmul tilelang.carver.matmul_analysis.get_tensorized_func_and_tags tilelang.carver.matmul_analysis.get_propagate_map tilelang.carver.matmul_analysis.get_ladder_stage3_map tilelang.carver.matmul_analysis.layout_propagate_chain Module Contents --------------- .. py:data:: logger .. py:function:: collect_vars_from_expr(prim_expr) .. py:function:: auto_inline_producers(sch, block, skip_blocks = None) .. py:function:: auto_inline_consumers(sch, block) .. py:function:: auto_inline_consumer_chain(sch, block) .. py:function:: find_first_similar_region(regions, buffer) .. py:function:: find_first_similar_buffer(regions, buffer) .. py:function:: find_last_producer_from_buffer(sch, main_block, buffer) .. py:function:: find_arg_idx_from_buffer_chain(sch, main_block, buffer) traverse to find the arg index from the buffer .. py:class:: IterKind Bases: :py:obj:`enum.Enum` Iter kinds for GEMM-liked programs. We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], where `I, J, K` are fundamental axes for gemm and `S` represents all other spatial axes (e.g. batches) kIter_S: spatial axes kIter_I: I axes kIter_J: J axes kIter_K: K axes kIter_T: trivial axes (i.e. with extent 1) .. py:attribute:: kIter_S :value: 0 .. py:attribute:: kIter_I :value: 1 .. py:attribute:: kIter_J :value: 2 .. py:attribute:: kIter_K :value: 3 .. py:attribute:: kIter_T :value: 4 .. py:class:: IterTrait .. py:attribute:: kind :type: IterKind .. py:attribute:: extent :type: tvm.tir.PrimExpr .. py:function:: make_iter_fusion_index_map(traits, kind_order) .. py:function:: detect_iter_traits(block) Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] :param block: The block to be analyzed :type block: tir.Block :returns: **traits** -- The detected iter traits for axes in A, B and C. None if the block does not match the pattern. :rtype: Optional[Tuple[List[IterTrait]]] .. py:function:: get_index_map(block, layout = None) Get index maps for the block :param block: The block to be analyzed :type block: tir.Block :param layout: the target layout index map to be used. 'n' for [i, k] layout 't' for [k, j] layout 'a' for auto inference based on whether the last axis is reduction. :type layout: List[str] :returns: **index_maps** -- The index maps for the block, or None if the block is not a gemm-liked kernel :rtype: Optional[Tuple[tir.IndexMap]] .. py:function:: get_in_out_dtypes(block) Detect In/Out data types for the given block based on the analysis if read/write buffers. .. py:function:: get_dequantize_block(sch, blocks) .. py:function:: is_identity_or_transpose_block(block_stmt) .. py:function:: is_identity_block(block_stmt) .. py:function:: is_transpose_block(block_stmt) .. py:function:: inline_transpose_block(sch, blocks) .. py:function:: normalize_to_matmul(sch, main_block, layout = None) .. py:function:: get_tensorized_func_and_tags(func, target, layout = None, skip_normalize = False, allow_gemv = False) transform function to matmul if necessary (e.g. transform conv2d with im2col) .. py:function:: get_propagate_map(trans = True, dtype='float16', matrix_name='A', index_dtype='int32') .. py:function:: get_ladder_stage3_map(dtype='float16', index_dtype='int32') .. py:function:: layout_propagate_chain(sch, start_block, start_buffer, end_block, index_map)