tilelang.utils.sparse¶
Attributes¶
Functions¶
|
|
|
|
|
Compress a tensor using the appropriate method based on the CUDA architecture. |
|
Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension. |
|
Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension. |
Module Contents¶
- tilelang.utils.sparse.compress_util¶
- tilelang.utils.sparse.compress_sm90(A, block_k, transposed)¶
- Parameters:
A (torch.Tensor)
block_k (int)
transposed (bool)
- Return type:
tuple[torch.Tensor, torch.Tensor]
- tilelang.utils.sparse.compress_sm80(A, transposed)¶
- Parameters:
A (torch.Tensor)
transposed (bool)
- Return type:
tuple[torch.Tensor, torch.Tensor]
- tilelang.utils.sparse.compress(A, transposed, arch=None, **kwargs)¶
Compress a tensor using the appropriate method based on the CUDA architecture.
- Parameters:
A (torch.Tensor)
transposed (bool)
arch (str | None)
- Return type:
tuple[torch.Tensor, torch.Tensor]
- tilelang.utils.sparse.randn_semi_sparse(M, K, dtype=torch.float16, device='cuda', transposed=False)¶
Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension. :param M: Number of rows :type M: int :param K: Number of columns :type K: int :param dtype: Data type of the tensor :param device: Device to create the tensor on :param transposed: If True, returns a transposed tensor of shape (K, M) :type transposed: bool
- Parameters:
M (int)
K (int)
transposed (bool)
- tilelang.utils.sparse.arange_semi_sparse(M, K, dtype=torch.float16, device='cuda', transposed=False)¶
Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension. :param M: Number of rows :type M: int :param K: Number of columns :type K: int :param dtype: Data type of the tensor :param device: Device to create the tensor on :param transposed: If True, returns a transposed tensor of shape (K, M) :type transposed: bool
- Parameters:
M (int)
K (int)
transposed (bool)