tilelang.language.utils¶
Functions¶
|
Convert a flat (linear) index into multi-dimensional coordinates for a given shape. |
|
Compute a flat (linear) index from multi-dimensional coordinates and strides. |
Module Contents¶
- tilelang.language.utils.index_to_coordinates(index, shape)¶
Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.
- Parameters:
index (int or PrimExpr) – The flat index to convert.
shape (Sequence[int]) – The extents of each dimension (length >= 1).
- Returns:
Coordinates for each dimension in the same order as shape.
- Return type:
list[PrimExpr]
- tilelang.language.utils.linear_index(*args)¶
Compute a flat (linear) index from multi-dimensional coordinates and strides.
The function accepts a sequence of PrimExpr arguments where the first portion are coordinates and the trailing portion are the corresponding strides. The number of strides must equal (number of coordinates - 1). The linear index is computed as:
linear = coords[0] for each (coord, stride) in zip(coords[1:], strides):
linear = linear * stride + coord
Examples
linear_index(i) -> i
linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
linear_index(i, j, stride_j) -> i * stride_j + j
linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k
linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v
- Raises:
ValueError – If called with no arguments, or if the number of strides is not one less than the number of coordinates.
- Returns:
The computed linear index expression.
- Return type:
PrimExpr
- Parameters:
args (tvm.tir.PrimExpr)