tilelang.contrib.cutedsl.reduce¶

Reduce operations for CuTeDSL backend. Based on tl_templates/cuda/reduce.h

Classes¶

SumOp

Sum reduction operator

MaxOp

Max reduction operator

MinOp

Min reduction operator

BitAndOp

Bitwise AND reduction operator

BitOrOp

Bitwise OR reduction operator

BitXorOp

Bitwise XOR reduction operator

Functions¶

min(a, b[, c, loc, ip])

max(a, b[, c, loc, ip])

bar_sync(barrier_id, number_of_threads)

bar_sync_ptx(barrier_id, number_of_threads)

AllReduce(reducer, threads, scale, thread_offset[, ...])

AllReduce operation implementing warp/block-level reduction.

Module Contents¶

tilelang.contrib.cutedsl.reduce.min(a, b, c=None, *, loc=None, ip=None)¶
Parameters:
  • a (float | cutlass.cute.typing.Float32)

  • b (float | cutlass.cute.typing.Float32)

  • c (float | cutlass.cute.typing.Float32 | None)

Return type:

cutlass.cute.typing.Float32

tilelang.contrib.cutedsl.reduce.max(a, b, c=None, *, loc=None, ip=None)¶
Parameters:
  • a (float | cutlass.cute.typing.Float32)

  • b (float | cutlass.cute.typing.Float32)

  • c (float | cutlass.cute.typing.Float32 | None)

Return type:

cutlass.cute.typing.Float32

class tilelang.contrib.cutedsl.reduce.SumOp¶

Sum reduction operator

static __call__(x, y)¶
class tilelang.contrib.cutedsl.reduce.MaxOp¶

Max reduction operator

static __call__(x, y)¶
class tilelang.contrib.cutedsl.reduce.MinOp¶

Min reduction operator

static __call__(x, y)¶
class tilelang.contrib.cutedsl.reduce.BitAndOp¶

Bitwise AND reduction operator

static __call__(x, y)¶
class tilelang.contrib.cutedsl.reduce.BitOrOp¶

Bitwise OR reduction operator

static __call__(x, y)¶
class tilelang.contrib.cutedsl.reduce.BitXorOp¶

Bitwise XOR reduction operator

static __call__(x, y)¶
tilelang.contrib.cutedsl.reduce.bar_sync(barrier_id, number_of_threads)¶
tilelang.contrib.cutedsl.reduce.bar_sync_ptx(barrier_id, number_of_threads)¶
tilelang.contrib.cutedsl.reduce.AllReduce(reducer, threads, scale, thread_offset, all_threads=None)¶

AllReduce operation implementing warp/block-level reduction. Based on tl::AllReduce from reduce.h

Parameters:
  • reducer – Reducer operator class (SumOp, MaxOp, etc.)

  • threads – Number of threads participating in reduction

  • scale – Reduction scale factor

  • thread_offset – Thread ID offset

  • all_threads – Total number of threads in block

Returns:

A callable object with run() and run_hopper() methods