tilelang.language.utilsΒΆ
FunctionsΒΆ
|
Create a tile memory-region descriptor for a BufferLoad. |
|
Convert a TVM buffer to a tile region descriptor. |
|
Convert a buffer load operation to a tile region descriptor. |
|
Convert a buffer region to a tile region descriptor. |
|
Convert a flat (linear) index into multi-dimensional coordinates for a given shape. |
|
Compute a flat (linear) index from multi-dimensional coordinates and strides. |
Module ContentsΒΆ
- tilelang.language.utils.region(buffer, access_type, *args)ΒΆ
Create a tile memory-region descriptor for a BufferLoad.
Maps access_type (βrβ, βwβ, βrwβ) to the numeric codes expected by the tl.region intrinsic (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.
- Parameters:
buffer (tir.BufferLoad) β The BufferLoad that identifies the underlying buffer and indices.
access_type (str) β One of βrβ, βwβ, or βrwβ indicating read, write, or read-write access.
*args (tir.PrimExpr) β Extent expressions for each region dimension.
- Returns:
A call to the tl.region intrinsic describing the memory region.
- Return type:
tir.Call
- Raises:
KeyError β If access_type is not one of βrβ, βwβ, or βrwβ.
- tilelang.language.utils.buffer_to_tile_region(buffer, access_type)ΒΆ
Convert a TVM buffer to a tile region descriptor.
- Parameters:
buffer (tir.Buffer) β The buffer to convert
access_type (str) β Type of access - βrβ for read, βwβ for write, βrwβ for read-write
- Returns:
A region descriptor covering the entire buffer
- Return type:
tir.Call
- tilelang.language.utils.buffer_load_to_tile_region(load, access_type, extents)ΒΆ
Convert a buffer load operation to a tile region descriptor.
- Parameters:
load (tir.BufferLoad) β The buffer load operation
access_type (str) β Type of access - βrβ for read, βwβ for write, βrwβ for read-write
extents (List[tir.PrimExpr]) β List of expressions defining the region size
- Returns:
A region descriptor for the loaded area
- Return type:
tir.Call
- tilelang.language.utils.buffer_region_to_tile_region(buffer_region, access_type, extents)ΒΆ
Convert a buffer region to a tile region descriptor.
- Parameters:
buffer_region (tir.BufferRegion) β The buffer region to convert
access_type (str) β Type of access - βrβ for read, βwβ for write, βrwβ for read-write
extents (List[tvm.tir.PrimExpr])
- Returns:
A region descriptor for the specified buffer region
- Return type:
tir.Call
- tilelang.language.utils.index_to_coordinates(index, shape)ΒΆ
Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.
- Parameters:
index (int or PrimExpr) β The flat index to convert.
shape (Sequence[int]) β The extents of each dimension (length >= 1).
- Returns:
Coordinates for each dimension in the same order as shape.
- Return type:
List[PrimExpr]
- tilelang.language.utils.linear_index(*args)ΒΆ
Compute a flat (linear) index from multi-dimensional coordinates and strides.
The function accepts a sequence of PrimExpr arguments where the first portion are coordinates and the trailing portion are the corresponding strides. The number of strides must equal (number of coordinates - 1). The linear index is computed as:
linear = coords[0] for each (coord, stride) in zip(coords[1:], strides):
linear = linear * stride + coord
Examples
linear_index(i) -> i
linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
linear_index(i, j, stride_j) -> i * stride_j + j
linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k
linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v
- Raises:
ValueError β If called with no arguments, or if the number of strides is not one less than the number of coordinates.
- Returns:
The computed linear index expression.
- Return type:
PrimExpr
- Parameters:
args (tvm.tir.PrimExpr)