tilelang.utils.tensor module#
- class tilelang.utils.tensor.TensorSupplyType(value)#
Bases:
Enum
An enumeration.
- Auto = 7#
- Integer = 1#
- Normal = 3#
- One = 6#
- Randn = 4#
- Uniform = 2#
- Zero = 5#
- tilelang.utils.tensor.adapt_torch2tvm(arg)#
- tilelang.utils.tensor.get_tensor_supply(supply_type: TensorSupplyType)#
- tilelang.utils.tensor.map_torch_type(intype: str) torch.dtype #
- tilelang.utils.tensor.torch_assert_close(tensor_a, tensor_b, rtol=0.01, atol=0.001, max_mismatched_ratio=0.001, verbose: bool = False, equal_nan: bool = True, check_device: bool = True, check_dtype: bool = True, check_layout: bool = True, check_stride: bool = False)#
Custom function to assert that two tensors are “close enough,” allowing a specified percentage of mismatched elements.
- tensor_atorch.Tensor
The first tensor to compare.
- tensor_btorch.Tensor
The second tensor to compare.
- rtolfloat, optional
Relative tolerance for comparison. Default is 1e-2.
- atolfloat, optional
Absolute tolerance for comparison. Default is 1e-3.
- max_mismatched_ratiofloat, optional
Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements).
- AssertionError:
If the ratio of mismatched elements exceeds max_mismatched_ratio.