tilelang.utils.tensor module#

class tilelang.utils.tensor.TensorSupplyType(value)#

Bases: Enum

An enumeration.

Auto = 7#
Integer = 1#
Normal = 3#
One = 6#
Randn = 4#
Uniform = 2#
Zero = 5#
tilelang.utils.tensor.adapt_torch2tvm(arg)#
tilelang.utils.tensor.get_tensor_supply(supply_type: TensorSupplyType)#
tilelang.utils.tensor.map_torch_type(intype: str) torch.dtype#
tilelang.utils.tensor.torch_assert_close(tensor_a, tensor_b, rtol=0.01, atol=0.001, max_mismatched_ratio=0.001, verbose: bool = False, equal_nan: bool = True, check_device: bool = True, check_dtype: bool = True, check_layout: bool = True, check_stride: bool = False)#

Custom function to assert that two tensors are “close enough,” allowing a specified percentage of mismatched elements.

tensor_atorch.Tensor

The first tensor to compare.

tensor_btorch.Tensor

The second tensor to compare.

rtolfloat, optional

Relative tolerance for comparison. Default is 1e-2.

atolfloat, optional

Absolute tolerance for comparison. Default is 1e-3.

max_mismatched_ratiofloat, optional

Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements).

AssertionError:

If the ratio of mismatched elements exceeds max_mismatched_ratio.