tilelang.language.utilsΒΆ

FunctionsΒΆ

region(buffer, access_type, *args)

Create a tile memory-region descriptor for a BufferLoad.

buffer_to_tile_region(buffer, access_type)

Convert a TVM buffer to a tile region descriptor.

buffer_load_to_tile_region(load, access_type, extents)

Convert a buffer load operation to a tile region descriptor.

buffer_region_to_tile_region(buffer_region, ...)

Convert a buffer region to a tile region descriptor.

index_to_coordinates(index, shape)

Convert a flat (linear) index into multi-dimensional coordinates for a given shape.

linear_index(*args)

Compute a flat (linear) index from multi-dimensional coordinates and strides.

Module ContentsΒΆ

tilelang.language.utils.region(buffer, access_type, *args)ΒΆ

Create a tile memory-region descriptor for a BufferLoad.

Maps access_type (β€˜r’, β€˜w’, β€˜rw’) to the numeric codes expected by the tl.region intrinsic (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.

Parameters:
  • buffer (tir.BufferLoad) – The BufferLoad that identifies the underlying buffer and indices.

  • access_type (str) – One of β€˜r’, β€˜w’, or β€˜rw’ indicating read, write, or read-write access.

  • *args (tir.PrimExpr) – Extent expressions for each region dimension.

Returns:

A call to the tl.region intrinsic describing the memory region.

Return type:

tir.Call

Raises:

KeyError – If access_type is not one of β€˜r’, β€˜w’, or β€˜rw’.

tilelang.language.utils.buffer_to_tile_region(buffer, access_type)ΒΆ

Convert a TVM buffer to a tile region descriptor.

Parameters:
  • buffer (tir.Buffer) – The buffer to convert

  • access_type (str) – Type of access - β€˜r’ for read, β€˜w’ for write, β€˜rw’ for read-write

Returns:

A region descriptor covering the entire buffer

Return type:

tir.Call

tilelang.language.utils.buffer_load_to_tile_region(load, access_type, extents)ΒΆ

Convert a buffer load operation to a tile region descriptor.

Parameters:
  • load (tir.BufferLoad) – The buffer load operation

  • access_type (str) – Type of access - β€˜r’ for read, β€˜w’ for write, β€˜rw’ for read-write

  • extents (List[tir.PrimExpr]) – List of expressions defining the region size

Returns:

A region descriptor for the loaded area

Return type:

tir.Call

tilelang.language.utils.buffer_region_to_tile_region(buffer_region, access_type, extents)ΒΆ

Convert a buffer region to a tile region descriptor.

Parameters:
  • buffer_region (tir.BufferRegion) – The buffer region to convert

  • access_type (str) – Type of access - β€˜r’ for read, β€˜w’ for write, β€˜rw’ for read-write

  • extents (List[tvm.tir.PrimExpr])

Returns:

A region descriptor for the specified buffer region

Return type:

tir.Call

tilelang.language.utils.index_to_coordinates(index, shape)ΒΆ

Convert a flat (linear) index into multi-dimensional coordinates for a given shape.

Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.

Parameters:
  • index (int or PrimExpr) – The flat index to convert.

  • shape (Sequence[int]) – The extents of each dimension (length >= 1).

Returns:

Coordinates for each dimension in the same order as shape.

Return type:

List[PrimExpr]

tilelang.language.utils.linear_index(*args)ΒΆ

Compute a flat (linear) index from multi-dimensional coordinates and strides.

The function accepts a sequence of PrimExpr arguments where the first portion are coordinates and the trailing portion are the corresponding strides. The number of strides must equal (number of coordinates - 1). The linear index is computed as:

linear = coords[0] for each (coord, stride) in zip(coords[1:], strides):

linear = linear * stride + coord

Examples

  • linear_index(i) -> i

  • linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)

  • linear_index(i, j, stride_j) -> i * stride_j + j

  • linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k

  • linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v

Raises:

ValueError – If called with no arguments, or if the number of strides is not one less than the number of coordinates.

Returns:

The computed linear index expression.

Return type:

PrimExpr

Parameters:

args (tvm.tir.PrimExpr)