tilelang.carver.roller.policy.tensorcore module#

Policy for tensorcore schedule

class tilelang.carver.roller.policy.tensorcore.TensorCorePolicy(arch: TileDevice, tags: Optional[Dict] = None)#

Bases: DefaultPolicy

block_reduction_depth: Optional[int] = None#
check_tile_shape_isvalid(td: TileDict)#

Checks if the tile shapes in the TileDict are valid for the nodes in this context.

Parameters: - td (TileDict): The TileDict object containing tile shapes and other configurations.

Returns: - bool: True if all tile shapes are valid, False otherwise.

compute_node_stride_map(node: PrimFuncNode, td: TileDict)#

Computes the stride map for a given node based on the TileDict configuration.

Parameters:
  • node (PrimFuncNode) – The node for which to compute the stride map.

  • td (TileDict) – The TileDict object containing the tile configuration.

Returns:

A tuple of dictionaries containing the output strides and tensor strides.

Return type:

Tuple[Dict, Dict]

get_node_reduce_step_candidates(node)#

Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2.

Parameters:

node (PrimFuncNode) – The node for which to calculate reduction step candidates. It contains reduction axes (raxis) with their domains (dom.extent).

Returns:

A dictionary mapping axis variable names to lists of step candidates. For each axis in the node, this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2 as step candidates; for others, it uses all factors of the domain.

Return type:

Dict[str, List[int]]

infer_node_smem_usage(td: TileDict, node: PrimFuncNode)#

Infers the shared memory usage of a node given a TileDict configuration.

Parameters:
  • td (TileDict) – The TileDict object containing the tile configuration.

  • node (PrimFuncNode) – The node for which to infer the shared memory usage.

Returns:

The estimated amount of shared memory used by the node.

Return type:

int

pipeline_stage: int = 1#
plan_rasterization(td: TileDict)#

Plans the rasterization for the given TileDict. This function is not implemented yet.

Parameters:

td (TileDict) – The TileDict object to plan rasterization for.

Raises:

RasterRationPlan – This function is not implemented yet.

use_async_copy: bool = False#
wmma_k: int = 16#