tilelang.carver.template.flashattention module#

class tilelang.carver.template.flashattention.FlashAttentionTemplate(_output_nodes: List[tilelang.carver.roller.node.OutputNode] = None, batch_size: int = 1, num_heads: int = 1, head_dim: int = 1, seq_length: int = 1, seq_kv_length: int = 1, is_causal: bool = False, in_dtype: str = 'float16', out_dtype: str = 'float16', accum_dtype: str = 'float16')#

Bases: BaseTemplate

accum_dtype: str = 'float16'#
batch_size: int = 1#
property class_attributes#

Returns the class attributes in dictionary form.

Returns:

Dictionary of class attributes.

Return type:

dict

get_hardware_aware_configs(arch: Optional[TileDevice] = None, topk: int = 10) List[Hint]#

Retrieves optimized hardware-aware configurations.

Parameters:
  • arch (TileDevice, optional) – The target hardware architecture.

  • topk (int, optional) – Number of top configurations to consider.

Returns:

A list of optimization hints for hardware acceleration.

Return type:

List[Hint]

head_dim: int = 1#
in_dtype: str = 'float16'#
initialize_function() None#

Defines and initializes the matrix multiplication computation.

This method sets up placeholders for input matrices, computes the matrix multiplication using TVM’s compute API, and optionally applies bias and type casting.

Raises:

AssertionError – If M, N, or K are not positive integers.

is_causal: bool = False#
num_heads: int = 1#
out_dtype: str = 'float16'#
params_as_dict()#

Returns the template parameters as a dictionary.

Returns:

Dictionary containing template parameter values.

Return type:

dict

seq_kv_length: int = 1#
seq_length: int = 1#