tilelang.transform.decouple_type_cast¶

Decouple type cast vectorization constraints.

When a vectorized loop has mixed-precision operations between local and memory buffers, the vectorization length would be constrained by the GCD of all involved dtypes.

This pass decouples the constraints by inserting a local buffer as an intermediate stage, allowing optimal vectorization for both computation and memory access.

Two cases are handled:

Case 1: local → memory (store to memory with mixed types)¶

Before:
for vec in T.vectorized(16):

b[vec] = T.cast(a_frag[vec], “float4_e2m1fn”)

After:
for vec in T.vectorized(16):

cast_buf[vec] = T.cast(a_frag[vec], “float4_e2m1fn”) # compute

for vec_copy in T.vectorized(16):

b[vec_copy] = cast_buf[vec_copy] # copy to memory

Case 2: memory → local (load from memory with different dtype)¶

Before:
for vec in T.vectorized(16):

a_frag[vec] = T.cast(b[vec], “float32”)

After:
for vec_copy in T.vectorized(16):

cast_buf[vec_copy] = b[vec_copy] # copy from memory

for vec in T.vectorized(16):

a_frag[vec] = T.cast(cast_buf[vec], “float32”) # compute

Attributes¶

Classes¶

MixedTypeChecker

Check if expression contains BufferLoads with different dtypes, skipping indices.

GlobalSharedBufferLoadCollector

Collect BufferLoads from global/shared buffers, skipping if_then_else conditions.

StoreCollector

Collect BufferStore nodes that need transformation, skipping indices traversal.

DecoupleTypeCastMutator

Mutator that decouples type cast vectorization constraints.

StoreReplacer

Mutator to replace memory BufferStores with cast buffer BufferStores.

LoadReplacer

Mutator to replace memory BufferLoads with cast buffer BufferLoads.

Functions¶

is_local_buffer(buffer)

Check if a buffer is local (register-level), including local.var.

is_global_or_shared_buffer(buffer)

Check if a buffer is a global or shared buffer.

validate_buffer_scope(buffer)

Validate that buffer has a known scope.

has_mixed_types(expr, target_dtype)

Check if expression contains BufferLoads with different dtypes than target.

get_global_or_shared_buffer_loads(expr[, ...])

Get BufferLoads from global/shared buffers in the expression.

has_global_or_shared_load_with_different_dtype(expr, ...)

Check if expression has global/shared BufferLoad with different dtype than target.

contains_seq_stmt(stmt)

Check if statement contains SeqStmt (multiple statements).

extract_if_condition(stmt)

Extract IfThenElse condition from statement if present.

DecoupleTypeCast()

Create a TVM pass that decouples type cast vectorization constraints.

Module Contents¶

tilelang.transform.decouple_type_cast.is_local_buffer(buffer)¶

Check if a buffer is local (register-level), including local.var.

Parameters:

buffer (tvm.tir.Buffer)

Return type:

bool

tilelang.transform.decouple_type_cast.is_global_or_shared_buffer(buffer)¶

Check if a buffer is a global or shared buffer.

Parameters:

buffer (tvm.tir.Buffer)

Return type:

bool

tilelang.transform.decouple_type_cast.validate_buffer_scope(buffer)¶

Validate that buffer has a known scope.

Raises:

ValueError – If buffer scope is unknown or empty.

Parameters:

buffer (tvm.tir.Buffer)

Return type:

None

class tilelang.transform.decouple_type_cast.MixedTypeChecker(target_dtype)¶

Bases: tvm.tir.PyStmtExprVisitor

Check if expression contains BufferLoads with different dtypes, skipping indices.

Parameters:

target_dtype (str)

target_dtype = ''¶
found_different = False¶
visit_buffer_load_(op)¶
Parameters:

op (tvm.tir.BufferLoad)

Return type:

None

tilelang.transform.decouple_type_cast.has_mixed_types(expr, target_dtype)¶

Check if expression contains BufferLoads with different dtypes than target.

If any BufferLoad in the expression has a different dtype than the target (store buffer’s dtype), vectorization may be constrained by GCD of all dtypes.

Parameters:
  • expr (tvm.tir.PrimExpr)

  • target_dtype (str)

Return type:

bool

class tilelang.transform.decouple_type_cast.GlobalSharedBufferLoadCollector(skip_if_then_else_cond=False)¶

Bases: tvm.tir.PyStmtExprVisitor

Collect BufferLoads from global/shared buffers, skipping if_then_else conditions.

The condition part of if_then_else doesn’t participate in type casting, so we skip collecting BufferLoads from there.

Parameters:

skip_if_then_else_cond (bool)

result: list[tvm.tir.BufferLoad] = []¶
skip_if_then_else_cond = False¶
visit_buffer_load_(op)¶
Parameters:

op (tvm.tir.BufferLoad)

Return type:

None

visit_call_(op)¶
Parameters:

op (tvm.tir.Call)

Return type:

None

tilelang.transform.decouple_type_cast.get_global_or_shared_buffer_loads(expr, skip_if_then_else_cond=False)¶

Get BufferLoads from global/shared buffers in the expression.

Parameters:
  • expr (tvm.tir.PrimExpr) – The expression to search.

  • skip_if_then_else_cond (bool) – If True, skip BufferLoads in if_then_else conditions, since they don’t participate in type casting.

Return type:

list[tvm.tir.BufferLoad]

tilelang.transform.decouple_type_cast.has_global_or_shared_load_with_different_dtype(expr, target_dtype)¶

Check if expression has global/shared BufferLoad with different dtype than target.

Used to detect memory→local cases where we need to insert cast buffer. Skips if_then_else condition since it doesn’t participate in type casting.

Parameters:
  • expr (tvm.tir.PrimExpr)

  • target_dtype (str)

Return type:

bool

class tilelang.transform.decouple_type_cast.StoreCollector¶

Bases: tvm.tir.PyStmtExprVisitor

Collect BufferStore nodes that need transformation, skipping indices traversal.

This avoids visiting BufferLoad/BufferStore nodes inside indices, which don’t participate in the type casting transformation.

local_to_memory: list[tvm.tir.BufferStore] = []¶
memory_to_local: list[tvm.tir.BufferStore] = []¶
visit_buffer_store_(op)¶
Parameters:

op (tvm.tir.BufferStore)

Return type:

None

visit_buffer_load_(op)¶
Parameters:

op (tvm.tir.BufferLoad)

Return type:

None

tilelang.transform.decouple_type_cast.contains_seq_stmt(stmt)¶

Check if statement contains SeqStmt (multiple statements).

When the For body has SeqStmt, the transformation is more complex and we skip the optimization for now.

Parameters:

stmt (tvm.tir.Stmt)

Return type:

bool

tilelang.transform.decouple_type_cast.extract_if_condition(stmt)¶

Extract IfThenElse condition from statement if present.

Returns:

A tuple of (condition, inner_body). If no IfThenElse, returns (None, stmt).

Parameters:

stmt (tvm.tir.Stmt)

Return type:

tuple[tvm.tir.PrimExpr | None, tvm.tir.Stmt]

tilelang.transform.decouple_type_cast.CastBufferMap¶
class tilelang.transform.decouple_type_cast.DecoupleTypeCastMutator¶

Bases: tvm.tir.PyStmtExprMutator

Mutator that decouples type cast vectorization constraints.

This mutator transforms vectorized loops that store to memory buffers (global/shared) with mixed-precision expressions by inserting local cache buffers as intermediate stages.

visit_for_(op)¶

Visit For nodes, transforming vectorized loops with mixed-type stores.

Parameters:

op (tvm.tir.For)

Return type:

tvm.tir.Stmt

class tilelang.transform.decouple_type_cast.StoreReplacer(cast_buffers, loop_var)¶

Bases: tvm.tir.PyStmtExprMutator

Mutator to replace memory BufferStores with cast buffer BufferStores.

Parameters:
  • cast_buffers (CastBufferMap)

  • loop_var (tvm.tir.Var)

cast_buffers¶
loop_var¶
visit_buffer_store_(op)¶
Parameters:

op (tvm.tir.BufferStore)

Return type:

tvm.tir.Stmt

class tilelang.transform.decouple_type_cast.LoadReplacer(cast_buffers, loop_var)¶

Bases: tvm.tir.PyStmtExprMutator

Mutator to replace memory BufferLoads with cast buffer BufferLoads.

Parameters:
  • cast_buffers (CastBufferMap)

  • loop_var (tvm.tir.Var)

cast_buffers¶
loop_var¶
visit_buffer_load_(op)¶
Parameters:

op (tvm.tir.BufferLoad)

Return type:

tvm.tir.PrimExpr

tilelang.transform.decouple_type_cast.DecoupleTypeCast()¶

Create a TVM pass that decouples type cast vectorization constraints.

This pass inserts a local buffer as an intermediate stage for vectorized stores to non-local buffers (global/shared) where the store value contains expressions with different dtypes.

This allows optimal vectorization for both computation and memory access.

Note

This pass must be applied before VectorizeLoop and StorageRewrite passes, while the IR still uses BufferLoad/BufferStore (not tvm_access_ptr).

Returns:

A TVM PrimFunc pass.