tilelang.language.loopΒΆ
Loop related language interfaces in TileLang.
FunctionsΒΆ
|
Tools to construct nested parallel for loop. |
|
Tools to construct persistent for loop. |
|
Tools to construct pipelined for loop. |
|
The serial For statement. |
|
The unrolled For statement. |
|
Alias of T.serial. |
|
Alias of T.unroll. |
Module ContentsΒΆ
- tilelang.language.loop.Parallel(*extents, coalesced_width=None, loop_layout=None)ΒΆ
- Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
- Parameters:
extents (PrimExpr) β The extents of the iteration.
coalesced_width (Optional[int]) β The coalesced width of the parallel loop.
loop_layout (to the outermost generated loop only. If you omit) β A layout annotation for the parallel loop nest, expressed as a
T.Fragment. When provided, it is attached as the"parallel_loop_layout"annotation on the outermost parallel loop. For a k-dimensionalT.Parallel(...)nest, the fragmentβsInputDimmust equalk.constraints (Notes on layout)
---------------------------
during (TileLang validates parallel loop layout annotations)
ParallelLoopLayoutValidator. (tl.transform.LayoutInference with)
are (The key constraints)
after (- Every parallel loop must be covered by a layout annotation) β layout inference. For a nested parallel nest, this annotation must live on the outermost loop; inner parallel loops must not carry the layout annotation themselves.
k (- For a nest depth of) β
InputDim == k.satisfy (the layout must) β
InputDim == k.loop (- Violations (missing annotation on the outermost) β inner loops, or mismatched
InputDim) cause a compilation error.on (outermost loop can manage its inner nest. Therefore the layout is placed) β inner loops, or mismatched
InputDim) cause a compilation error.Rationale (inner loops cannot control/annotate their outer loops, while the)
on
region. (the outermost loop so lowering passes can rewrite the entire)
easy (To make this)
loop_layout
loop_layout
the (compiler will try to infer a valid layout and attach it during)
the
pass. (LayoutInference)
- Returns:
res β The ForFrame.
- Return type:
frame.ForFrame
- tilelang.language.loop.Persistent(domain, wave_size, index, group_size=8)ΒΆ
Tools to construct persistent for loop.
- Parameters:
domain (List[tir.PrimExpr]) β The list of dominators.
wave_size (int) β The wave size.
index (int) β The tile index in one wave.
group_size (tir.PrimExpr) β The group size.
- tilelang.language.loop.Pipelined(start, stop=None, num_stages=0, order=None, stage=None, sync=None, group=None)ΒΆ
Tools to construct pipelined for loop.
- Parameters:
start (PrimExpr) β The minimum value of iteration.
stop (PrimExpr) β The maximum value of iteration.
num_stages (int) β The max number of buffer used between pipeline producers and consumers. if num_stages is 0, pipeline will not be enabled.
order (list[int] | None)
stage (list[int] | None)
sync (list[list[int]] | None)
group (list[list[int]] | None)
- Returns:
res β The ForFrame.
- Return type:
frame.ForFrame
- tilelang.language.loop.serial(start, stop=None, step=None, *, annotations=None)ΒΆ
The serial For statement.
- Parameters:
start (PrimExpr) β The minimum value of iteration.
stop (PrimExpr) β The maximum value of iteration.
step (PrimExpr) β The step size of the iteration.
annotations (Dict[str, Any]) β The optional annotations of the For statement.
- Returns:
res β The ForFrame.
- Return type:
frame.ForFrame
- tilelang.language.loop.unroll(start, stop=None, step=None, *, explicit=False, unroll_factor=None, annotations=None)ΒΆ
The unrolled For statement.
- Parameters:
start (PrimExpr) β The minimum value of iteration.
stop (PrimExpr) β The maximum value of iteration.
step (PrimExpr) β The step size of the iteration.
explicit (bool) β Whether to explicitly unroll the loop.
unroll_factor (int) β The unroll factor of the loop.
annotations (Dict[str, Any]) β The optional annotations of the For statement.
- Returns:
res β The ForFrame.
- Return type:
frame.ForFrame
- tilelang.language.loop.Serial(start, stop=None, step=None, *, annotations=None)ΒΆ
Alias of T.serial.
- Parameters:
start (tvm.tir.PrimExpr)
stop (tvm.tir.PrimExpr | None)
step (tvm.tir.PrimExpr | None)
annotations (dict[str, Any] | None)
- tilelang.language.loop.Unroll(start, stop=None, step=None, *, explicit=False, unroll_factor=None, annotations=None)ΒΆ
Alias of T.unroll.
- Parameters:
start (tvm.tir.PrimExpr)
stop (tvm.tir.PrimExpr | None)
step (tvm.tir.PrimExpr | None)
explicit (bool)
unroll_factor (int | None)
annotations (dict[str, Any] | None)