tilelang.transform¶

Wrapping transformations.

Submodules¶

Functions¶

get_pass_context()

Get the current pass context

ClusterPlanning()

ClusterPlanning

PipelinePlanning()

infer the fragment/shared memory layout

LayoutInference()

LayoutInference

LowerTileOp()

LowerTileOp

InjectSoftwarePipeline()

InjectSoftwarePipeline

OptimizeCPAsyncSync()

Optimize explicit cp.async commit/wait synchronization intrinsics.

FrontendLegalize()

FrontendLegalize

LegalizeNegativeIndex()

Legalize negative indices in buffer loads.

InjectAssumes()

Inject Assumes for natural shape boundary conditions. And convert Assumes in Evaluate(Call(...)) form

VerifyParallelLoop()

VerifyParallelLoop

LowerHopperIntrin()

LowerHopperIntrin

WarpSpecializedPipeline()

WarpSpecializedPipeline

RewriteWgmmaSync()

RewriteWgmmaSync

ThreadSync(storage_scope)

Insert sync between parallel read/write of shared buffers.

ThreadPartialSync(storage_scope)

Insert partial sync.

IfStmtBinding()

IfStmtBinding

MergeIfStmt()

MergeIfStmt

LoopUnswitching()

LoopUnswitching: Hoist loop-invariant if statements out of loops.

MultiVersionBuffer()

WarpSpecializedPipeline

ProducerConsumerWarpSpecialized()

Producer-Consumer Warp Specialization for TMA pipelines.

AnnotateWarpGroupRegAlloc()

Inject set_max_nreg calls into warp-specialized functions.

FuseMBarrierArriveExpectTx()

Fuse simple expect_tx -> TMA issue -> arrive back into arrive_and_expect_tx.

InjectFenceProxy()

InjectFenceProxy

LegalizeVectorizedLoop()

LegalizeLoopVectorize

LegalizeSafeMemoryAccess()

LegalizeLoopVectorize

LowerAccessPtr()

Lower TileLang frontend tl.access_ptr to tir.builtin.tvm_access_ptr.

MakePackedAPI()

MakePackedAPI

AnnotateDeviceRegions()

AnnotateDeviceRegions

SplitHostDevice()

Split host/device functions even for empty kernels.

AnnotateReadOnlyParams()

Annotate read-only handle parameters for PrimFuncs.

VectorizeLoop([enable_vectorize])

VectorizeLoop

LowerPTXAsyncCopy()

Lower eligible global->shared copies into PTX cp.async on CUDA.

InjectPTXAsyncCopy()

Deprecated alias of LowerPTXAsyncCopy.

LowerDeviceStorageAccessInfo()

Lower attached storage access information on device.

ConfigIndexBitwidth()

Config index bitwidth.

FlattenBuffer()

FlattenBuffer

EliminateStorageSyncForMBarrier()

EliminateStorageSyncForMBarrier

MergeSharedMemoryAllocations([...])

MergeSharedMemoryAllocations

LowerL2Persistent()

LowerL2Persistent

MarkCudaSyncCalls([have_pdl])

MarkCudaSyncCalls

PersistThreadblock()

PersistThreadblock

AlignDynamicSharedMemoryAllocations([align_bytes])

AlignDynamicSharedMemoryAllocations

LowerSharedBarrier()

LowerSharedBarrier

PlanAndUpdateBufferAllocationLocation()

Plan and update buffer allocation locations within PrimFuncs.

HoistNonRestrictParams()

StorageRewrite()

StorageRewrite

LowerOpaqueBlock()

LowerOpaqueBlock

LowerThreadAllreduce()

LowerThreadAllreduce

LowerIntrin()

LowerIntrin

LowerDeviceKernelLaunch()

Create and return a transform pass that lowers device kernel launch constructs to target-specific IR.

LowerSharedTmem()

LowerSharedTmem

LayoutReducer()

Return a TVM transform pass that performs layout reduction/normalization.

UnrollLoop()

Unroll loops as in Halide pipeline.

LowerLDGSTG()

Lower Ramp-based global memory load/store to ldg/stg intrinsics.

Package Contents¶

tilelang.transform.get_pass_context()¶

Get the current pass context

tilelang.transform.ClusterPlanning()¶

ClusterPlanning

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.PipelinePlanning()¶

infer the fragment/shared memory layout

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LayoutInference()¶

LayoutInference

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LowerTileOp()¶

LowerTileOp

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.InjectSoftwarePipeline()¶

InjectSoftwarePipeline

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.OptimizeCPAsyncSync()¶

Optimize explicit cp.async commit/wait synchronization intrinsics.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.FrontendLegalize()¶

FrontendLegalize

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LegalizeNegativeIndex()¶

Legalize negative indices in buffer loads.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.InjectAssumes()¶

Inject Assumes for natural shape boundary conditions. And convert Assumes in Evaluate(Call(…)) form (tvm builtin assume call) to AttrNode form.

Returns:¶

fpasstvm.transform.Pass

The result pass

tilelang.transform.VerifyParallelLoop()¶

VerifyParallelLoop

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LowerHopperIntrin()¶

LowerHopperIntrin

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.WarpSpecializedPipeline()¶

WarpSpecializedPipeline

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.RewriteWgmmaSync()¶

RewriteWgmmaSync

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.ThreadSync(storage_scope)¶

Insert sync between parallel read/write of shared buffers.

Parameters:

storage_scope (str) – The target storage scope.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.ThreadPartialSync(storage_scope)¶

Insert partial sync.

Parameters:

storage_scope (str) – The target storage scope.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.IfStmtBinding()¶

IfStmtBinding

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.MergeIfStmt()¶

MergeIfStmt

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LoopUnswitching()¶

LoopUnswitching: Hoist loop-invariant if statements out of loops.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.MultiVersionBuffer()¶

WarpSpecializedPipeline

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.ProducerConsumerWarpSpecialized()¶

Producer-Consumer Warp Specialization for TMA pipelines.

Splits pipelined loops with TMA loads into producer (TMA copy) and consumer (compute) warp groups with mbarrier-based synchronization.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.AnnotateWarpGroupRegAlloc()¶

Inject set_max_nreg calls into warp-specialized functions.

This pass analyzes the function to collect register hints from set_max_nreg and no_set_max_nreg calls, then injects appropriate set_max_nreg calls into producer and consumer branches of warp-specialized code.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.FuseMBarrierArriveExpectTx()¶

Fuse simple expect_tx -> TMA issue -> arrive back into arrive_and_expect_tx.

tilelang.transform.InjectFenceProxy()¶

InjectFenceProxy

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LegalizeVectorizedLoop()¶

LegalizeLoopVectorize

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LegalizeSafeMemoryAccess()¶

LegalizeLoopVectorize

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LowerAccessPtr()¶

Lower TileLang frontend tl.access_ptr to tir.builtin.tvm_access_ptr.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.MakePackedAPI()¶

MakePackedAPI

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.AnnotateDeviceRegions()¶

AnnotateDeviceRegions

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.SplitHostDevice()¶

Split host/device functions even for empty kernels.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.AnnotateReadOnlyParams()¶

Annotate read-only handle parameters for PrimFuncs.

Adds attribute tl.readonly_param_indices listing param indices that are never written, enabling CUDA codegen to emit const qualifiers to unlock read-only cache loads.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.VectorizeLoop(enable_vectorize=True)¶

VectorizeLoop

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

Parameters:

enable_vectorize (bool)

tilelang.transform.LowerPTXAsyncCopy()¶

Lower eligible global->shared copies into PTX cp.async on CUDA.

When enabled (pass config tl.enable_async_copy, default True), this pass may rewrite plain user-written global->shared BufferStore patterns (e.g. SIMT copies in T.Parallel) into tir.ptx_cp_async, and insert tir.ptx_commit_group + tir.ptx_wait_group(0) to preserve synchronous semantics for normal stores. If explicit commit/wait intrinsics already exist, the pass avoids duplicating them (and may insert a missing commit immediately before an existing wait to cover injected cp.async).

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.InjectPTXAsyncCopy()¶

Deprecated alias of LowerPTXAsyncCopy.

tilelang.transform.LowerDeviceStorageAccessInfo()¶

Lower attached storage access information on device.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

Note

Run this pass after all storage access analysis finish.

tilelang.transform.ConfigIndexBitwidth()¶

Config index bitwidth.

Returns:

  • fpass (tvm.transform.Pass) – The result pass

  • —-

tilelang.transform.FlattenBuffer()¶

FlattenBuffer

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.EliminateStorageSyncForMBarrier()¶

EliminateStorageSyncForMBarrier

tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=False, align_bytes=16)¶

MergeSharedMemoryAllocations

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

Parameters:
  • enable_aggressive_merge (bool)

  • align_bytes (int)

tilelang.transform.LowerL2Persistent()¶

LowerL2Persistent

tilelang.transform.MarkCudaSyncCalls(have_pdl=False)¶

MarkCudaSyncCalls

Parameters:

have_pdl (bool)

tilelang.transform.PersistThreadblock()¶

PersistThreadblock

tilelang.transform.AlignDynamicSharedMemoryAllocations(align_bytes=16)¶

AlignDynamicSharedMemoryAllocations

Parameters:

align_bytes (int) – The alignment bytes.

tilelang.transform.LowerSharedBarrier()¶

LowerSharedBarrier

tilelang.transform.PlanAndUpdateBufferAllocationLocation()¶

Plan and update buffer allocation locations within PrimFuncs.

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.HoistNonRestrictParams()¶
tilelang.transform.StorageRewrite()¶

StorageRewrite

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LowerOpaqueBlock()¶

LowerOpaqueBlock

tilelang.transform.LowerThreadAllreduce()¶

LowerThreadAllreduce

tilelang.transform.LowerIntrin()¶

LowerIntrin

tilelang.transform.LowerDeviceKernelLaunch()¶

Create and return a transform pass that lowers device kernel launch constructs to target-specific IR.

This pass transforms high-level device kernel launch and related intrinsics into lower-level IR suitable for backend code generation and device-side lowering.

Returns:

The transform pass that performs device kernel launch lowering.

Return type:

tvm.transform.Pass

tilelang.transform.LowerSharedTmem()¶

LowerSharedTmem

tilelang.transform.LayoutReducer()¶

Return a TVM transform pass that performs layout reduction/normalization.

This wrapper delegates to the underlying FFI implementation and returns a pass object suitable for use in a PassContext or pass pipeline. The pass is intended to simplify or reduce tensor/layout-related representations during relay/tile transformations.

Returns:

The transform pass object produced by the FFI backend.

tilelang.transform.UnrollLoop()¶

Unroll loops as in Halide pipeline.

This pass unrolls loops based on configuration options including: - auto_max_step: Threshold of number of steps to be automatically unrolled - auto_max_depth: Maximum nested level of loops that can be automatically unrolled - auto_max_extent: Maximum extent of loop that will be unrolled - explicit_unroll: Whether to explicitly unroll instead of setting a pragma - unroll_local_access: Whether to always unroll local access

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass

tilelang.transform.LowerLDGSTG()¶

Lower Ramp-based global memory load/store to ldg/stg intrinsics.

This pass transforms vectorized global memory loads and stores (using Ramp indices) into explicit ldg32/64/128/256 and stg32/64/128/256 intrinsics for better codegen.

Key behaviors: - Converts Ramp-based global BufferLoad to ldg intrinsics - Converts Ramp-based global BufferStore to stg intrinsics - Supports predicated loads (if_then_else with else=0) - Supports predicated stores (if in then case) - Skips loads in async scope (will be lowered to cp.async) - Only enabled for CUDA targets

Returns:

fpass – The result pass

Return type:

tvm.transform.Pass