tilelang.carver.template.flashattention module#
- class tilelang.carver.template.flashattention.FlashAttentionTemplate(_output_nodes: List[tilelang.carver.roller.node.OutputNode] = None, batch_size: int = 1, num_heads: int = 1, head_dim: int = 1, seq_length: int = 1, seq_kv_length: int = 1, is_causal: bool = False, in_dtype: str = 'float16', out_dtype: str = 'float16', accum_dtype: str = 'float16')#
Bases:
BaseTemplate
- accum_dtype: str = 'float16'#
- batch_size: int = 1#
- property class_attributes#
Returns the class attributes in dictionary form.
- Returns:
Dictionary of class attributes.
- Return type:
dict
- get_hardware_aware_configs(arch: Optional[TileDevice] = None, topk: int = 10) List[Hint] #
Retrieves optimized hardware-aware configurations.
- Parameters:
arch (TileDevice, optional) – The target hardware architecture.
topk (int, optional) – Number of top configurations to consider.
- Returns:
A list of optimization hints for hardware acceleration.
- Return type:
List[Hint]
- head_dim: int = 1#
- in_dtype: str = 'float16'#
- initialize_function() None #
Defines and initializes the matrix multiplication computation.
This method sets up placeholders for input matrices, computes the matrix multiplication using TVM’s compute API, and optionally applies bias and type casting.
- Raises:
AssertionError – If M, N, or K are not positive integers.
- is_causal: bool = False#
- num_heads: int = 1#
- out_dtype: str = 'float16'#
- params_as_dict()#
Returns the template parameters as a dictionary.
- Returns:
Dictionary containing template parameter values.
- Return type:
dict
- seq_kv_length: int = 1#
- seq_length: int = 1#