tilelang.tileop.base ==================== .. py:module:: tilelang.tileop.base Classes ------- .. autoapisummary:: tilelang.tileop.base.GemmWarpPolicy Module Contents --------------- .. py:class:: GemmWarpPolicy Bases: :py:obj:`enum.IntEnum` Enumeration for GEMM Warp Partitioning Policies. .. py:attribute:: Square :value: 0 .. py:attribute:: FullRow :value: 1 .. py:attribute:: FullCol :value: 2 .. py:method:: is_square() Check if the policy is a square partitioning. :returns: True if the policy is square, False otherwise. :rtype: bool .. py:method:: is_full_row() Check if the policy is a full row partitioning. :returns: True if the policy is full row, False otherwise. :rtype: bool .. py:method:: is_full_col() Check if the policy is a full column partitioning. :returns: True if the policy is full column, False otherwise. :rtype: bool .. py:method:: to_prime_factors(num) :staticmethod: Compute the prime factorization of a given number. :param num: The number to factorize. :type num: int :returns: A list of prime factors of the number. :rtype: list .. py:method:: compute_warp_partition(M, N, num_warps) Compute the warp partition (m_warp, n_warp) based on the given policy. :param M: The number of rows in the GEMM workload. :type M: int :param N: The number of columns in the GEMM workload. :type N: int :param num_warps: The total number of warps available. :type num_warps: int :returns: A tuple (m_warp, n_warp) representing the partitioning of warps. :rtype: tuple :raises ValueError: If the policy is invalid or the partitioning fails. :raises AssertionError: If M or N is not divisible by the required factor for FullRow or FullCol policies. .. py:method:: from_warp_partition(m_warp, n_warp) :classmethod: Determine the warp policy based on the given warp partitioning. :param m_warp: Number of warps in the row dimension :type m_warp: int :param n_warp: Number of warps in the column dimension :type n_warp: int :returns: The corresponding warp policy :rtype: GemmWarpPolicy .. rubric:: Examples >>> GemmWarpPolicy.from_block_row_cols(4, 1) # All warps in rows GemmWarpPolicy.FullRow >>> GemmWarpPolicy.from_block_row_cols(1, 4) # All warps in columns GemmWarpPolicy.FullCol >>> GemmWarpPolicy.from_block_row_cols(2, 2) # Balanced distribution GemmWarpPolicy.Square