tilelang.transform.decouple_type_cast¶

Decouple type cast vectorization constraints.

When a vectorized loop has mixed-precision operations between local and memory buffers, the vectorization length would be constrained by the GCD of all involved dtypes.

This pass decouples the constraints by inserting a local buffer as an intermediate stage, allowing optimal vectorization for both computation and memory access.

Mixed-precision is detected by the presence of Cast nodes in the loop body.

Two cases are handled:

Case 1: local → memory (store to memory with mixed types)¶

Before:
for vec in T.vectorized(16):

b[vec] = T.cast(a_frag[vec], “float4_e2m1fn”)

After:
for vec in T.vectorized(16):

cast_buf[vec] = T.cast(a_frag[vec], “float4_e2m1fn”) # compute

for vec_copy in T.vectorized(16):

b[vec_copy] = cast_buf[vec_copy] # copy to memory

Case 2: memory → local (load from memory with different dtype)¶

Before:
for vec in T.vectorized(16):

a_frag[vec] = T.cast(b[vec], “float32”)

After:
for vec_copy in T.vectorized(16):

cast_buf[vec_copy] = b[vec_copy] # copy from memory

for vec in T.vectorized(16):

a_frag[vec] = T.cast(cast_buf[vec], “float32”) # compute

Attributes¶

Classes¶

MemoryAccessCollector

Collect shared/global BufferStore and BufferLoad nodes.

DecoupleTypeCastMutator

Mutator that decouples type cast vectorization constraints.

AccessReplacer

Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses.

Functions¶

is_local_buffer(buffer)

Check if a buffer is local (register-level), including local.var and metal.simdgroup.

is_global_or_shared_buffer(buffer)

Check if a buffer is a global or shared buffer.

normalize_flat_binds(stmt)

Return stmt with dominating flat Bind values substituted into uses.

extract_if_condition(stmt)

Extract IfThenElse condition from statement if present.

DecoupleTypeCast()

Create a TVM pass that decouples type cast vectorization constraints.

Module Contents¶

tilelang.transform.decouple_type_cast.is_local_buffer(buffer)¶

Check if a buffer is local (register-level), including local.var and metal.simdgroup.

Parameters:

buffer (tvm.tirx.Buffer)

Return type:

bool

tilelang.transform.decouple_type_cast.is_global_or_shared_buffer(buffer)¶

Check if a buffer is a global or shared buffer.

Parameters:

buffer (tvm.tirx.Buffer)

Return type:

bool

class tilelang.transform.decouple_type_cast.MemoryAccessCollector(loop_var)¶

Bases: tvm.tirx.PyStmtExprVisitor

Collect shared/global BufferStore and BufferLoad nodes.

Skips indices traversal so that index expressions (which may contain BufferLoads to index buffers) do not pollute the result.

BufferLoads in if_then_else conditions are skipped because conditions don’t participate in the type-cast compute path.

BufferLoads whose indices do not depend on loop_var are skipped because they are scalar accesses (e.g. b[0]) that should remain in the compute loop as broadcasts.

Parameters:

loop_var (tvm.tirx.Var)

loop_var¶
stores: list[tvm.tirx.BufferStore] = []¶
loads: list[tvm.tirx.BufferLoad] = []¶
visit_buffer_store_(op)¶
Parameters:

op (tvm.tirx.BufferStore)

Return type:

None

visit_buffer_load_(op)¶
Parameters:

op (tvm.tirx.BufferLoad)

Return type:

None

visit_call_(op)¶
Parameters:

op (tvm.tirx.Call)

Return type:

None

tilelang.transform.decouple_type_cast.BindEnv¶
tilelang.transform.decouple_type_cast.normalize_flat_binds(stmt)¶

Return stmt with dominating flat Bind values substituted into uses.

Parameters:

stmt (tvm.tirx.Stmt)

Return type:

tvm.tirx.Stmt

tilelang.transform.decouple_type_cast.extract_if_condition(stmt)¶

Extract IfThenElse condition from statement if present.

Returns:

A tuple of (condition, inner_body). If no IfThenElse, returns (None, stmt).

Parameters:

stmt (tvm.tirx.Stmt)

Return type:

tuple[tvm.tirx.PrimExpr | None, tvm.tirx.Stmt]

tilelang.transform.decouple_type_cast.CastEntry¶
class tilelang.transform.decouple_type_cast.DecoupleTypeCastMutator¶

Bases: tvm.tirx.PyStmtExprMutator

Mutator that decouples type cast vectorization constraints.

This mutator transforms vectorized loops that have mixed-precision operations (detected by the presence of Cast nodes) by inserting local cache buffers as intermediate stages.

visit_for_(op)¶

Visit For nodes, transforming vectorized loops with mixed-type stores.

Parameters:

op (tvm.tirx.For)

Return type:

tvm.tirx.Stmt

class tilelang.transform.decouple_type_cast.AccessReplacer(store_entries, load_entries, loop_var)¶

Bases: tvm.tirx.PyStmtExprMutator

Mutator to replace memory BufferStores/BufferLoads with cast buffer accesses.

Matches by both buffer and indices (structural equality) so that accesses like a[i] and a[i+32] from the same buffer map to different cast buffers.

Parameters:
  • store_entries (list[CastEntry])

  • load_entries (list[CastEntry])

  • loop_var (tvm.tirx.Var)

store_entries¶
load_entries¶
loop_var¶
visit_buffer_store_(op)¶
Parameters:

op (tvm.tirx.BufferStore)

Return type:

tvm.tirx.Stmt

visit_buffer_load_(op)¶
Parameters:

op (tvm.tirx.BufferLoad)

Return type:

tvm.tirx.PrimExpr

tilelang.transform.decouple_type_cast.DecoupleTypeCast()¶

Create a TVM pass that decouples type cast vectorization constraints.

This pass inserts a local buffer as an intermediate stage for vectorized loops where the body contains Cast nodes (mixed-precision operations).

This allows optimal vectorization for both computation and memory access.

Note

This pass must be applied before VectorizeLoop and StorageRewrite passes, while the IR still uses BufferLoad/BufferStore (not tvm_access_ptr).

Returns:

A TVM PrimFunc pass.