tilelang.transform.decouple_type_cast¶
Decouple type cast vectorization constraints.
When a vectorized loop has mixed-precision operations between local and memory buffers, the vectorization length would be constrained by the GCD of all involved dtypes.
This pass decouples the constraints by inserting a local buffer as an intermediate stage, allowing optimal vectorization for both computation and memory access.
Two cases are handled:
Case 1: local → memory (store to memory with mixed types)¶
- Before:
- for vec in T.vectorized(16):
b[vec] = T.cast(a_frag[vec], “float4_e2m1fn”)
- After:
- for vec in T.vectorized(16):
cast_buf[vec] = T.cast(a_frag[vec], “float4_e2m1fn”) # compute
- for vec_copy in T.vectorized(16):
b[vec_copy] = cast_buf[vec_copy] # copy to memory
Case 2: memory → local (load from memory with different dtype)¶
- Before:
- for vec in T.vectorized(16):
a_frag[vec] = T.cast(b[vec], “float32”)
- After:
- for vec_copy in T.vectorized(16):
cast_buf[vec_copy] = b[vec_copy] # copy from memory
- for vec in T.vectorized(16):
a_frag[vec] = T.cast(cast_buf[vec], “float32”) # compute
Attributes¶
Classes¶
Check if expression contains BufferLoads with different dtypes, skipping indices. |
|
Collect BufferLoads from global/shared buffers, skipping if_then_else conditions. |
|
Collect BufferStore nodes that need transformation, skipping indices traversal. |
|
Mutator that decouples type cast vectorization constraints. |
|
Mutator to replace memory BufferStores with cast buffer BufferStores. |
|
Mutator to replace memory BufferLoads with cast buffer BufferLoads. |
Functions¶
|
Check if a buffer is local (register-level), including local.var. |
|
Check if a buffer is a global or shared buffer. |
|
Validate that buffer has a known scope. |
|
Check if expression contains BufferLoads with different dtypes than target. |
|
Get BufferLoads from global/shared buffers in the expression. |
Check if expression has global/shared BufferLoad with different dtype than target. |
|
|
Check if statement contains SeqStmt (multiple statements). |
|
Extract IfThenElse condition from statement if present. |
Create a TVM pass that decouples type cast vectorization constraints. |
Module Contents¶
- tilelang.transform.decouple_type_cast.is_local_buffer(buffer)¶
Check if a buffer is local (register-level), including local.var.
- Parameters:
buffer (tvm.tir.Buffer)
- Return type:
Check if a buffer is a global or shared buffer.
- Parameters:
buffer (tvm.tir.Buffer)
- Return type:
- tilelang.transform.decouple_type_cast.validate_buffer_scope(buffer)¶
Validate that buffer has a known scope.
- Raises:
ValueError – If buffer scope is unknown or empty.
- Parameters:
buffer (tvm.tir.Buffer)
- Return type:
None
- class tilelang.transform.decouple_type_cast.MixedTypeChecker(target_dtype)¶
Bases:
tvm.tir.PyStmtExprVisitorCheck if expression contains BufferLoads with different dtypes, skipping indices.
- Parameters:
target_dtype (str)
- target_dtype = ''¶
- found_different = False¶
- visit_buffer_load_(op)¶
- Parameters:
op (tvm.tir.BufferLoad)
- Return type:
None
- tilelang.transform.decouple_type_cast.has_mixed_types(expr, target_dtype)¶
Check if expression contains BufferLoads with different dtypes than target.
If any BufferLoad in the expression has a different dtype than the target (store buffer’s dtype), vectorization may be constrained by GCD of all dtypes.
- Parameters:
expr (tvm.tir.PrimExpr)
target_dtype (str)
- Return type:
Bases:
tvm.tir.PyStmtExprVisitorCollect BufferLoads from global/shared buffers, skipping if_then_else conditions.
The condition part of if_then_else doesn’t participate in type casting, so we skip collecting BufferLoads from there.
- Parameters:
skip_if_then_else_cond (bool)
- Parameters:
op (tvm.tir.BufferLoad)
- Return type:
None
- Parameters:
op (tvm.tir.Call)
- Return type:
None
Get BufferLoads from global/shared buffers in the expression.
- Parameters:
expr (tvm.tir.PrimExpr) – The expression to search.
skip_if_then_else_cond (bool) – If True, skip BufferLoads in if_then_else conditions, since they don’t participate in type casting.
- Return type:
list[tvm.tir.BufferLoad]
Check if expression has global/shared BufferLoad with different dtype than target.
Used to detect memory→local cases where we need to insert cast buffer. Skips if_then_else condition since it doesn’t participate in type casting.
- Parameters:
expr (tvm.tir.PrimExpr)
target_dtype (str)
- Return type:
- class tilelang.transform.decouple_type_cast.StoreCollector¶
Bases:
tvm.tir.PyStmtExprVisitorCollect BufferStore nodes that need transformation, skipping indices traversal.
This avoids visiting BufferLoad/BufferStore nodes inside indices, which don’t participate in the type casting transformation.
- local_to_memory: list[tvm.tir.BufferStore] = []¶
- memory_to_local: list[tvm.tir.BufferStore] = []¶
- visit_buffer_store_(op)¶
- Parameters:
op (tvm.tir.BufferStore)
- Return type:
None
- visit_buffer_load_(op)¶
- Parameters:
op (tvm.tir.BufferLoad)
- Return type:
None
- tilelang.transform.decouple_type_cast.contains_seq_stmt(stmt)¶
Check if statement contains SeqStmt (multiple statements).
When the For body has SeqStmt, the transformation is more complex and we skip the optimization for now.
- Parameters:
stmt (tvm.tir.Stmt)
- Return type:
- tilelang.transform.decouple_type_cast.extract_if_condition(stmt)¶
Extract IfThenElse condition from statement if present.
- Returns:
A tuple of (condition, inner_body). If no IfThenElse, returns (None, stmt).
- Parameters:
stmt (tvm.tir.Stmt)
- Return type:
tuple[tvm.tir.PrimExpr | None, tvm.tir.Stmt]
- tilelang.transform.decouple_type_cast.CastBufferMap¶
- class tilelang.transform.decouple_type_cast.DecoupleTypeCastMutator¶
Bases:
tvm.tir.PyStmtExprMutatorMutator that decouples type cast vectorization constraints.
This mutator transforms vectorized loops that store to memory buffers (global/shared) with mixed-precision expressions by inserting local cache buffers as intermediate stages.
- visit_for_(op)¶
Visit For nodes, transforming vectorized loops with mixed-type stores.
- Parameters:
op (tvm.tir.For)
- Return type:
tvm.tir.Stmt
- class tilelang.transform.decouple_type_cast.StoreReplacer(cast_buffers, loop_var)¶
Bases:
tvm.tir.PyStmtExprMutatorMutator to replace memory BufferStores with cast buffer BufferStores.
- Parameters:
cast_buffers (CastBufferMap)
loop_var (tvm.tir.Var)
- cast_buffers¶
- loop_var¶
- visit_buffer_store_(op)¶
- Parameters:
op (tvm.tir.BufferStore)
- Return type:
tvm.tir.Stmt
- class tilelang.transform.decouple_type_cast.LoadReplacer(cast_buffers, loop_var)¶
Bases:
tvm.tir.PyStmtExprMutatorMutator to replace memory BufferLoads with cast buffer BufferLoads.
- Parameters:
cast_buffers (CastBufferMap)
loop_var (tvm.tir.Var)
- cast_buffers¶
- loop_var¶
- visit_buffer_load_(op)¶
- Parameters:
op (tvm.tir.BufferLoad)
- Return type:
tvm.tir.PrimExpr
- tilelang.transform.decouple_type_cast.DecoupleTypeCast()¶
Create a TVM pass that decouples type cast vectorization constraints.
This pass inserts a local buffer as an intermediate stage for vectorized stores to non-local buffers (global/shared) where the store value contains expressions with different dtypes.
This allows optimal vectorization for both computation and memory access.
Note
This pass must be applied before VectorizeLoop and StorageRewrite passes, while the IR still uses BufferLoad/BufferStore (not tvm_access_ptr).
- Returns:
A TVM PrimFunc pass.