tilelang.carver.template.matmul module#

class tilelang.carver.template.matmul.MatmulTemplate(M: Optional[int] = None, N: Optional[int] = None, K: Optional[int] = None, trans_A: bool = False, trans_B: bool = True, in_dtype: str = 'float16', out_dtype: str = 'float16', accum_dtype: str = 'float16', with_bias: bool = False)#

Bases: BaseTemplate

A template for matrix multiplication (MatMul).

This class defines the computation for a matrix-matrix multiplication with configurable parameters such as transposition, data types, and bias addition.

M#

Number of rows in matrix A and matrix C.

Type:

int

N#

Number of columns in matrix B and matrix C.

Type:

int

K#

Number of columns in matrix A and rows in matrix B.

Type:

int

trans_A#

Whether to transpose matrix A before multiplication.

Type:

bool

trans_B#

Whether to transpose matrix B before multiplication.

Type:

bool

in_dtype#

Data type of input matrices.

Type:

str

out_dtype#

Data type of output matrix.

Type:

str

accum_dtype#

Data type used for accumulation.

Type:

str

with_bias#

Whether to add a bias term.

Type:

bool

K: int = None#
M: int = None#
N: int = None#
accum_dtype: str = 'float16'#
property class_attributes#

Returns the class attributes in dictionary form.

Returns:

Dictionary of class attributes.

Return type:

dict

get_hardware_aware_configs(arch: Optional[TileDevice] = None, topk: int = 10) List[Hint]#

Retrieves optimized hardware-aware configurations.

Parameters:
  • arch (TileDevice, optional) – The target hardware architecture.

  • topk (int, optional) – Number of top configurations to consider.

Returns:

A list of optimization hints for hardware acceleration.

Return type:

List[Hint]

in_dtype: str = 'float16'#
initialize_function() None#

Defines and initializes the matrix multiplication computation.

This method sets up placeholders for input matrices, computes the matrix multiplication using TVM’s compute API, and optionally applies bias and type casting.

Raises:

AssertionError – If M, N, or K are not positive integers.

out_dtype: str = 'float16'#
params_as_dict()#

Returns the template parameters as a dictionary.

Returns:

Dictionary containing template parameter values.

Return type:

dict

trans_A: bool = False#
trans_B: bool = True#
with_bias: bool = False#