tilelang.carver.template.flashattention ======================================= .. py:module:: tilelang.carver.template.flashattention Classes ------- .. autoapisummary:: tilelang.carver.template.flashattention.FlashAttentionTemplate Module Contents --------------- .. py:class:: FlashAttentionTemplate Bases: :py:obj:`tilelang.carver.template.base.BaseTemplate` .. py:attribute:: batch_size :type: int :value: 1 .. py:attribute:: num_heads :type: int :value: 1 .. py:attribute:: head_dim :type: int :value: 1 .. py:attribute:: seq_length :type: int :value: 1 .. py:attribute:: seq_kv_length :type: int :value: 1 .. py:attribute:: is_causal :type: bool :value: False .. py:attribute:: in_dtype :type: str :value: 'float16' .. py:attribute:: out_dtype :type: str :value: 'float16' .. py:attribute:: accum_dtype :type: str :value: 'float16' .. py:method:: get_hardware_aware_configs(arch = None, topk = 10) Retrieves optimized hardware-aware configurations. :param arch: The target hardware architecture. :type arch: TileDevice, optional :param topk: Number of top configurations to consider. :type topk: int, optional :returns: A list of optimization hints for hardware acceleration. :rtype: List[Hint] .. py:method:: initialize_function() Defines and initializes the matrix multiplication computation. This method sets up placeholders for input matrices, computes the matrix multiplication using TVM's compute API, and optionally applies bias and type casting. :raises AssertionError: If M, N, or K are not positive integers. .. py:method:: params_as_dict() Returns the template parameters as a dictionary. :returns: Dictionary containing template parameter values. :rtype: dict .. py:property:: class_attributes Returns the class attributes in dictionary form. :returns: Dictionary of class attributes. :rtype: dict .. py:method:: __repr__() Returns a string representation of the class instance. :returns: A formatted string representation of the class. :rtype: str