tilelang.language.utils¶

Functions¶

index_to_coordinates(index, shape)

Convert a flat (linear) index into multi-dimensional coordinates for a given shape.

linear_index(*args)

Compute a flat (linear) index from multi-dimensional coordinates and strides.

Module Contents¶

tilelang.language.utils.index_to_coordinates(index, shape)¶

Convert a flat (linear) index into multi-dimensional coordinates for a given shape.

Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.

Parameters:
  • index (int or PrimExpr) – The flat index to convert.

  • shape (Sequence[int]) – The extents of each dimension (length >= 1).

Returns:

Coordinates for each dimension in the same order as shape.

Return type:

list[PrimExpr]

tilelang.language.utils.linear_index(*args)¶

Compute a flat (linear) index from multi-dimensional coordinates and strides.

The function accepts a sequence of PrimExpr arguments where the first portion are coordinates and the trailing portion are the corresponding strides. The number of strides must equal (number of coordinates - 1). The linear index is computed as:

linear = coords[0] for each (coord, stride) in zip(coords[1:], strides):

linear = linear * stride + coord

Examples

  • linear_index(i) -> i

  • linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)

  • linear_index(i, j, stride_j) -> i * stride_j + j

  • linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k

  • linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v

Raises:

ValueError – If called with no arguments, or if the number of strides is not one less than the number of coordinates.

Returns:

The computed linear index expression.

Return type:

PrimExpr

Parameters:

args (tvm.tir.PrimExpr)