tilelang.utils.tensorΒΆ
ClassesΒΆ
Generic enumeration. |
FunctionsΒΆ
|
|
|
|
|
|
|
Custom function to assert that two tensors are "close enough," allowing a specified |
Module ContentsΒΆ
- class tilelang.utils.tensor.TensorSupplyTypeΒΆ
Bases:
enum.Enum
Generic enumeration.
Derive from this class to define new enumerations.
- Integer = 1ΒΆ
- Uniform = 2ΒΆ
- Normal = 3ΒΆ
- Randn = 4ΒΆ
- Zero = 5ΒΆ
- One = 6ΒΆ
- Auto = 7ΒΆ
- tilelang.utils.tensor.map_torch_type(intype)ΒΆ
- Parameters:
intype (str)
- Return type:
torch.dtype
- tilelang.utils.tensor.adapt_torch2tvm(arg)ΒΆ
- tilelang.utils.tensor.get_tensor_supply(supply_type=TensorSupplyType.Integer)ΒΆ
- Parameters:
supply_type (TensorSupplyType)
- tilelang.utils.tensor.torch_assert_close(tensor_a, tensor_b, rtol=0.01, atol=0.001, max_mismatched_ratio=0.001, verbose=False, equal_nan=True, check_device=True, check_dtype=True, check_layout=True, check_stride=False, base_name='LHS', ref_name='RHS')ΒΆ
Custom function to assert that two tensors are βclose enough,β allowing a specified percentage of mismatched elements.
Parameters:ΒΆ
- tensor_atorch.Tensor
The first tensor to compare.
- tensor_btorch.Tensor
The second tensor to compare.
- rtolfloat, optional
Relative tolerance for comparison. Default is 1e-2.
- atolfloat, optional
Absolute tolerance for comparison. Default is 1e-3.
- max_mismatched_ratiofloat, optional
Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements).
Raises:ΒΆ
- AssertionError:
If the ratio of mismatched elements exceeds max_mismatched_ratio.
- Parameters:
verbose (bool)
equal_nan (bool)
check_device (bool)
check_dtype (bool)
check_layout (bool)
check_stride (bool)
base_name (str)
ref_name (str)