tilelang.language.tile_schedule¶
Tile schedulers built on TileLang meta_class.
@meta_class auto-inlines every method that emits a buffer store, so state
methods can freely read scalar state via self.x[0] and write via
self.x[0] = ... (the latter lowers to a BufferStore). Store-free methods
(valid, coord) stay plain Python and just return PrimExpr values, so
they work in conditions and can be reused statelessly. State lives in
T.alloc_var buffers allocated in __init__ (only when stateful). Works
in both eager (@tilelang.jit) and lazy (@T.prim_func) modes.
Note on compile-time branching: the lazy TVMScript parser does not constant-fold
a Python if on a compile-time value inside an inlined method (it would emit a
dead TIR branch). So all compile-time decisions (traversal order, clustering)
are made in __init__ (plain Python); the inlined hot methods contain only
uniform runtime arithmetic. The store-free coord may use a compile-time
if because, being plain Python, it is evaluated (not traced) per call.
Classes¶
Common state and tile-traversal skeleton. |
|
Persistent, grid-strided tile scheduler with L2 swizzle and clustering. |
Module Contents¶
- class tilelang.language.tile_schedule.BaseTileScheduler(prefix, stateful=True)¶
Common state and tile-traversal skeleton.
State (single-element
T.alloc_varbuffers):m_idx/n_idxare the current tile coordinates;linear_idxis the worker’s global linear cursor;current_iteris the 0-based iteration count (how many tiles this worker has advanced past = the “wave” index).current_iteris the single iteration clock the kernel can read for pipeline/double-buffer state (e.g.sched.current_iter[0] & 1), removing the need for a separatefor wloop counter. Subclasses implementupdate_current_idxto decodelinear_idxinto(m_idx, n_idx)and setself._total_tiles(used byvalid).- Parameters:
prefix (str)
stateful (bool)
- abstract update_current_idx(linear_idx)¶
- init(linear_init)¶
- next_tile(step)¶
- valid()¶
- class tilelang.language.tile_schedule.PersistentTileScheduler(prefix, num_m_tiles, num_n_tiles, num_workers=None, swizzle_size=1, column_major=True, cluster_size=1, stateful=True)¶
Bases:
BaseTileSchedulerPersistent, grid-strided tile scheduler with L2 swizzle and clustering.
The 2D tile grid (
num_m_tilesxnum_n_tiles) is flattened into a linear order and distributed acrossnum_workerspersistent workers: workerwprocesses linear indicesw, w + num_workers, ...until they run past the total. The linear-to-2D decode is controlled by three independent, interacting factors:column_major– traversal order. Withswizzle_size == 1this is pure column-major (mvaries fastest) whenTrue, or pure row-major (nvaries fastest) whenFalse.swizzle_size– L2-locality panel width. The “fast” axis is widened into panels ofswizzle_sizetiles along the “slow” axis: a full strip of the fast axis is swept for each group ofswizzle_sizeslow-axis tiles before advancing.swizzle_size == 1disables swizzling. This is the CUTLASS-style threadblock swizzle used by the SM100 persistent GEMM examples. Non-divisible tail panels are handled (narrower last panel).cluster_size– block clustering along M (x) only. The M tiles are grouped into clusters ofcluster_sizerows, and the scheduler runs at cluster granularity: it produces the cluster-row inm_idxover a grid ofceildiv(num_m_tiles, cluster_size)cluster-rows xnum_n_tilescolumns. The caller turns the cluster-row into a real block row by adding the in-cluster rank:bx = sched.m_idx[0] * cluster_size + cta_rank_in_cluster
(For
cluster_size == 1this reduces tobx = sched.m_idx[0].)
Interaction: clustering reshapes the grid to
M' x N'(withM' = ceildiv(num_m_tiles, cluster_size)) first; the swizzle and traversal order then operate on that reshaped grid.num_workersis the number of resident workers (= clusters whencluster_size > 1).- Parameters:
prefix (str) – Name prefix for the scheduler’s state buffers in the generated IR (
{prefix}_m_idx/{prefix}_n_idx/{prefix}_linear_idx).num_m_tiles (int | PrimExpr) – Number of tiles along M (
ceildiv(M, block_M)).num_n_tiles (int | PrimExpr) – Number of tiles along N (
ceildiv(N, block_N)).num_workers (int | PrimExpr, optional) – Persistent stride = number of resident workers/clusters. Defaults to
driver.get_num_sms() // cluster_size(one block per SM, grouped into clusters).swizzle_size (int) – L2 swizzle panel width (default
1= no swizzle).column_major (bool) – Traversal order (default
True= column-major / M fastest).cluster_size (int) – Block-cluster size along M only (default
1= no clustering).stateful (bool) – If
True(default), allocate them_idx/n_idx/linear_idx/current_iterstate buffers and enableinit/next_tile/validfor awhilepersistent loop. IfFalse, allocate no state and expose only the purecoord(tile_id) -> (m, n)decode (forfor w in range(waves)loops and auto warp-specialization, where the loop owns the iteration clock and the WS pass owns the pipeline phase).
Examples
Plain persistent GEMM (one block per SM):
m_blocks, n_blocks = T.ceildiv(M, block_M), T.ceildiv(N, block_N) with T.Kernel(driver.get_num_sms(), threads=threads) as (block_id,): sched = T.PersistentTileScheduler("sched", m_blocks, n_blocks) sched.init(block_id) while sched.valid(): bx, by = sched.m_idx[0], sched.n_idx[0] # ... compute tile (bx, by) ... sched.next_tile()
With L2 swizzle (panel width 8):
sched = T.PersistentTileScheduler( "sched", m_blocks, n_blocks, swizzle_size=8)
With a 2-CTA cluster along M (e.g. SM100 2-SM MMA):
sm_num = driver.get_num_sms() cluster_size = 2 with T.ClusterKernel(sm_num, threads=256, cluster_dims=cluster_size) as (block_id): cta_rank = T.block_rank_in_cluster() sched = T.PersistentTileScheduler( "sched", m_blocks, n_blocks, swizzle_size=8, cluster_size=cluster_size) sched.init(block_id // cluster_size) # init with cluster id while sched.valid(): bx = sched.m_idx[0] * cluster_size + cta_rank by = sched.n_idx[0] # ... compute tile (bx, by) ... sched.next_tile()
Stateless form (
stateful=False) forfor wloops / auto-WS, where the scheduler is only a tile-coordinate decoder:sched = T.PersistentTileScheduler( "sched", m_blocks, n_blocks, swizzle_size=8, stateful=False) for w in range(waves): bx, by = sched.coord(num_workers * w + worker_id) if bx * block_M < M and by * block_N < N: # ... pipelined inner loop (T.Pipelined / T.copy / T.gemm) ... pass
Manual warp-specialized kernels use
current_iteras the single iteration clock (no separatefor wloop): each warp role runs its ownwhile sched.valid()loop and readssched.current_iter[0]for pipeline/double-buffer state (sched.current_iter[0] & 1etc.) while readingsched.m_idx[0]/sched.n_idx[0]for the tile:sched.init(block_id) while sched.valid(): bx, by = sched.m_idx[0], sched.n_idx[0] # ... use sched.current_iter[0] for barrier phase / double-buffering ... sched.next_tile()
Notes
State is held in single-element
T.alloc_varbuffers, so read withsched.m_idx[0]/sched.current_iter[0]and (inside methods) write withself.x[0] = ...; the[0]is required for the write to lower to aBufferStore.- cluster_size = 1¶
- num_workers = None¶
- swizzle_size = 1¶
- coord(tile_id)¶
Decode a linear
tile_idinto(m_idx, n_idx)(pure PrimExpr).Stateless: builds expressions only, touches no state buffer, so it is usable on a
stateful=Falseinstance and can be called independently from any loop / warp role.m_idxis the cluster row whencluster_size > 1(caller adds the in-cluster rank).
- update_current_idx(linear_idx)¶
- init(worker_id)¶
- next_tile()¶