tilelang.carver.roller.shape_inference.tir module#
- class tilelang.carver.roller.shape_inference.tir.DependencyAnalysis(deps)#
Bases:
object
- analyze()#
- find_path_from_source(start_name, target_name)#
Finds the path (if it exists) from a starting node (source) to a target node. Returns the path as a list of nodes.
- get_or_create_node(name)#
- print_dependencies()#
- traverse_dependencies(compute)#
- class tilelang.carver.roller.shape_inference.tir.InputShapeInference(deps: List[Statement])#
Bases:
object
- construct_dependency_target(targets: Tuple[str])#
- get_input_exprs(output_exprs)#
- infer(shape: Dict[str, List[ConstIntBound]], rstep: Optional[Dict[str, int]] = None, targets=None)#
- class tilelang.carver.roller.shape_inference.tir.Statement(block_analyzer, block: BlockRV)#
Bases:
object
- make_reverse(input_name: str, input_iter: List[PrimExpr])#
- class tilelang.carver.roller.shape_inference.tir.TensorDepNode(name)#
Bases:
object
For tensor dependency analysis.
- add_next(node)#
- add_prev(node)#
- deduplicate(lst)#
- tilelang.carver.roller.shape_inference.tir.get_analyzer_by_tir(block_analyzer, args) InputShapeInference #
- tilelang.carver.roller.shape_inference.tir.region_exist_in_list(a, list) bool #
- tilelang.carver.roller.shape_inference.tir.walk_indice(expr)#