tilelang.carver.roller.policy.tensorcore module#
Policy for tensorcore schedule
- class tilelang.carver.roller.policy.tensorcore.TensorCorePolicy(arch: TileDevice, tags: Optional[Dict] = None)#
Bases:
DefaultPolicy
- block_reduction_depth: Optional[int] = None#
- check_tile_shape_isvalid(td: TileDict)#
Checks if the tile shapes in the TileDict are valid for the nodes in this context.
Parameters: - td (TileDict): The TileDict object containing tile shapes and other configurations.
Returns: - bool: True if all tile shapes are valid, False otherwise.
- compute_node_stride_map(node: PrimFuncNode, td: TileDict)#
Computes the stride map for a given node based on the TileDict configuration.
- Parameters:
node (PrimFuncNode) – The node for which to compute the stride map.
td (TileDict) – The TileDict object containing the tile configuration.
- Returns:
A tuple of dictionaries containing the output strides and tensor strides.
- Return type:
Tuple[Dict, Dict]
- get_node_reduce_step_candidates(node)#
Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2.
- Parameters:
node (PrimFuncNode) – The node for which to calculate reduction step candidates. It contains reduction axes (raxis) with their domains (dom.extent).
- Returns:
A dictionary mapping axis variable names to lists of step candidates. For each axis in the node, this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2 as step candidates; for others, it uses all factors of the domain.
- Return type:
Dict[str, List[int]]
- infer_node_smem_usage(td: TileDict, node: PrimFuncNode)#
Infers the shared memory usage of a node given a TileDict configuration.
- Parameters:
td (TileDict) – The TileDict object containing the tile configuration.
node (PrimFuncNode) – The node for which to infer the shared memory usage.
- Returns:
The estimated amount of shared memory used by the node.
- Return type:
int
- pipeline_stage: int = 1#
- plan_rasterization(td: TileDict)#
Plans the rasterization for the given TileDict. This function is not implemented yet.
- Parameters:
td (TileDict) – The TileDict object to plan rasterization for.
- Raises:
RasterRationPlan – This function is not implemented yet.
- use_async_copy: bool = False#
- wmma_k: int = 16#