tilelang.carver.roller.shape_inference.tir module#

class tilelang.carver.roller.shape_inference.tir.DependencyAnalysis(deps)#

Bases: object

analyze()#
find_path_from_source(start_name, target_name)#

Finds the path (if it exists) from a starting node (source) to a target node. Returns the path as a list of nodes.

get_or_create_node(name)#
print_dependencies()#
traverse_dependencies(compute)#
class tilelang.carver.roller.shape_inference.tir.InputShapeInference(deps: List[Statement])#

Bases: object

construct_dependency_target(targets: Tuple[str])#
get_input_exprs(output_exprs)#
infer(shape: Dict[str, List[ConstIntBound]], rstep: Optional[Dict[str, int]] = None, targets=None)#
class tilelang.carver.roller.shape_inference.tir.Statement(block_analyzer, block: BlockRV)#

Bases: object

make_reverse(input_name: str, input_iter: List[PrimExpr])#
class tilelang.carver.roller.shape_inference.tir.TensorDepNode(name)#

Bases: object

For tensor dependency analysis.

add_next(node)#
add_prev(node)#
deduplicate(lst)#
tilelang.carver.roller.shape_inference.tir.get_analyzer_by_tir(block_analyzer, args) InputShapeInference#
tilelang.carver.roller.shape_inference.tir.region_exist_in_list(a, list) bool#
tilelang.carver.roller.shape_inference.tir.walk_indice(expr)#