tilelang.carver.roller.shape_inference.tir ========================================== .. py:module:: tilelang.carver.roller.shape_inference.tir Classes ------- .. autoapisummary:: tilelang.carver.roller.shape_inference.tir.Statement tilelang.carver.roller.shape_inference.tir.TensorDepNode tilelang.carver.roller.shape_inference.tir.DependencyAnalysis tilelang.carver.roller.shape_inference.tir.InputShapeInference Functions --------- .. autoapisummary:: tilelang.carver.roller.shape_inference.tir.region_exist_in_list tilelang.carver.roller.shape_inference.tir.walk_indice tilelang.carver.roller.shape_inference.tir.get_analyzer_by_tir Module Contents --------------- .. py:class:: Statement(block_analyzer, block) .. py:attribute:: block_analyzer .. py:attribute:: block .. py:attribute:: dep_name .. py:attribute:: dependent_region .. py:attribute:: reverse_bound_inference .. py:method:: make_reverse(input_name, input_iter) .. py:class:: TensorDepNode(name) Bases: :py:obj:`object` For tensor dependency analysis. .. py:attribute:: name .. py:method:: add_next(node) .. py:method:: add_prev(node) .. py:method:: deduplicate(lst) .. py:method:: __str__() .. py:method:: __repr__() .. py:class:: DependencyAnalysis(deps) Bases: :py:obj:`object` .. py:attribute:: deps .. py:attribute:: name2dep .. py:attribute:: mapping .. py:method:: get_or_create_node(name) .. py:method:: traverse_dependencies(compute) .. py:method:: analyze() .. py:method:: print_dependencies() .. py:method:: find_path_from_source(start_name, target_name) Finds the path (if it exists) from a starting node (source) to a target node. Returns the path as a list of nodes. .. py:class:: InputShapeInference(deps) .. py:attribute:: deps .. py:attribute:: target_mapping .. py:attribute:: buffer_mapping .. py:attribute:: reduce_axes :value: [] .. py:attribute:: dep_analysis .. py:method:: construct_dependency_target(targets) .. py:method:: infer(shape, rstep = None, targets=None) .. py:method:: get_input_exprs(output_exprs) .. py:function:: region_exist_in_list(a, list) .. py:function:: walk_indice(expr) .. py:function:: get_analyzer_by_tir(block_analyzer, args)