tilelang.tileop.gemm.gemm_base¶
Classes¶
Base class for GEMM tile operators. |
Module Contents¶
- class tilelang.tileop.gemm.gemm_base.GemmBase¶
Base class for GEMM tile operators.
Classifies the GEMM variant by the memory scopes of operands A and B (SS, SR, RS, TS, RR) and provides common property accessors for the underlying
gemm_nodeIR node.- gemm_node: tvm.ir.base.Node¶
- __post_init__()¶
- Return type:
None
- abstract infer_layout(target, thread_nums)¶
- Parameters:
target (tvm.target.Target)
thread_nums (int)
- abstract lower(layout_map, target, thread_bounds, thread_var, mbar_phase_expr=None)¶
- Parameters:
layout_map (dict)
target (tvm.target.Target)
thread_bounds (tvm.ir.Range)
thread_var (tvm.tirx.Var)
mbar_phase_expr (tvm.tirx.PrimExpr | None)
- is_gemm_sr()¶
Return True if A is in shared memory and B is in registers (SR variant).
- Return type:
- is_gemm_rs()¶
Return True if A is in registers and B is in shared memory (RS variant).
- Return type:
- is_gemm_ts()¶
Return True if A is in tensor memory and B is in shared memory (TS variant).
- Return type:
- property M: int¶
- Return type:
int
- property N: int¶
- Return type:
int
- property K: int¶
- Return type:
int
- property a_dtype¶
A operand dtype.
- property b_dtype¶
B operand dtype.
- property accum_dtype: str¶
- Return type:
str
- property chunk: int¶
- Return type:
int
- property A: tvm.tirx.Buffer¶
- Return type:
tvm.tirx.Buffer
- property B: tvm.tirx.Buffer¶
- Return type:
tvm.tirx.Buffer
- property C: tvm.tirx.Buffer¶
- Return type:
tvm.tirx.Buffer
- property ARegion¶
- property BRegion¶
- property CRegion¶
- property stride_A: int¶
- Return type:
int
- property stride_B: int¶
- Return type:
int
- property offset_A: int¶
- Return type:
int
- property offset_B: int¶
- Return type:
int
- property clear_accum: tvm.ir.PrimExpr¶
- Return type:
tvm.ir.PrimExpr
- property k_pack: int¶
- Return type:
int
- property wg_wait: int¶
- Return type:
int
- property policy: tilelang.tileop.base.GemmWarpPolicy¶
- Return type:
- property mbarptr: tvm.ir.PrimExpr¶
- Return type:
tvm.ir.PrimExpr
- property mbar: tvm.tirx.BufferLoad | None¶
- Return type:
tvm.tirx.BufferLoad | None
- property C_coords¶
- property SFARegion¶
- property SFBRegion¶
- property sf_k_start: tvm.ir.PrimExpr¶
- Return type:
tvm.ir.PrimExpr
- get_region_base_offsets(region)¶
Get the base offset (start index) for each dimension from a BufferRegion.
For example, if region is A_shared[ko % 2, 0:128, 0:64], this returns [ko % 2, 0, 0]
- Parameters:
region – BufferRegion object
- Returns:
List of PrimExpr representing the base offset for each dimension
- property A_base_offsets¶
Get base offsets for each dimension of A region
- property B_base_offsets¶
Get base offsets for each dimension of B region
- property C_base_offsets¶
Get base offsets for each dimension of C region