tilelang.contrib.cutedsl.gemm_v2¶

Classes¶

Functions¶

initialize_wgmma_descriptor(layout_type, ...)

increase_descriptor_offset(desc, offset)

warpgroup_fence_operand(*args)

warpgroup_arrive()

warpgroup_commit_batch()

warpgroup_wait(N)

wgmma_ss(A_dtype, B_dtype, C_dtype, M, N, K, tnspA, ...)

wgmma_rs(A_dtype, B_dtype, C_dtype, M, N, K, tnspB, ...)

WGMMA register-shared variant using PTX inline asm.

Module Contents¶

class tilelang.contrib.cutedsl.gemm_v2.GmmaDescriptor(desc_64=None)¶
Parameters:

desc_64 (cutlass.cute.Int64)

desc¶
desc_i64¶
__add__(offset)¶
tilelang.contrib.cutedsl.gemm_v2.initialize_wgmma_descriptor(layout_type, leading_byte_offset, stride_byte_offset, desc, start_address)¶
Parameters:
tilelang.contrib.cutedsl.gemm_v2.increase_descriptor_offset(desc, offset)¶
Parameters:

desc (GmmaDescriptor)

tilelang.contrib.cutedsl.gemm_v2.warpgroup_fence_operand(*args)¶
tilelang.contrib.cutedsl.gemm_v2.warpgroup_arrive()¶
tilelang.contrib.cutedsl.gemm_v2.warpgroup_commit_batch()¶
tilelang.contrib.cutedsl.gemm_v2.warpgroup_wait(N)¶
tilelang.contrib.cutedsl.gemm_v2.wgmma_ss(A_dtype, B_dtype, C_dtype, M, N, K, tnspA, tnspB, scaleA, scaleB, desc_a, desc_b, C_ptr, scale_out)¶
Parameters:
  • A_dtype (str)

  • B_dtype (str)

  • C_dtype (str)

  • M (cutlass.cutlass_dsl.Constexpr[int])

  • N (cutlass.cutlass_dsl.Constexpr[int])

  • K (cutlass.cutlass_dsl.Constexpr[int])

  • tnspA (bool)

  • tnspB (bool)

  • scaleA (int)

  • scaleB (int)

  • desc_a (GmmaDescriptor)

  • desc_b (GmmaDescriptor)

  • C_ptr (cutlass.cute.Pointer)

  • scale_out (cutlass.cutlass_dsl.Constexpr[int])

tilelang.contrib.cutedsl.gemm_v2.wgmma_rs(A_dtype, B_dtype, C_dtype, M, N, K, tnspB, scaleA, scaleB, A_ptr, desc_b, C_ptr, scale_out)¶

WGMMA register-shared variant using PTX inline asm.

A operand comes from registers, B from shared memory descriptor. M is always 64. A is always K-major (not transposed).

Parameters:
  • A_dtype (str)

  • B_dtype (str)

  • C_dtype (str)

  • M (cutlass.cutlass_dsl.Constexpr[int])

  • N (cutlass.cutlass_dsl.Constexpr[int])

  • K (cutlass.cutlass_dsl.Constexpr[int])

  • tnspB (cutlass.cutlass_dsl.Constexpr[bool])

  • scaleA (cutlass.cutlass_dsl.Constexpr[int])

  • scaleB (cutlass.cutlass_dsl.Constexpr[int])

  • A_ptr (cutlass.cute.Pointer)

  • desc_b (GmmaDescriptor)

  • C_ptr (cutlass.cute.Pointer)

  • scale_out (cutlass.cutlass_dsl.Constexpr[int])