tilelang.tileop.gemm.gemm_mma¶

Classes¶

GemmMMA

Base class for GEMM tile operators.

Module Contents¶

class tilelang.tileop.gemm.gemm_mma.GemmMMA¶

Bases: tilelang.tileop.gemm.gemm_base.GemmBase

Base class for GEMM tile operators.

Classifies the GEMM variant by the memory scopes of operands A and B (SS, SR, RS, TS, RR) and provides common property accessors for the underlying gemm_node IR node.

infer_layout(target, thread_nums)¶
Parameters:
  • target (tvm.target.Target)

  • thread_nums (int)

lower(layout_map, target, thread_bounds, thread_var, mbar_phase_expr=None)¶
Parameters:
  • layout_map (dict)

  • target (tvm.target.Target)

  • thread_bounds (tvm.ir.Range)

  • thread_var (tvm.tir.Var)

  • mbar_phase_expr (tvm.tir.PrimExpr | None)

is_gemm_ss()¶

Return True if both A and B are in shared memory (SS variant).

Return type:

bool

is_gemm_sr()¶

Return True if A is in shared memory and B is in registers (SR variant).

Return type:

bool

is_gemm_rs()¶

Return True if A is in registers and B is in shared memory (RS variant).

Return type:

bool

is_gemm_rr()¶

Return True if both A and B are in registers (RR variant).

Return type:

bool