tilelang.language.tile_schedule =============================== .. py:module:: tilelang.language.tile_schedule .. autoapi-nested-parse:: Tile schedulers built on TileLang ``meta_class``. ``@meta_class`` auto-inlines every method that emits a buffer store, so state methods can freely read scalar state via ``self.x[0]`` and write via ``self.x[0] = ...`` (the latter lowers to a ``BufferStore``). Store-free methods (``valid``, ``coord``) stay plain Python and just return ``PrimExpr`` values, so they work in conditions and can be reused statelessly. State lives in ``T.alloc_var`` buffers allocated in ``__init__`` (only when ``stateful``). Works in both eager (``@tilelang.jit``) and lazy (``@T.prim_func``) modes. Note on compile-time branching: the lazy TVMScript parser does not constant-fold a Python ``if`` on a compile-time value inside an inlined method (it would emit a dead TIR branch). So all compile-time decisions (traversal order, clustering) are made in ``__init__`` (plain Python); the inlined hot methods contain only uniform runtime arithmetic. The store-free ``coord`` may use a compile-time ``if`` because, being plain Python, it is evaluated (not traced) per call. Classes ------- .. autoapisummary:: tilelang.language.tile_schedule.BaseTileScheduler tilelang.language.tile_schedule.PersistentTileScheduler Module Contents --------------- .. py:class:: BaseTileScheduler(prefix, stateful = True) Common state and tile-traversal skeleton. State (single-element ``T.alloc_var`` buffers): ``m_idx`` / ``n_idx`` are the current tile coordinates; ``linear_idx`` is the worker's global linear cursor; ``current_iter`` is the 0-based iteration count (how many tiles this worker has advanced past = the "wave" index). ``current_iter`` is the single iteration clock the kernel can read for pipeline/double-buffer state (e.g. ``sched.current_iter[0] & 1``), removing the need for a separate ``for w`` loop counter. Subclasses implement ``update_current_idx`` to decode ``linear_idx`` into ``(m_idx, n_idx)`` and set ``self._total_tiles`` (used by ``valid``). .. py:method:: update_current_idx(linear_idx) :abstractmethod: .. py:method:: init(linear_init) .. py:method:: next_tile(step) .. py:method:: valid() .. py:class:: PersistentTileScheduler(prefix, num_m_tiles, num_n_tiles, num_workers=None, swizzle_size = 1, column_major = True, cluster_size = 1, stateful = True) Bases: :py:obj:`BaseTileScheduler` Persistent, grid-strided tile scheduler with L2 swizzle and clustering. The 2D tile grid (``num_m_tiles`` x ``num_n_tiles``) is flattened into a linear order and distributed across ``num_workers`` persistent workers: worker ``w`` processes linear indices ``w, w + num_workers, ...`` until they run past the total. The linear-to-2D decode is controlled by three independent, interacting factors: 1. ``column_major`` -- traversal order. With ``swizzle_size == 1`` this is pure column-major (``m`` varies fastest) when ``True``, or pure row-major (``n`` varies fastest) when ``False``. 2. ``swizzle_size`` -- L2-locality panel width. The "fast" axis is widened into panels of ``swizzle_size`` tiles along the "slow" axis: a full strip of the fast axis is swept for each group of ``swizzle_size`` slow-axis tiles before advancing. ``swizzle_size == 1`` disables swizzling. This is the CUTLASS-style threadblock swizzle used by the SM100 persistent GEMM examples. Non-divisible tail panels are handled (narrower last panel). 3. ``cluster_size`` -- block clustering along M (x) only. The M tiles are grouped into clusters of ``cluster_size`` rows, and the scheduler runs at *cluster* granularity: it produces the cluster-row in ``m_idx`` over a grid of ``ceildiv(num_m_tiles, cluster_size)`` cluster-rows x ``num_n_tiles`` columns. The caller turns the cluster-row into a real block row by adding the in-cluster rank:: bx = sched.m_idx[0] * cluster_size + cta_rank_in_cluster (For ``cluster_size == 1`` this reduces to ``bx = sched.m_idx[0]``.) Interaction: clustering reshapes the grid to ``M' x N'`` (with ``M' = ceildiv(num_m_tiles, cluster_size)``) *first*; the swizzle and traversal order then operate on that reshaped grid. ``num_workers`` is the number of resident workers (= clusters when ``cluster_size > 1``). :param prefix: Name prefix for the scheduler's state buffers in the generated IR (``{prefix}_m_idx`` / ``{prefix}_n_idx`` / ``{prefix}_linear_idx``). :type prefix: str :param num_m_tiles: Number of tiles along M (``ceildiv(M, block_M)``). :type num_m_tiles: int | PrimExpr :param num_n_tiles: Number of tiles along N (``ceildiv(N, block_N)``). :type num_n_tiles: int | PrimExpr :param num_workers: Persistent stride = number of resident workers/clusters. Defaults to ``driver.get_num_sms() // cluster_size`` (one block per SM, grouped into clusters). :type num_workers: int | PrimExpr, optional :param swizzle_size: L2 swizzle panel width (default ``1`` = no swizzle). :type swizzle_size: int :param column_major: Traversal order (default ``True`` = column-major / M fastest). :type column_major: bool :param cluster_size: Block-cluster size along M only (default ``1`` = no clustering). :type cluster_size: int :param stateful: If ``True`` (default), allocate the ``m_idx`` / ``n_idx`` / ``linear_idx`` / ``current_iter`` state buffers and enable ``init`` / ``next_tile`` / ``valid`` for a ``while`` persistent loop. If ``False``, allocate no state and expose only the pure ``coord(tile_id) -> (m, n)`` decode (for ``for w in range(waves)`` loops and auto warp-specialization, where the loop owns the iteration clock and the WS pass owns the pipeline phase). :type stateful: bool .. rubric:: Examples Plain persistent GEMM (one block per SM):: m_blocks, n_blocks = T.ceildiv(M, block_M), T.ceildiv(N, block_N) with T.Kernel(driver.get_num_sms(), threads=threads) as (block_id,): sched = T.PersistentTileScheduler("sched", m_blocks, n_blocks) sched.init(block_id) while sched.valid(): bx, by = sched.m_idx[0], sched.n_idx[0] # ... compute tile (bx, by) ... sched.next_tile() With L2 swizzle (panel width 8):: sched = T.PersistentTileScheduler( "sched", m_blocks, n_blocks, swizzle_size=8) With a 2-CTA cluster along M (e.g. SM100 2-SM MMA):: sm_num = driver.get_num_sms() cluster_size = 2 with T.ClusterKernel(sm_num, threads=256, cluster_dims=cluster_size) as (block_id): cta_rank = T.block_rank_in_cluster() sched = T.PersistentTileScheduler( "sched", m_blocks, n_blocks, swizzle_size=8, cluster_size=cluster_size) sched.init(block_id // cluster_size) # init with cluster id while sched.valid(): bx = sched.m_idx[0] * cluster_size + cta_rank by = sched.n_idx[0] # ... compute tile (bx, by) ... sched.next_tile() Stateless form (``stateful=False``) for ``for w`` loops / auto-WS, where the scheduler is only a tile-coordinate decoder:: sched = T.PersistentTileScheduler( "sched", m_blocks, n_blocks, swizzle_size=8, stateful=False) for w in range(waves): bx, by = sched.coord(num_workers * w + worker_id) if bx * block_M < M and by * block_N < N: # ... pipelined inner loop (T.Pipelined / T.copy / T.gemm) ... pass Manual warp-specialized kernels use ``current_iter`` as the single iteration clock (no separate ``for w`` loop): each warp role runs its own ``while sched.valid()`` loop and reads ``sched.current_iter[0]`` for pipeline/double-buffer state (``sched.current_iter[0] & 1`` etc.) while reading ``sched.m_idx[0]`` / ``sched.n_idx[0]`` for the tile:: sched.init(block_id) while sched.valid(): bx, by = sched.m_idx[0], sched.n_idx[0] # ... use sched.current_iter[0] for barrier phase / double-buffering ... sched.next_tile() .. rubric:: Notes State is held in single-element ``T.alloc_var`` buffers, so read with ``sched.m_idx[0]`` / ``sched.current_iter[0]`` and (inside methods) write with ``self.x[0] = ...``; the ``[0]`` is required for the write to lower to a ``BufferStore``. .. py:attribute:: cluster_size :value: 1 .. py:attribute:: num_workers :value: None .. py:attribute:: swizzle_size :value: 1 .. py:method:: coord(tile_id) Decode a linear ``tile_id`` into ``(m_idx, n_idx)`` (pure PrimExpr). Stateless: builds expressions only, touches no state buffer, so it is usable on a ``stateful=False`` instance and can be called independently from any loop / warp role. ``m_idx`` is the cluster row when ``cluster_size > 1`` (caller adds the in-cluster rank). .. py:method:: update_current_idx(linear_idx) .. py:method:: init(worker_id) .. py:method:: next_tile()