tilelang.carver.roller.shape_inference.tir¶

Classes¶

Functions¶

region_exist_in_list(a, list)

walk_indice(expr)

get_analyzer_by_tir(block_analyzer, args)

Module Contents¶

class tilelang.carver.roller.shape_inference.tir.Statement(block_analyzer, block)¶
Parameters:

block (tvm.tir.schedule.schedule.BlockRV)

block_analyzer¶
block¶
dep_name¶
dependent_region¶
reverse_bound_inference¶
make_reverse(input_name, input_iter)¶
Parameters:
  • input_name (str)

  • input_iter (List[tvm.tir.PrimExpr])

class tilelang.carver.roller.shape_inference.tir.TensorDepNode(name)¶

Bases: object

For tensor dependency analysis.

name¶
add_next(node)¶
add_prev(node)¶
deduplicate(lst)¶
__str__()¶
__repr__()¶
class tilelang.carver.roller.shape_inference.tir.DependencyAnalysis(deps)¶

Bases: object

deps¶
name2dep¶
mapping¶
get_or_create_node(name)¶
traverse_dependencies(compute)¶
analyze()¶
print_dependencies()¶
find_path_from_source(start_name, target_name)¶

Finds the path (if it exists) from a starting node (source) to a target node. Returns the path as a list of nodes.

class tilelang.carver.roller.shape_inference.tir.InputShapeInference(deps)¶
Parameters:

deps (List[Statement])

deps¶
target_mapping¶
buffer_mapping¶
reduce_axes = []¶
dep_analysis¶
construct_dependency_target(targets)¶
Parameters:

targets (Tuple[str])

infer(shape, rstep=None, targets=None)¶
Parameters:
  • shape (Dict[str, List[tvm.arith.ConstIntBound]])

  • rstep (Dict[str, int])

get_input_exprs(output_exprs)¶
tilelang.carver.roller.shape_inference.tir.region_exist_in_list(a, list)¶
Return type:

bool

tilelang.carver.roller.shape_inference.tir.walk_indice(expr)¶
tilelang.carver.roller.shape_inference.tir.get_analyzer_by_tir(block_analyzer, args)¶
Return type:

InputShapeInference