tilelang.carver.roller.policy.tensorcore¶
Policy for tensorcore schedule
Attributes¶
Classes¶
Default Policy for fastdlight, a heuristic plan that tries to |
Module Contents¶
- tilelang.carver.roller.policy.tensorcore.logger¶
- class tilelang.carver.roller.policy.tensorcore.TensorCorePolicy(arch, tags=None)¶
Bases:
tilelang.carver.roller.policy.default.DefaultPolicy
Default Policy for fastdlight, a heuristic plan that tries to minimize memory traffic and maximize parallelism.for BitBLAS Schedule.
- Parameters:
arch (tilelang.carver.arch.TileDevice)
tags (Optional[Dict])
- wmma_k: int = 16¶
- pipeline_stage: int = 1¶
- use_async_copy: bool = False¶
- block_reduction_depth: int | None = None¶
- infer_node_smem_usage(td, node)¶
Infers the shared memory usage of a node given a TileDict configuration.
- Parameters:
td (TileDict) – The TileDict object containing the tile configuration.
node (PrimFuncNode) – The node for which to infer the shared memory usage.
- Returns:
The estimated amount of shared memory used by the node.
- Return type:
int
- get_node_reduce_step_candidates(node)¶
Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2.
- Parameters:
node (PrimFuncNode) – The node for which to calculate reduction step candidates. It contains reduction axes (raxis) with their domains (dom.extent).
- Returns:
A dictionary mapping axis variable names to lists of step candidates. For each axis in the node, this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2 as step candidates; for others, it uses all factors of the domain.
- Return type:
Dict[str, List[int]]
- check_tile_shape_isvalid(td)¶
Checks if the tile shapes in the TileDict are valid for the nodes in this context.
Parameters: - td (TileDict): The TileDict object containing tile shapes and other configurations.
Returns: - bool: True if all tile shapes are valid, False otherwise.
- Parameters:
- compute_node_stride_map(node, td)¶
Computes the stride map for a given node based on the TileDict configuration.
- Parameters:
node (PrimFuncNode) – The node for which to compute the stride map.
td (TileDict) – The TileDict object containing the tile configuration.
- Returns:
A tuple of dictionaries containing the output strides and tensor strides.
- Return type:
Tuple[Dict, Dict]