tilelang.transform.decouple_type_cast¶
Decouple type cast vectorization constraints.
When a vectorized loop has mixed-precision operations between local and memory buffers, the vectorization length would be constrained by the GCD of all involved dtypes.
This pass decouples the constraints by inserting a local buffer as an intermediate stage, allowing optimal vectorization for both computation and memory access.
Mixed-precision is detected by the presence of Cast nodes in the loop body.
Two cases are handled:
Case 1: local → memory (store to memory with mixed types)¶
- Before:
- for vec in T.vectorized(16):
b[vec] = T.cast(a_frag[vec], “float4_e2m1fn”)
- After:
- for vec in T.vectorized(16):
cast_buf[vec] = T.cast(a_frag[vec], “float4_e2m1fn”) # compute
- for vec_copy in T.vectorized(16):
b[vec_copy] = cast_buf[vec_copy] # copy to memory
Case 2: memory → local (load from memory with different dtype)¶
- Before:
- for vec in T.vectorized(16):
a_frag[vec] = T.cast(b[vec], “float32”)
- After:
- for vec_copy in T.vectorized(16):
cast_buf[vec_copy] = b[vec_copy] # copy from memory
- for vec in T.vectorized(16):
a_frag[vec] = T.cast(cast_buf[vec], “float32”) # compute
Attributes¶
Classes¶
Collect shared/global BufferStore and BufferLoad nodes. |
|
Mutator that decouples type cast vectorization constraints. |
|
Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses. |
Functions¶
|
Check if a buffer is local (register-level), including local.var. |
|
Check if a buffer is a global or shared buffer. |
|
Inline all LetStmt bindings in stmt so that downstream visitors can |
|
Extract IfThenElse condition from statement if present. |
Create a TVM pass that decouples type cast vectorization constraints. |
Module Contents¶
- tilelang.transform.decouple_type_cast.is_local_buffer(buffer)¶
Check if a buffer is local (register-level), including local.var.
- Parameters:
buffer (tvm.tir.Buffer)
- Return type:
Check if a buffer is a global or shared buffer.
- Parameters:
buffer (tvm.tir.Buffer)
- Return type:
- class tilelang.transform.decouple_type_cast.MemoryAccessCollector(loop_var)¶
Bases:
tvm.tir.PyStmtExprVisitorCollect shared/global BufferStore and BufferLoad nodes.
Skips indices traversal so that index expressions (which may contain BufferLoads to index buffers) do not pollute the result.
BufferLoads in if_then_else conditions are skipped because conditions don’t participate in the type-cast compute path.
BufferLoads whose indices do not depend on
loop_varare skipped because they are scalar accesses (e.g.b[0]) that should remain in the compute loop as broadcasts.- Parameters:
loop_var (tvm.tir.Var)
- loop_var¶
- stores: list[tvm.tir.BufferStore] = []¶
- loads: list[tvm.tir.BufferLoad] = []¶
- visit_buffer_store_(op)¶
- Parameters:
op (tvm.tir.BufferStore)
- Return type:
None
- visit_buffer_load_(op)¶
- Parameters:
op (tvm.tir.BufferLoad)
- Return type:
None
- visit_call_(op)¶
- Parameters:
op (tvm.tir.Call)
- Return type:
None
- tilelang.transform.decouple_type_cast.inline_let_stmts(stmt)¶
Inline all LetStmt bindings in stmt so that downstream visitors can see the original BufferLoad nodes that were hidden behind Var references.
Used before collecting memory accesses so that BufferLoads inside LetStmt values are visible to
MemoryAccessCollector.- Parameters:
stmt (tvm.tir.Stmt)
- Return type:
tvm.tir.Stmt
- tilelang.transform.decouple_type_cast.extract_if_condition(stmt)¶
Extract IfThenElse condition from statement if present.
- Returns:
A tuple of (condition, inner_body). If no IfThenElse, returns (None, stmt).
- Parameters:
stmt (tvm.tir.Stmt)
- Return type:
tuple[tvm.tir.PrimExpr | None, tvm.tir.Stmt]
- tilelang.transform.decouple_type_cast.CastEntry¶
- class tilelang.transform.decouple_type_cast.DecoupleTypeCastMutator¶
Bases:
tvm.tir.PyStmtExprMutatorMutator that decouples type cast vectorization constraints.
This mutator transforms vectorized loops that have mixed-precision operations (detected by the presence of Cast nodes) by inserting local cache buffers as intermediate stages.
- visit_for_(op)¶
Visit For nodes, transforming vectorized loops with mixed-type stores.
- Parameters:
op (tvm.tir.For)
- Return type:
tvm.tir.Stmt
- class tilelang.transform.decouple_type_cast.AccessReplacer(store_entries, load_entries, loop_var)¶
Bases:
tvm.tir.PyStmtExprMutatorMutator to replace memory BufferStores/BufferLoads with cast buffer accesses.
Matches by both buffer and indices (structural equality) so that accesses like a[i] and a[i+32] from the same buffer map to different cast buffers.
- Parameters:
store_entries (list[CastEntry])
load_entries (list[CastEntry])
loop_var (tvm.tir.Var)
- store_entries¶
- load_entries¶
- loop_var¶
- visit_buffer_store_(op)¶
- Parameters:
op (tvm.tir.BufferStore)
- Return type:
tvm.tir.Stmt
- visit_buffer_load_(op)¶
- Parameters:
op (tvm.tir.BufferLoad)
- Return type:
tvm.tir.PrimExpr
- tilelang.transform.decouple_type_cast.DecoupleTypeCast()¶
Create a TVM pass that decouples type cast vectorization constraints.
This pass inserts a local buffer as an intermediate stage for vectorized loops where the body contains Cast nodes (mixed-precision operations).
This allows optimal vectorization for both computation and memory access.
Note
This pass must be applied before VectorizeLoop and StorageRewrite passes, while the IR still uses BufferLoad/BufferStore (not tvm_access_ptr).
- Returns:
A TVM PrimFunc pass.