tilelang.language.builtin¶

The language interface for tl programs.

Functions¶

create_list_of_mbarrier(*args)

Create a list of memory barrier handles.

get_mbarrier(*args)

Retrieve a memory barrier operation.

create_tma_descriptor(*args)

Create a Tensor Memory Access (TMA) descriptor.

tma_load(*args)

Perform a Tensor Memory Access (TMA) load operation.

fence_proxy_async(*args)

Create a fence for asynchronous proxy operations.

tma_store_arrive(*args)

Signal the arrival of a TMA store operation.

tma_store_wait(*args)

Wait for completion of TMA store operations.

set_max_nreg(reg_count, is_inc)

Set the maximum number of registers to use.

inc_max_nreg(reg_count)

Increment the maximum number of registers to use.

dec_max_nreg(reg_count)

Decrement the maximum number of registers to use.

annotate_producer_reg_dealloc([reg_count])

Annotate the producer reg dealloc.

annotate_consumer_reg_alloc([reg_count])

Annotate the consumer reg alloc.

no_set_max_nreg()

Disable the maximum register limit setting.

disable_warp_group_reg_alloc()

Disable the warp group reg alloc.

mbarrier_wait_parity(mbarrier, parity)

Wait for memory barrier parity condition.

mbarrier_arrive(mbarrier)

Arrive at memory barrier.

mbarrier_expect_tx(*args)

Set expected transaction count for memory barrier.

warpgroup_arrive()

Signal warpgroup readiness for subsequent WGMMA operations.

warpgroup_commit_batch()

Commit the current warpgroup batch for WGMMA operations.

warpgroup_wait(num_mma)

Wait for completion of the specified warpgroup batch.

get_lane_idx([warp_size])

Return the logical lane index of the calling thread within a warp.

get_warp_idx_sync([warp_size])

Return the canonical warp index, assuming the warp's threads are converged.

get_warp_idx([warp_size])

Return the canonical warp index without synchronizing the warp.

get_warp_group_idx([warp_size, warps_per_group])

Return the canonical warp group index for the calling thread.

shuffle_elect(thread_extent)

Elect exactly one lane within a logical thread group.

wait_wgmma(id)

Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

barrier_wait(barrier_id[, parity])

Wait for a memory barrier to complete.

barrier_arrive(barrier_id)

Arrive at a memory barrier.

shfl_xor(value, offset)

Perform a shuffle operation with XOR offset.

shfl_down(value, offset)

Perform a shuffle operation with down offset.

shfl_up(value, offset)

Perform a shuffle operation with up offset.

sync_threads([barrier_id, arrive_count])

Synchronize all threads in a block.

sync_global()

Synchronize all threads in the entire grid.

sync_grid()

Synchronize all threads in a grid.

initialize_descriptor(descriptor, start_address[, ...])

Initialize a memory descriptor with the given parameters.

increase_descriptor_offset(descriptor, offset)

Increase the offset of a memory descriptor.

loop_break()

Break out of the innermost loop.

cp_async_barrier_noinc(barrier_id)

Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.

Module Contents¶

tilelang.language.builtin.create_list_of_mbarrier(*args)¶

Create a list of memory barrier handles.

Parameters:

*args (list or Any) – Either a single list of arguments, or multiple arguments directly.

Returns:

Handle to the created list of memory barriers.

Return type:

tvm.tir.Call

Raises:

TypeError – If the input is not a list or variadic arguments.

Examples

>>> create_list_of_mbarrier([128, 128])
>>> create_list_of_mbarrier(128, 128)
tilelang.language.builtin.get_mbarrier(*args)¶

Retrieve a memory barrier operation.

Parameters:

*args – Variable arguments to specify which memory barrier to retrieve

Returns:

A handle to the requested memory barrier

Return type:

tir.Call

tilelang.language.builtin.create_tma_descriptor(*args)¶

Create a Tensor Memory Access (TMA) descriptor.

Parameters:

*args – Variable arguments defining the TMA descriptor configuration

Returns:

A handle to the created TMA descriptor

Return type:

tir.Call

tilelang.language.builtin.tma_load(*args)¶

Perform a Tensor Memory Access (TMA) load operation.

Parameters:

*args – Variable arguments specifying the TMA load parameters

Returns:

A handle to the TMA load operation

Return type:

tir.Call

tilelang.language.builtin.fence_proxy_async(*args)¶

Create a fence for asynchronous proxy operations.

Parameters:

*args – Variable arguments for fence configuration

Returns:

A handle to the fence operation

Return type:

tir.Call

tilelang.language.builtin.tma_store_arrive(*args)¶

Signal the arrival of a TMA store operation.

Parameters:

*args – Variable arguments for the store arrival operation

Returns:

A handle to the store arrive operation

Return type:

tir.Call

tilelang.language.builtin.tma_store_wait(*args)¶

Wait for completion of TMA store operations.

Parameters:

*args – Variable arguments specifying which store operations to wait for

Returns:

A handle to the store wait operation

Return type:

tir.Call

tilelang.language.builtin.set_max_nreg(reg_count, is_inc)¶

Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg

Parameters:
  • reg_count (int) – int The number of registers to allocate

  • is_inc (int) – int Whether to increment or decrement the register count 0 if decrement, 1 if increment

Returns:

A handle to the register setting operation

Return type:

tir.Call

tilelang.language.builtin.inc_max_nreg(reg_count)¶

Increment the maximum number of registers to use.

Parameters:

reg_count (int)

tilelang.language.builtin.dec_max_nreg(reg_count)¶

Decrement the maximum number of registers to use.

Parameters:

reg_count (int)

tilelang.language.builtin.annotate_producer_reg_dealloc(reg_count=24)¶

Annotate the producer reg dealloc.

Parameters:

reg_count (int)

tilelang.language.builtin.annotate_consumer_reg_alloc(reg_count=240)¶

Annotate the consumer reg alloc.

Parameters:

reg_count (int)

tilelang.language.builtin.no_set_max_nreg()¶

Disable the maximum register limit setting.

tilelang.language.builtin.disable_warp_group_reg_alloc()¶

Disable the warp group reg alloc.

tilelang.language.builtin.mbarrier_wait_parity(mbarrier, parity)¶

Wait for memory barrier parity condition.

Parameters:
  • mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to wait on

  • parity (int | tvm.tir.Var) – Optional[int, Var] The parity value to wait for

Examples

# Wait for parity 0 on barrier 0
T.mbarrier_wait_parity(0, 0)

# Wait for parity value in variable ko on barrier 1
T.mbarrier_wait_parity(1, ko)

# Wait using barrier handle
barrier = T.get_mbarrier(0)
T.mbarrier_wait_parity(barrier, 1)

# Common usage in pipelined kernels:
for ko in range(num_stages):
    # Producer waits for consumer to finish previous iteration
    T.mbarrier_wait_parity(1, ko ^ 1)
    # Producer copies data
    T.copy(A_global, A_shared)
    # Producer signals data ready
    T.mbarrier_arrive(0)

    # Consumer waits for producer data
    T.mbarrier_wait_parity(0, ko)
    # Consumer computes
    T.gemm(A_shared, B_shared, C_local)
    # Consumer signals completion
    T.mbarrier_arrive(1)
Returns:

A handle to the barrier wait operation

Return type:

tir.Call

Parameters:
  • mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call)

  • parity (int | tvm.tir.Var)

tilelang.language.builtin.mbarrier_arrive(mbarrier)¶

Arrive at memory barrier.

Parameters:

mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to arrive at

tilelang.language.builtin.mbarrier_expect_tx(*args)¶

Set expected transaction count for memory barrier.

Parameters:

*args – Variable arguments specifying the expected transaction count

Returns:

A handle to the barrier expectation operation

Return type:

tir.Call

tilelang.language.builtin.warpgroup_arrive()¶

Signal warpgroup readiness for subsequent WGMMA operations.

Returns:

A handle to the warpgroup arrive operation.

Return type:

tir.Call

tilelang.language.builtin.warpgroup_commit_batch()¶

Commit the current warpgroup batch for WGMMA operations.

Returns:

A handle to the warpgroup commit batch operation.

Return type:

tir.Call

tilelang.language.builtin.warpgroup_wait(num_mma)¶

Wait for completion of the specified warpgroup batch.

Parameters:

num_mma (int) – int Identifier of the warpgroup MMA batch to wait on.

Returns:

A handle to the warpgroup wait operation.

Return type:

tir.Call

tilelang.language.builtin.get_lane_idx(warp_size=None)¶

Return the logical lane index of the calling thread within a warp.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD.

Return type:

tvm.tir.PrimExpr

Example

>>> lane = T.get_lane_idx()
>>> custom_lane = T.get_lane_idx(64)  # override warp size explicitly

Implementation Notes¶

Lowers to the CUDA helper tl::get_lane_idx(warp_size) defined in src/tl_templates/cuda/intrin.h, which computes the lane index from the linear thread id using the provided warp_size.

tilelang.language.builtin.get_warp_idx_sync(warp_size=None)¶

Return the canonical warp index, assuming the warp’s threads are converged.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp size used for the index calculation.

Return type:

tvm.tir.PrimExpr

Example

>>> warp = T.get_warp_idx_sync()
>>> custom_warp = T.get_warp_idx_sync(64)

Implementation Notes¶

Emits tl::get_warp_idx_sync(warp_size) which divides the block-linear thread id by warp_size, matching the semantics of CUTLASS’ canonical helpers.

tilelang.language.builtin.get_warp_idx(warp_size=None)¶

Return the canonical warp index without synchronizing the warp.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp size used for the index calculation.

Return type:

tvm.tir.PrimExpr

Example

>>> warp = T.get_warp_idx()
>>> custom_warp = T.get_warp_idx(64)

Implementation Notes¶

Lowers to tl::get_warp_idx(warp_size) which divides the block-linear thread id by the provided warp_size without requiring warp convergence.

tilelang.language.builtin.get_warp_group_idx(warp_size=None, warps_per_group=None)¶

Return the canonical warp group index for the calling thread.

Parameters:
  • warp_size (Optional[int, PrimExpr]) – Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD).

  • warps_per_group (Optional[int, PrimExpr]) – Number of warps per warp-group. Defaults to 4 on NVIDIA architectures.

Return type:

tvm.tir.PrimExpr

Example

>>> group = T.get_warp_group_idx()
>>> custom_group = T.get_warp_group_idx(32, 6)  # treat 6 warps as a group

Implementation Notes¶

Generates tl::get_warp_group_idx(warp_size, warps_per_group) which divides the block-linear thread id by warp_size * warps_per_group, matching the canonical ordering while allowing architecture-specific overrides.

tilelang.language.builtin.shuffle_elect(thread_extent)¶

Elect exactly one lane within a logical thread group.

Parameters:

thread_extent (int) – Size (in threads) of the group in which a single lane should be elected. Passing 0 elects a single lane in the entire thread block.

Return type:

tvm.tir.PrimExpr

Example

>>> is_leader = T.shuffle_elect(64)
>>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0))

Implementation Notes¶

Lowered to the CUDA helper tl::tl_shuffle_elect<thread_extent>() defined in src/tl_templates/cuda/intrin.h, which relies on cutlass::canonical_warp_idx_sync() and cute::elect_one_sync() (or __shfl_sync) to pick one lane per group.

tilelang.language.builtin.wait_wgmma(id)¶

Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

Parameters:

id (int) – int The id of the WGMMA operation to wait for

Returns:

A handle to the WGMMA wait operation

Return type:

tir.Call

tilelang.language.builtin.barrier_wait(barrier_id, parity=None)¶

Wait for a memory barrier to complete.

Parameters:
  • barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to wait on

  • parity (int | tvm.tir.Var | None) – Optional[int, Var] The parity value to wait for

Returns:

A handle to the barrier wait operation

Return type:

tir.Call

Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1.

tilelang.language.builtin.barrier_arrive(barrier_id)¶

Arrive at a memory barrier.

Parameters:

barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to arrive at

tilelang.language.builtin.shfl_xor(value, offset)¶

Perform a shuffle operation with XOR offset.

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle

  • offset (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The offset for the shuffle operation

Returns:

A handle to the shuffle operation

Return type:

tir.Call

tilelang.language.builtin.shfl_down(value, offset)¶

Perform a shuffle operation with down offset.

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle

  • offset (int | tvm.tir.PrimExpr | tvm.tir.Call)

tilelang.language.builtin.shfl_up(value, offset)¶

Perform a shuffle operation with up offset.

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle

  • offset (int | tvm.tir.PrimExpr | tvm.tir.Call)

tilelang.language.builtin.sync_threads(barrier_id=None, arrive_count=None)¶

Synchronize all threads in a block.

Parameters:
  • barrier_id (int)

  • arrive_count (int)

tilelang.language.builtin.sync_global()¶

Synchronize all threads in the entire grid.

tilelang.language.builtin.sync_grid()¶

Synchronize all threads in a grid.

tilelang.language.builtin.initialize_descriptor(descriptor, start_address, layout_type_=0, leading_byte_offset=0, stride_byte_offset=0)¶

Initialize a memory descriptor with the given parameters.

Parameters:
  • descriptor (Buffer) – The memory descriptor to initialize.

  • start_address (PrimExpr) – The starting address of the memory region.

  • layout_type (int, optional) – Layout type identifier. Defaults to 0.

  • leading_byte_offset (int, optional) – Leading byte offset. Defaults to 0.

  • stride_byte_offset (int, optional) – Stride byte offset. Defaults to 0.

  • layout_type_ (int)

Returns:

A handle representing the initialized descriptor.

Return type:

PrimExpr

tilelang.language.builtin.increase_descriptor_offset(descriptor, offset)¶

Increase the offset of a memory descriptor.

Parameters:
  • descriptor (PrimExpr) – The memory descriptor to modify.

  • offset (PrimExpr) – The offset value to increase.

Returns:

A handle representing the modified descriptor.

Return type:

PrimExpr

tilelang.language.builtin.loop_break()¶

Break out of the innermost loop.

tilelang.language.builtin.cp_async_barrier_noinc(barrier_id)¶

Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.

Parameters:

barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call)