tilelang.language.builtin ========================= .. py:module:: tilelang.language.builtin .. autoapi-nested-parse:: The language interface for tl programs. Functions --------- .. autoapisummary:: tilelang.language.builtin.create_list_of_mbarrier tilelang.language.builtin.get_mbarrier tilelang.language.builtin.create_tma_descriptor tilelang.language.builtin.tma_load tilelang.language.builtin.fence_proxy_async tilelang.language.builtin.tma_store_arrive tilelang.language.builtin.tma_store_wait tilelang.language.builtin.set_max_nreg tilelang.language.builtin.inc_max_nreg tilelang.language.builtin.dec_max_nreg tilelang.language.builtin.no_set_max_nreg tilelang.language.builtin.mbarrier_wait_parity tilelang.language.builtin.mbarrier_arrive tilelang.language.builtin.mbarrier_expect_tx tilelang.language.builtin.wait_wgmma tilelang.language.builtin.barrier_wait tilelang.language.builtin.barrier_arrive tilelang.language.builtin.shfl_xor tilelang.language.builtin.shfl_down tilelang.language.builtin.shfl_up tilelang.language.builtin.sync_threads tilelang.language.builtin.sync_thread_partial tilelang.language.builtin.sync_global tilelang.language.builtin.sync_grid Module Contents --------------- .. py:function:: create_list_of_mbarrier(*args) Create a list of memory barrier handles. :param \*args: Either a single list of arguments, or multiple arguments directly. :type \*args: list or Any :returns: Handle to the created list of memory barriers. :rtype: tvm.tir.Call :raises TypeError: If the input is not a list or variadic arguments. .. rubric:: Examples >>> create_list_of_mbarrier([128, 128]) >>> create_list_of_mbarrier(128, 128) .. py:function:: get_mbarrier(*args) Retrieve a memory barrier operation. :param \*args: Variable arguments to specify which memory barrier to retrieve :returns: A handle to the requested memory barrier :rtype: tir.Call .. py:function:: create_tma_descriptor(*args) Create a Tensor Memory Access (TMA) descriptor. :param \*args: Variable arguments defining the TMA descriptor configuration :returns: A handle to the created TMA descriptor :rtype: tir.Call .. py:function:: tma_load(*args) Perform a Tensor Memory Access (TMA) load operation. :param \*args: Variable arguments specifying the TMA load parameters :returns: A handle to the TMA load operation :rtype: tir.Call .. py:function:: fence_proxy_async(*args) Create a fence for asynchronous proxy operations. :param \*args: Variable arguments for fence configuration :returns: A handle to the fence operation :rtype: tir.Call .. py:function:: tma_store_arrive(*args) Signal the arrival of a TMA store operation. :param \*args: Variable arguments for the store arrival operation :returns: A handle to the store arrive operation :rtype: tir.Call .. py:function:: tma_store_wait(*args) Wait for completion of TMA store operations. :param \*args: Variable arguments specifying which store operations to wait for :returns: A handle to the store wait operation :rtype: tir.Call .. py:function:: set_max_nreg(reg_count, is_inc) Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg :param reg_count: int The number of registers to allocate :param is_inc: int Whether to increment or decrement the register count 0 if decrement, 1 if increment :returns: A handle to the register setting operation :rtype: tir.Call .. py:function:: inc_max_nreg(reg_count) Increment the maximum number of registers to use. .. py:function:: dec_max_nreg(reg_count) Decrement the maximum number of registers to use. .. py:function:: no_set_max_nreg() Disable the maximum register limit setting. .. py:function:: mbarrier_wait_parity(mbarrier, parity) Wait for memory barrier parity condition. :param mbarrier: Optional[int, PrimExpr] The memory barrier to wait on :param parity: Optional[int, Var] The parity value to wait for .. rubric:: Examples .. code-block:: python # Wait for parity 0 on barrier 0 T.mbarrier_wait_parity(0, 0) # Wait for parity value in variable ko on barrier 1 T.mbarrier_wait_parity(1, ko) # Wait using barrier handle barrier = T.get_mbarrier(0) T.mbarrier_wait_parity(barrier, 1) # Common usage in pipelined kernels: for ko in range(num_stages): # Producer waits for consumer to finish previous iteration T.mbarrier_wait_parity(1, ko ^ 1) # Producer copies data T.copy(A_global, A_shared) # Producer signals data ready T.mbarrier_arrive(0) # Consumer waits for producer data T.mbarrier_wait_parity(0, ko) # Consumer computes T.gemm(A_shared, B_shared, C_local) # Consumer signals completion T.mbarrier_arrive(1) :returns: A handle to the barrier wait operation :rtype: tir.Call .. py:function:: mbarrier_arrive(mbarrier) Arrive at memory barrier. :param mbarrier: Optional[int, PrimExpr] The memory barrier to arrive at .. py:function:: mbarrier_expect_tx(*args) Set expected transaction count for memory barrier. :param \*args: Variable arguments specifying the expected transaction count :returns: A handle to the barrier expectation operation :rtype: tir.Call .. py:function:: wait_wgmma(id) Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. :param id: int The id of the WGMMA operation to wait for :returns: A handle to the WGMMA wait operation :rtype: tir.Call .. py:function:: barrier_wait(barrier_id, parity = None) Wait for a memory barrier to complete. :param barrier_id: Optional[int, PrimExpr] The memory barrier to wait on :param parity: Optional[int, Var] The parity value to wait for :returns: A handle to the barrier wait operation :rtype: tir.Call Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1. .. py:function:: barrier_arrive(barrier_id) Arrive at a memory barrier. :param barrier_id: Optional[int, PrimExpr] The memory barrier to arrive at .. py:function:: shfl_xor(value, offset) Perform a shuffle operation with XOR offset. :param value: Optional[int, PrimExpr] The value to shuffle :param offset: Optional[int, PrimExpr] The offset for the shuffle operation :returns: A handle to the shuffle operation :rtype: tir.Call .. py:function:: shfl_down(value, offset) Perform a shuffle operation with down offset. :param value: Optional[int, PrimExpr] The value to shuffle .. py:function:: shfl_up(value, offset) Perform a shuffle operation with up offset. :param value: Optional[int, PrimExpr] The value to shuffle .. py:function:: sync_threads() Synchronize all threads in a warp. .. py:function:: sync_thread_partial(barrier_id) Synchronize threads within a warp. :param barrier_id: Optional[int, PrimExpr] The memory barrier to synchronize :returns: A handle to the synchronization operation :rtype: tir.Call .. py:function:: sync_global() Synchronize all threads in a block. .. py:function:: sync_grid() Synchronize all threads in a grid.