tilelang.transform.add_bufstore_wrapper¶

Classes¶

FindVarUse

A Python StmtExprVisitor to define custom visitor for both Stmt and PrimExpr.

AddWrapperForSingleStoreMutator

Add a dummy parallel for loop to wrap the single buffer store

Functions¶

Module Contents¶

class tilelang.transform.add_bufstore_wrapper.FindVarUse¶

Bases: tvm.tir.PyStmtExprVisitor

A Python StmtExprVisitor to define custom visitor for both Stmt and PrimExpr.

Users can customize any of the visit function.

used_var¶
visit_var_(op)¶

Visit Var.

Users can customize this function to overwrite VisitVar_(const VarNode* op) on the C++ side.

Parameters:

op (Var) – The Var to be visited.

class tilelang.transform.add_bufstore_wrapper.AddWrapperForSingleStoreMutator¶

Bases: tvm.tir.PyStmtExprMutator

Add a dummy parallel for loop to wrap the single buffer store
Condition:
  1. not inside a parallel for loop

  2. no custom thread binding, i.e. threadIdx.x, blockIdx.x

inside_pfor = 0¶
thread_binding_var¶
visit_block_(op)¶

Visit Block. Users can customize this function to overwrite VisitStmt_(const BlockNode* op) on the C++ side.

Parameters:

op (Block) – The Block to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_attr_stmt_(op)¶

Visit AttrStmt. Users can customize this function to overwrite VisitStmt_(const AttrStmtNode* op) on the C++ side.

Parameters:

op (AttrStmt) – The AttrStmt to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_for_(op)¶

Visit For. Users can customize this function to overwrite VisitStmt_(const ForNode* op) on the C++ side.

Parameters:

op (For) – The For to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

visit_buffer_store_(op)¶

Visit BufferStore. Users can customize this function to overwrite VisitStmt_(const BufferStoreNode* op) on the C++ side.

Parameters:

op (BufferStore) – The BufferStore to be visited.

Returns:

result – The mutated Stmt.

Return type:

Stmt

tilelang.transform.add_bufstore_wrapper.AddWrapperForSingleBufStore()¶