tilelang.transform.decouple_type_cast¶
Decouple type cast vectorization constraints.
When a vectorized loop has mixed-precision operations between local and memory buffers, the vectorization length would be constrained by the GCD of all involved dtypes.
This pass decouples the constraints by inserting a local buffer as an intermediate stage, allowing optimal vectorization for both computation and memory access.
Mixed-precision is detected by the presence of Cast nodes in the loop body.
Two cases are handled:
Case 1: local → memory (store to memory with mixed types)¶
- Before:
- for vec in T.vectorized(16):
b[vec] = T.cast(a_frag[vec], “float4_e2m1fn”)
- After:
- for vec in T.vectorized(16):
cast_buf[vec] = T.cast(a_frag[vec], “float4_e2m1fn”) # compute
- for vec_copy in T.vectorized(16):
b[vec_copy] = cast_buf[vec_copy] # copy to memory
Case 2: memory → local (load from memory with different dtype)¶
- Before:
- for vec in T.vectorized(16):
a_frag[vec] = T.cast(b[vec], “float32”)
- After:
- for vec_copy in T.vectorized(16):
cast_buf[vec_copy] = b[vec_copy] # copy from memory
- for vec in T.vectorized(16):
a_frag[vec] = T.cast(cast_buf[vec], “float32”) # compute
Attributes¶
Classes¶
Collect shared/global BufferStore and BufferLoad nodes. |
|
Mutator that decouples type cast vectorization constraints. |
|
Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses. |
Functions¶
|
Check if a buffer is local (register-level), including local.var and metal.simdgroup. |
|
Check if a buffer is a global or shared buffer. |
|
Return |
|
Extract IfThenElse condition from statement if present. |
Create a TVM pass that decouples type cast vectorization constraints. |
Module Contents¶
- tilelang.transform.decouple_type_cast.is_local_buffer(buffer)¶
Check if a buffer is local (register-level), including local.var and metal.simdgroup.
- Parameters:
buffer (tvm.tirx.Buffer)
- Return type:
Check if a buffer is a global or shared buffer.
- Parameters:
buffer (tvm.tirx.Buffer)
- Return type:
- class tilelang.transform.decouple_type_cast.MemoryAccessCollector(loop_var)¶
Bases:
tvm.tirx.PyStmtExprVisitorCollect shared/global BufferStore and BufferLoad nodes.
Skips indices traversal so that index expressions (which may contain BufferLoads to index buffers) do not pollute the result.
BufferLoads in if_then_else conditions are skipped because conditions don’t participate in the type-cast compute path.
BufferLoads whose indices do not depend on
loop_varare skipped because they are scalar accesses (e.g.b[0]) that should remain in the compute loop as broadcasts.- Parameters:
loop_var (tvm.tirx.Var)
- loop_var¶
- stores: list[tvm.tirx.BufferStore] = []¶
- loads: list[tvm.tirx.BufferLoad] = []¶
- visit_buffer_store_(op)¶
- Parameters:
op (tvm.tirx.BufferStore)
- Return type:
None
- visit_buffer_load_(op)¶
- Parameters:
op (tvm.tirx.BufferLoad)
- Return type:
None
- visit_call_(op)¶
- Parameters:
op (tvm.tirx.Call)
- Return type:
None
- tilelang.transform.decouple_type_cast.BindEnv¶
- tilelang.transform.decouple_type_cast.normalize_flat_binds(stmt)¶
Return
stmtwith dominating flat Bind values substituted into uses.- Parameters:
stmt (tvm.tirx.Stmt)
- Return type:
tvm.tirx.Stmt
- tilelang.transform.decouple_type_cast.extract_if_condition(stmt)¶
Extract IfThenElse condition from statement if present.
- Returns:
A tuple of (condition, inner_body). If no IfThenElse, returns (None, stmt).
- Parameters:
stmt (tvm.tirx.Stmt)
- Return type:
tuple[tvm.tirx.PrimExpr | None, tvm.tirx.Stmt]
- tilelang.transform.decouple_type_cast.CastEntry¶
- class tilelang.transform.decouple_type_cast.DecoupleTypeCastMutator¶
Bases:
tvm.tirx.PyStmtExprMutatorMutator that decouples type cast vectorization constraints.
This mutator transforms vectorized loops that have mixed-precision operations (detected by the presence of Cast nodes) by inserting local cache buffers as intermediate stages.
- visit_for_(op)¶
Visit For nodes, transforming vectorized loops with mixed-type stores.
- Parameters:
op (tvm.tirx.For)
- Return type:
tvm.tirx.Stmt
- class tilelang.transform.decouple_type_cast.AccessReplacer(store_entries, load_entries, loop_var)¶
Bases:
tvm.tirx.PyStmtExprMutatorMutator to replace memory BufferStores/BufferLoads with cast buffer accesses.
Matches by both buffer and indices (structural equality) so that accesses like a[i] and a[i+32] from the same buffer map to different cast buffers.
- Parameters:
store_entries (list[CastEntry])
load_entries (list[CastEntry])
loop_var (tvm.tirx.Var)
- store_entries¶
- load_entries¶
- loop_var¶
- visit_buffer_store_(op)¶
- Parameters:
op (tvm.tirx.BufferStore)
- Return type:
tvm.tirx.Stmt
- visit_buffer_load_(op)¶
- Parameters:
op (tvm.tirx.BufferLoad)
- Return type:
tvm.tirx.PrimExpr
- tilelang.transform.decouple_type_cast.DecoupleTypeCast()¶
Create a TVM pass that decouples type cast vectorization constraints.
This pass inserts a local buffer as an intermediate stage for vectorized loops where the body contains Cast nodes (mixed-precision operations).
This allows optimal vectorization for both computation and memory access.
Note
This pass must be applied before VectorizeLoop and StorageRewrite passes, while the IR still uses BufferLoad/BufferStore (not tvm_access_ptr).
- Returns:
A TVM PrimFunc pass.