tilelang.language.scan_op¶

Scan operations exposed on the TileLang language surface.

Functions¶

cumsum_fragment(src, dst, dim, reverse)

Compute cumulative sum for fragment buffers by copying to shared memory first.

cumsum(src[, dst, dim, reverse])

Compute the cumulative sum of src along dim, writing results to dst.

cummax_fragment(src, dst, dim, reverse)

Compute cumulative maximum for fragment buffers by staging through shared memory.

cummax(src[, dst, dim, reverse])

Compute the cumulative maximum of src along dim, writing results to dst.

Module Contents¶

tilelang.language.scan_op.cumsum_fragment(src, dst, dim, reverse)¶

Compute cumulative sum for fragment buffers by copying to shared memory first.

This macro handles cumulative sum operations on fragment buffers by first copying the data to shared memory, performing the cumsum operation, and then copying back.

Parameters:
  • src (tilelang._typing.BufferLikeType) – Source buffer (Buffer, BufferRegion, or BufferLoad) containing input data.

  • dst (tilelang._typing.BufferLikeType) – Destination buffer (Buffer, BufferRegion, or BufferLoad) for output data.

  • dim (int) – Dimension along which to compute cumulative sum.

  • reverse (bool) – If True, compute cumulative sum in reverse order.

Return type:

None

tilelang.language.scan_op.cumsum(src, dst=None, dim=0, reverse=False)¶

Compute the cumulative sum of src along dim, writing results to dst.

Negative dim indices are normalized (Python-style). If dst is None, the operation is performed in-place into src. Raises ValueError when dim is out of bounds for src.shape. When src.scope() == “local.fragment”, this delegates to cumsum_fragment; otherwise it emits the tl.cumsum intrinsic.

Supports Buffer, BufferRegion, and BufferLoad inputs, allowing operations on buffer slices/regions.

Examples

A 1D inclusive scan that writes the result into a separate shared-memory buffer:

>>> import tilelang.language as T
>>> @T.prim_func
... def kernel(A: T.Tensor((128,), "float32"), B: T.Tensor((128,), "float32")):
...     with T.Kernel(1, threads=128):
...         A_shared = T.alloc_shared((128,), "float32")
...         T.copy(A, A_shared)
...         T.cumsum(src=A_shared, dst=A_shared, dim=0)
...         T.copy(A_shared, B)

A 2D prefix sum along the last dimension with reverse accumulation:

>>> import tilelang.language as T
>>> @T.prim_func
... def kernel2d(A: T.Tensor((64, 64), "float16"), B: T.Tensor((64, 64), "float16")):
...     with T.Kernel(1, 1, threads=256):
...         tile = T.alloc_shared((64, 64), "float16")
...         T.copy(A, tile)
...         T.cumsum(src=tile, dim=1, reverse=True)
...         T.copy(tile, B)

Operating on a buffer region (slice):

>>> import tilelang.language as T
>>> @T.prim_func
... def kernel_region(InputG_fragment: T.Tensor((128,), "float32"), chunk_size: T.int32):
...     with T.Kernel(1, threads=128):
...         i = T.int32(0)
...         T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0)
Returns:

A handle to the emitted cumulative-sum operation.

Return type:

tirx.Call

Parameters:
  • src (tilelang._typing.BufferLikeType)

  • dst (tilelang._typing.BufferLikeType | None)

  • dim (int)

  • reverse (bool)

tilelang.language.scan_op.cummax_fragment(src, dst, dim, reverse)¶

Compute cumulative maximum for fragment buffers by staging through shared memory.

Parameters:
  • src (tilelang._typing.BufferLikeType) – Source buffer (Buffer, BufferRegion, or BufferLoad) containing input data.

  • dst (tilelang._typing.BufferLikeType) – Destination buffer (Buffer, BufferRegion, or BufferLoad) for output data.

  • dim (int) – Dimension along which to compute cumulative maximum.

  • reverse (bool) – If True, compute cumulative maximum in reverse order.

Return type:

None

tilelang.language.scan_op.cummax(src, dst=None, dim=0, reverse=False)¶

Compute the cumulative maximum of src along dim, writing results to dst.

Negative dim indices are normalized (Python-style). If dst is None, the operation is performed in-place into src. When src.scope() is “local.fragment”, this delegates to cummax_fragment; otherwise it emits the tl.cummax intrinsic.

Parameters:
  • src (tilelang._typing.BufferLikeType)

  • dst (tilelang._typing.BufferLikeType | None)

  • dim (int)

  • reverse (bool)

Return type:

tvm.tirx.PrimExpr | None