tilelang.transform.decouple_type_cast¶

Decouple type cast vectorization constraints.

When a vectorized loop has mixed-precision operations between local and memory buffers, the vectorization length would be constrained by the GCD of all involved dtypes.

This pass decouples the constraints by inserting a local buffer as an intermediate stage, allowing optimal vectorization for both computation and memory access.

Mixed-precision is detected by the presence of Cast nodes in the loop body.

Two cases are handled:

Case 1: local → memory (store to memory with mixed types)¶

Before:
for vec in T.vectorized(16):

b[vec] = T.cast(a_frag[vec], “float4_e2m1fn”)

After:
for vec in T.vectorized(16):

cast_buf[vec] = T.cast(a_frag[vec], “float4_e2m1fn”) # compute

for vec_copy in T.vectorized(16):

b[vec_copy] = cast_buf[vec_copy] # copy to memory

Case 2: memory → local (load from memory with different dtype)¶

Before:
for vec in T.vectorized(16):

a_frag[vec] = T.cast(b[vec], “float32”)

After:
for vec_copy in T.vectorized(16):

cast_buf[vec_copy] = b[vec_copy] # copy from memory

for vec in T.vectorized(16):

a_frag[vec] = T.cast(cast_buf[vec], “float32”) # compute

Attributes¶

Classes¶

MemoryAccessCollector

Collect shared/global BufferStore and BufferLoad nodes.

DecoupleTypeCastMutator

Mutator that decouples type cast vectorization constraints.

AccessReplacer

Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses.

Functions¶

is_local_buffer(buffer)

Check if a buffer is local (register-level), including local.var.

is_global_or_shared_buffer(buffer)

Check if a buffer is a global or shared buffer.

inline_let_stmts(stmt)

Inline all LetStmt bindings in stmt so that downstream visitors can

extract_if_condition(stmt)

Extract IfThenElse condition from statement if present.

DecoupleTypeCast()

Create a TVM pass that decouples type cast vectorization constraints.

Module Contents¶

tilelang.transform.decouple_type_cast.is_local_buffer(buffer)¶

Check if a buffer is local (register-level), including local.var.

Parameters:

buffer (tvm.tir.Buffer)

Return type:

bool

tilelang.transform.decouple_type_cast.is_global_or_shared_buffer(buffer)¶

Check if a buffer is a global or shared buffer.

Parameters:

buffer (tvm.tir.Buffer)

Return type:

bool

class tilelang.transform.decouple_type_cast.MemoryAccessCollector(loop_var)¶

Bases: tvm.tir.PyStmtExprVisitor

Collect shared/global BufferStore and BufferLoad nodes.

Skips indices traversal so that index expressions (which may contain BufferLoads to index buffers) do not pollute the result.

BufferLoads in if_then_else conditions are skipped because conditions don’t participate in the type-cast compute path.

BufferLoads whose indices do not depend on loop_var are skipped because they are scalar accesses (e.g. b[0]) that should remain in the compute loop as broadcasts.

Parameters:

loop_var (tvm.tir.Var)

loop_var¶
stores: list[tvm.tir.BufferStore] = []¶
loads: list[tvm.tir.BufferLoad] = []¶
visit_buffer_store_(op)¶
Parameters:

op (tvm.tir.BufferStore)

Return type:

None

visit_buffer_load_(op)¶
Parameters:

op (tvm.tir.BufferLoad)

Return type:

None

visit_call_(op)¶
Parameters:

op (tvm.tir.Call)

Return type:

None

tilelang.transform.decouple_type_cast.inline_let_stmts(stmt)¶

Inline all LetStmt bindings in stmt so that downstream visitors can see the original BufferLoad nodes that were hidden behind Var references.

Used before collecting memory accesses so that BufferLoads inside LetStmt values are visible to MemoryAccessCollector.

Parameters:

stmt (tvm.tir.Stmt)

Return type:

tvm.tir.Stmt

tilelang.transform.decouple_type_cast.extract_if_condition(stmt)¶

Extract IfThenElse condition from statement if present.

Returns:

A tuple of (condition, inner_body). If no IfThenElse, returns (None, stmt).

Parameters:

stmt (tvm.tir.Stmt)

Return type:

tuple[tvm.tir.PrimExpr | None, tvm.tir.Stmt]

tilelang.transform.decouple_type_cast.CastEntry¶
class tilelang.transform.decouple_type_cast.DecoupleTypeCastMutator¶

Bases: tvm.tir.PyStmtExprMutator

Mutator that decouples type cast vectorization constraints.

This mutator transforms vectorized loops that have mixed-precision operations (detected by the presence of Cast nodes) by inserting local cache buffers as intermediate stages.

visit_for_(op)¶

Visit For nodes, transforming vectorized loops with mixed-type stores.

Parameters:

op (tvm.tir.For)

Return type:

tvm.tir.Stmt

class tilelang.transform.decouple_type_cast.AccessReplacer(store_entries, load_entries, loop_var)¶

Bases: tvm.tir.PyStmtExprMutator

Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses.

Matches by both buffer and indices (structural equality) so that accesses like a[i] and a[i+32] from the same buffer map to different cast buffers.

Parameters:
  • store_entries (list[CastEntry])

  • load_entries (list[CastEntry])

  • loop_var (tvm.tir.Var)

store_entries¶
load_entries¶
loop_var¶
visit_buffer_store_(op)¶
Parameters:

op (tvm.tir.BufferStore)

Return type:

tvm.tir.Stmt

visit_buffer_load_(op)¶
Parameters:

op (tvm.tir.BufferLoad)

Return type:

tvm.tir.PrimExpr

tilelang.transform.decouple_type_cast.DecoupleTypeCast()¶

Create a TVM pass that decouples type cast vectorization constraints.

This pass inserts a local buffer as an intermediate stage for vectorized loops where the body contains Cast nodes (mixed-precision operations).

This allows optimal vectorization for both computation and memory access.

Note

This pass must be applied before VectorizeLoop and StorageRewrite passes, while the IR still uses BufferLoad/BufferStore (not tvm_access_ptr).

Returns:

A TVM PrimFunc pass.