tilelang.primitives.gemm.base module#
- class tilelang.primitives.gemm.base.GemmBaseParams(A: tvm.tir.buffer.Buffer, B: tvm.tir.buffer.Buffer, C: tvm.tir.buffer.Buffer, transpose_A: bool = False, transpose_B: bool = False, block_row_warps: Optional[int] = None, block_col_warps: Optional[int] = None, warp_row_tiles: Optional[int] = None, warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, policy: tilelang.primitives.gemm.base.GemmWarpPolicy = (<GemmWarpPolicy.Square: 0>, ), k_pack: int = 1)#
Bases:
object
- A: Buffer#
- B: Buffer#
- C: Buffer#
- block_col_warps: Optional[int] = None#
- block_row_warps: Optional[int] = None#
- chunk: Optional[int] = None#
- property class_attributes#
- get_warp_size() int #
- infer_block_partition(threads: Optional[int]) None #
Infer and set block partition parameters (e.g., block_row_warps, block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the shape of A and B. If these parameters are not already specified, the method will attempt to infer them automatically based on the given threads.
- Parameters:
threads (Optional[int]) – The total number of threads in a block. Must be provided if any block partition parameter is not already set.
- Raises:
AssertionError – If threads is None but any block partition parameter is missing, or if A and B have inconsistent shapes for GEMM.
- k_pack: int = 1#
- params_as_dict()#
- policy: GemmWarpPolicy = (<GemmWarpPolicy.Square: 0>,)#
- transpose_A: bool = False#
- transpose_B: bool = False#
- warp_col_tiles: Optional[int] = None#
- warp_row_tiles: Optional[int] = None#
- class tilelang.primitives.gemm.base.GemmWarpPolicy(value)#
Bases:
IntEnum
Enumeration for GEMM Warp Partitioning Policies.
- FullCol = 2#
- FullRow = 1#
- Square = 0#
- compute_warp_partition(M, N, num_warps)#
Compute the warp partition (m_warp, n_warp) based on the given policy.
- Parameters:
M (int) – The number of rows in the GEMM workload.
N (int) – The number of columns in the GEMM workload.
num_warps (int) – The total number of warps available.
- Returns:
A tuple (m_warp, n_warp) representing the partitioning of warps.
- Return type:
tuple
- Raises:
ValueError – If the policy is invalid or the partitioning fails.
AssertionError – If M or N is not divisible by the required factor for FullRow or FullCol policies.
- classmethod from_warp_partition(m_warp: int, n_warp: int) GemmWarpPolicy #
Determine the warp policy based on the given warp partitioning.
- Parameters:
m_warp (int) – Number of warps in the row dimension
n_warp (int) – Number of warps in the column dimension
- Returns:
The corresponding warp policy
- Return type:
Examples
>>> GemmWarpPolicy.from_block_row_cols(4, 1) # All warps in rows GemmWarpPolicy.FullRow >>> GemmWarpPolicy.from_block_row_cols(1, 4) # All warps in columns GemmWarpPolicy.FullCol >>> GemmWarpPolicy.from_block_row_cols(2, 2) # Balanced distribution GemmWarpPolicy.Square
- is_full_col() bool #
Check if the policy is a full column partitioning.
- Returns:
True if the policy is full column, False otherwise.
- Return type:
bool
- is_full_row() bool #
Check if the policy is a full row partitioning.
- Returns:
True if the policy is full row, False otherwise.
- Return type:
bool
- is_square() bool #
Check if the policy is a square partitioning.
- Returns:
True if the policy is square, False otherwise.
- Return type:
bool
- static to_prime_factors(num)#
Compute the prime factorization of a given number.
- Parameters:
num (int) – The number to factorize.
- Returns:
A list of prime factors of the number.
- Return type:
list