tilelang.transform.hoist_broadcast_values¶
Classes¶
Functions¶
TVM Pass: HoistBroadcastValues. |
Module Contents¶
- class tilelang.transform.hoist_broadcast_values.HoistBroadcastValuesMutator¶
Bases:
tvm.tir.PyStmtExprMutator- pending_defs = []¶
- hoist_enabled = False¶
- visit_broadcast_(op)¶
- visit_buffer_store_(op)¶
- Parameters:
op (tvm.tir.BufferStore)
- visit_let_stmt_(op)¶
- Parameters:
op (tvm.tir.LetStmt)
- tilelang.transform.hoist_broadcast_values.HoistBroadcastValues()¶
TVM Pass: HoistBroadcastValues.
This pass scans the TIR for Broadcast operations involving immediate constants (IntImm, FloatImm). It extracts these constants into variables defined via LetStmt immediately surrounding the statement where the broadcast occurs.
Example Transformation:¶
- Before:
A[i] = B[i] + T.Broadcast(3.14, 4) + T.Broadcast(3.14, 4)
- After:
bv_3_14 = 3.14 bv_3_14_1 = 3.14 A[i] = B[i] + T.Broadcast(bv_3_14, 4) + T.Broadcast(bv_3_14_1, 4)