tilelang.language.builtin¶

The language interface for tl programs.

Functions¶

create_list_of_mbarrier(*args)

Create a list of memory barrier handles.

get_mbarrier(*args)

Retrieve a memory barrier operation.

create_tma_descriptor(*args)

Create a Tensor Memory Access (TMA) descriptor.

tma_load(*args)

Perform a Tensor Memory Access (TMA) load operation.

fence_proxy_async(*args)

Create a fence for asynchronous proxy operations.

tma_store_arrive(*args)

Signal the arrival of a TMA store operation.

tma_store_wait(*args)

Wait for completion of TMA store operations.

set_max_nreg(reg_count, is_inc)

Set the maximum number of registers to use.

inc_max_nreg(reg_count)

Increment the maximum number of registers to use.

dec_max_nreg(reg_count)

Decrement the maximum number of registers to use.

no_set_max_nreg()

Disable the maximum register limit setting.

mbarrier_wait_parity(mbarrier, parity)

Wait for memory barrier parity condition.

mbarrier_arrive(mbarrier)

Arrive at memory barrier.

mbarrier_expect_tx(*args)

Set expected transaction count for memory barrier.

wait_wgmma(id)

Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

barrier_wait(barrier_id[, parity])

Wait for a memory barrier to complete.

barrier_arrive(barrier_id)

Arrive at a memory barrier.

shfl_xor(value, offset)

Perform a shuffle operation with XOR offset.

shfl_down(value, offset)

Perform a shuffle operation with down offset.

shfl_up(value, offset)

Perform a shuffle operation with up offset.

sync_threads()

Synchronize all threads in a warp.

sync_thread_partial(barrier_id)

Synchronize threads within a warp.

sync_global()

Synchronize all threads in a block.

sync_grid()

Synchronize all threads in a grid.

Module Contents¶

tilelang.language.builtin.create_list_of_mbarrier(*args)¶

Create a list of memory barrier handles.

Parameters:

*args (list or Any) – Either a single list of arguments, or multiple arguments directly.

Returns:

Handle to the created list of memory barriers.

Return type:

tvm.tir.Call

Raises:

TypeError – If the input is not a list or variadic arguments.

Examples

>>> create_list_of_mbarrier([128, 128])
>>> create_list_of_mbarrier(128, 128)
tilelang.language.builtin.get_mbarrier(*args)¶

Retrieve a memory barrier operation.

Parameters:

*args – Variable arguments to specify which memory barrier to retrieve

Returns:

A handle to the requested memory barrier

Return type:

tir.Call

tilelang.language.builtin.create_tma_descriptor(*args)¶

Create a Tensor Memory Access (TMA) descriptor.

Parameters:

*args – Variable arguments defining the TMA descriptor configuration

Returns:

A handle to the created TMA descriptor

Return type:

tir.Call

tilelang.language.builtin.tma_load(*args)¶

Perform a Tensor Memory Access (TMA) load operation.

Parameters:

*args – Variable arguments specifying the TMA load parameters

Returns:

A handle to the TMA load operation

Return type:

tir.Call

tilelang.language.builtin.fence_proxy_async(*args)¶

Create a fence for asynchronous proxy operations.

Parameters:

*args – Variable arguments for fence configuration

Returns:

A handle to the fence operation

Return type:

tir.Call

tilelang.language.builtin.tma_store_arrive(*args)¶

Signal the arrival of a TMA store operation.

Parameters:

*args – Variable arguments for the store arrival operation

Returns:

A handle to the store arrive operation

Return type:

tir.Call

tilelang.language.builtin.tma_store_wait(*args)¶

Wait for completion of TMA store operations.

Parameters:

*args – Variable arguments specifying which store operations to wait for

Returns:

A handle to the store wait operation

Return type:

tir.Call

tilelang.language.builtin.set_max_nreg(reg_count, is_inc)¶

Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg

Parameters:
  • reg_count (int) – int The number of registers to allocate

  • is_inc (int) – int Whether to increment or decrement the register count 0 if decrement, 1 if increment

Returns:

A handle to the register setting operation

Return type:

tir.Call

tilelang.language.builtin.inc_max_nreg(reg_count)¶

Increment the maximum number of registers to use.

Parameters:

reg_count (int)

tilelang.language.builtin.dec_max_nreg(reg_count)¶

Decrement the maximum number of registers to use.

Parameters:

reg_count (int)

tilelang.language.builtin.no_set_max_nreg()¶

Disable the maximum register limit setting.

tilelang.language.builtin.mbarrier_wait_parity(mbarrier, parity)¶

Wait for memory barrier parity condition.

Parameters:
  • mbarrier (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to wait on

  • parity (Union[int, tvm.tir.Var]) – Optional[int, Var] The parity value to wait for

Examples

# Wait for parity 0 on barrier 0
T.mbarrier_wait_parity(0, 0)

# Wait for parity value in variable ko on barrier 1
T.mbarrier_wait_parity(1, ko)

# Wait using barrier handle
barrier = T.get_mbarrier(0)
T.mbarrier_wait_parity(barrier, 1)

# Common usage in pipelined kernels:
for ko in range(num_stages):
    # Producer waits for consumer to finish previous iteration
    T.mbarrier_wait_parity(1, ko ^ 1)
    # Producer copies data
    T.copy(A_global, A_shared)
    # Producer signals data ready
    T.mbarrier_arrive(0)

    # Consumer waits for producer data
    T.mbarrier_wait_parity(0, ko)
    # Consumer computes
    T.gemm(A_shared, B_shared, C_local)
    # Consumer signals completion
    T.mbarrier_arrive(1)
Returns:

A handle to the barrier wait operation

Return type:

tir.Call

Parameters:
  • mbarrier (Union[int, tvm.tir.PrimExpr, tvm.tir.Call])

  • parity (Union[int, tvm.tir.Var])

tilelang.language.builtin.mbarrier_arrive(mbarrier)¶

Arrive at memory barrier.

Parameters:

mbarrier (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to arrive at

tilelang.language.builtin.mbarrier_expect_tx(*args)¶

Set expected transaction count for memory barrier.

Parameters:

*args – Variable arguments specifying the expected transaction count

Returns:

A handle to the barrier expectation operation

Return type:

tir.Call

tilelang.language.builtin.wait_wgmma(id)¶

Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

Parameters:

id (int) – int The id of the WGMMA operation to wait for

Returns:

A handle to the WGMMA wait operation

Return type:

tir.Call

tilelang.language.builtin.barrier_wait(barrier_id, parity=None)¶

Wait for a memory barrier to complete.

Parameters:
  • barrier_id (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to wait on

  • parity (Union[int, tvm.tir.Var, None]) – Optional[int, Var] The parity value to wait for

Returns:

A handle to the barrier wait operation

Return type:

tir.Call

Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1.

tilelang.language.builtin.barrier_arrive(barrier_id)¶

Arrive at a memory barrier.

Parameters:

barrier_id (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to arrive at

tilelang.language.builtin.shfl_xor(value, offset)¶

Perform a shuffle operation with XOR offset.

Parameters:
  • value (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The value to shuffle

  • offset (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The offset for the shuffle operation

Returns:

A handle to the shuffle operation

Return type:

tir.Call

tilelang.language.builtin.shfl_down(value, offset)¶

Perform a shuffle operation with down offset.

Parameters:
  • value (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The value to shuffle

  • offset (Union[int, tvm.tir.PrimExpr, tvm.tir.Call])

tilelang.language.builtin.shfl_up(value, offset)¶

Perform a shuffle operation with up offset.

Parameters:
  • value (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The value to shuffle

  • offset (Union[int, tvm.tir.PrimExpr, tvm.tir.Call])

tilelang.language.builtin.sync_threads()¶

Synchronize all threads in a warp.

tilelang.language.builtin.sync_thread_partial(barrier_id)¶

Synchronize threads within a warp.

Parameters:

barrier_id (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to synchronize

Returns:

A handle to the synchronization operation

Return type:

tir.Call

tilelang.language.builtin.sync_global()¶

Synchronize all threads in a block.

tilelang.language.builtin.sync_grid()¶

Synchronize all threads in a grid.