tilelang.utils.tensor ===================== .. py:module:: tilelang.utils.tensor Classes ------- .. autoapisummary:: tilelang.utils.tensor.TensorSupplyType Functions --------- .. autoapisummary:: tilelang.utils.tensor.map_torch_type tilelang.utils.tensor.adapt_torch2tvm tilelang.utils.tensor.get_tensor_supply tilelang.utils.tensor.torch_assert_close Module Contents --------------- .. py:class:: TensorSupplyType Bases: :py:obj:`enum.Enum` Generic enumeration. Derive from this class to define new enumerations. .. py:attribute:: Integer :value: 1 .. py:attribute:: Uniform :value: 2 .. py:attribute:: Normal :value: 3 .. py:attribute:: Randn :value: 4 .. py:attribute:: Zero :value: 5 .. py:attribute:: One :value: 6 .. py:attribute:: Auto :value: 7 .. py:function:: map_torch_type(intype) .. py:function:: adapt_torch2tvm(arg) .. py:function:: get_tensor_supply(supply_type = TensorSupplyType.Integer) .. py:function:: torch_assert_close(tensor_a, tensor_b, rtol=0.01, atol=0.001, max_mismatched_ratio=0.001, verbose = False, equal_nan = True, check_device = True, check_dtype = True, check_layout = True, check_stride = False, base_name = 'LHS', ref_name = 'RHS') Custom function to assert that two tensors are "close enough," allowing a specified percentage of mismatched elements. Parameters: ---------- tensor_a : torch.Tensor The first tensor to compare. tensor_b : torch.Tensor The second tensor to compare. rtol : float, optional Relative tolerance for comparison. Default is 1e-2. atol : float, optional Absolute tolerance for comparison. Default is 1e-3. max_mismatched_ratio : float, optional Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements). Raises: ------- AssertionError: If the ratio of mismatched elements exceeds `max_mismatched_ratio`.