tilelang.carver.roller.hintΒΆ

Hint definition for schedule

ClassesΒΆ

TensorCoreExtraConfig

This class is used to store extra information for tensorcore

Stride

Manages stride information for a given axis of a tensor.

TileDict

Manages tiling information and configurations for computational tasks.

IntrinInfo

The information of tensorcore intrinsic related information

Hint

Central configuration class for managing various parameters of computational tasks.

Module ContentsΒΆ

class tilelang.carver.roller.hint.TensorCoreExtraConfig(AS_shape, BS_shape, AF_shape, BF_shape, tc_axis)ΒΆ

This class is used to store extra information for tensorcore

Parameters:
  • AS_shape (Tuple[int])

  • BS_shape (Tuple[int])

  • AF_shape (Tuple[int])

  • BF_shape (Tuple[int])

  • tc_axis (Tuple[int])

AS_shape: Tuple[int]ΒΆ
BS_shape: Tuple[int]ΒΆ
AF_shape: Tuple[int]ΒΆ
BF_shape: Tuple[int]ΒΆ
tc_axis: Tuple[int]ΒΆ
class tilelang.carver.roller.hint.Stride(stride=1, ax=-1)ΒΆ

Manages stride information for a given axis of a tensor.

Parameters:
  • stride (int)

  • ax (int)

property ax: intΒΆ
Return type:

int

property stride: intΒΆ
Return type:

int

compute_strides_from_shape(shape)ΒΆ
Parameters:

shape (List[int])

Return type:

List[int]

compute_elements_from_shape(shape)ΒΆ
Parameters:

shape (List[int])

Return type:

int

is_valid()ΒΆ
Return type:

bool

__repr__()ΒΆ
Return type:

str

class tilelang.carver.roller.hint.TileDict(output_tile)ΒΆ

Manages tiling information and configurations for computational tasks.

output_tileΒΆ
tile_mapΒΆ
rstep_mapΒΆ
cached_tensors_mapΒΆ
output_strides_mapΒΆ
tensor_strides_mapΒΆ
traffic = -1ΒΆ
smem_cost = -1ΒΆ
block_per_SM = -1ΒΆ
num_wave = -1ΒΆ
grid_size = -1ΒΆ
valid = TrueΒΆ
get_tile(func)ΒΆ
Return type:

List[int]

get_rstep(node)ΒΆ
Return type:

Dict[str, int]

__hash__()ΒΆ
Return type:

int

class tilelang.carver.roller.hint.IntrinInfo(in_dtype, out_dtype, trans_b, input_transform_kind=0, weight_transform_kind=0)ΒΆ

The information of tensorcore intrinsic related information

Parameters:
  • in_dtype (str)

  • out_dtype (str)

  • trans_b (bool)

  • input_transform_kind (int)

  • weight_transform_kind (int)

in_dtypeΒΆ
out_dtypeΒΆ
trans_a = FalseΒΆ
trans_bΒΆ
input_transform_kind = 0ΒΆ
weight_transform_kind = 0ΒΆ
__repr__()ΒΆ
Return type:

str

is_input_8bit()ΒΆ
Return type:

bool

property smooth_a: boolΒΆ
Return type:

bool

property smooth_b: boolΒΆ
Return type:

bool

property inter_transform_a: boolΒΆ
Return type:

bool

property inter_transform_b: boolΒΆ
Return type:

bool

class tilelang.carver.roller.hint.HintΒΆ

Bases: object

Central configuration class for managing various parameters of computational tasks.

arch = NoneΒΆ
use_tc = NoneΒΆ
block = []ΒΆ
thread = []ΒΆ
warp = []ΒΆ
rstep = []ΒΆ
reduce_thread = []ΒΆ
rasterization_planΒΆ
cached_tensors = []ΒΆ
output_stridesΒΆ
schedule_stages = NoneΒΆ
block_reduction_depth: int = NoneΒΆ
split_k_factor: int = 1ΒΆ
vectorize: Dict[str, int]ΒΆ
pipeline_stage = 1ΒΆ
use_async = FalseΒΆ
opt_shapes: Dict[str, int]ΒΆ
intrin_infoΒΆ
shared_scope: str = 'shared'ΒΆ
pass_context: DictΒΆ
to_dict()ΒΆ
Return type:

Dict

classmethod from_dict(dic)ΒΆ
Parameters:

dic (Dict)

Return type:

Hint

tensorcore_legalization()ΒΆ
property raxis_order: List[int]ΒΆ
Return type:

List[int]

property step: List[int]ΒΆ
Return type:

List[int]

__repr__()ΒΆ
Return type:

str

complete_config(node)ΒΆ
Parameters:

node (tilelang.carver.roller.PrimFuncNode)