tilelang.contrib.cutedsl.gemm_v2¶
Classes¶
Functions¶
|
|
|
|
|
|
|
|
|
WGMMA register-shared variant using PTX inline asm. |
Module Contents¶
- class tilelang.contrib.cutedsl.gemm_v2.GmmaDescriptor(desc_64=None)¶
- Parameters:
desc_64 (cutlass.cute.Int64)
- desc¶
- desc_i64¶
- __add__(offset)¶
- tilelang.contrib.cutedsl.gemm_v2.initialize_wgmma_descriptor(layout_type, leading_byte_offset, stride_byte_offset, desc, start_address)¶
- Parameters:
desc (GmmaDescriptor)
start_address (cutlass.cute.Pointer)
- tilelang.contrib.cutedsl.gemm_v2.increase_descriptor_offset(desc, offset)¶
- Parameters:
desc (GmmaDescriptor)
- tilelang.contrib.cutedsl.gemm_v2.warpgroup_fence_operand(*args)¶
- tilelang.contrib.cutedsl.gemm_v2.warpgroup_arrive()¶
- tilelang.contrib.cutedsl.gemm_v2.warpgroup_commit_batch()¶
- tilelang.contrib.cutedsl.gemm_v2.warpgroup_wait(N)¶
- tilelang.contrib.cutedsl.gemm_v2.wgmma_ss(A_dtype, B_dtype, C_dtype, M, N, K, tnspA, tnspB, scaleA, scaleB, desc_a, desc_b, C_ptr, scale_out)¶
- Parameters:
A_dtype (str)
B_dtype (str)
C_dtype (str)
M (cutlass.cutlass_dsl.Constexpr[int])
N (cutlass.cutlass_dsl.Constexpr[int])
K (cutlass.cutlass_dsl.Constexpr[int])
tnspA (bool)
tnspB (bool)
scaleA (int)
scaleB (int)
desc_a (GmmaDescriptor)
desc_b (GmmaDescriptor)
C_ptr (cutlass.cute.Pointer)
scale_out (cutlass.cutlass_dsl.Constexpr[int])
- tilelang.contrib.cutedsl.gemm_v2.wgmma_rs(A_dtype, B_dtype, C_dtype, M, N, K, tnspB, scaleA, scaleB, A_ptr, desc_b, C_ptr, scale_out)¶
WGMMA register-shared variant using PTX inline asm.
A operand comes from registers, B from shared memory descriptor. M is always 64. A is always K-major (not transposed).
- Parameters:
A_dtype (str)
B_dtype (str)
C_dtype (str)
M (cutlass.cutlass_dsl.Constexpr[int])
N (cutlass.cutlass_dsl.Constexpr[int])
K (cutlass.cutlass_dsl.Constexpr[int])
tnspB (cutlass.cutlass_dsl.Constexpr[bool])
scaleA (cutlass.cutlass_dsl.Constexpr[int])
scaleB (cutlass.cutlass_dsl.Constexpr[int])
A_ptr (cutlass.cute.Pointer)
desc_b (GmmaDescriptor)
C_ptr (cutlass.cute.Pointer)
scale_out (cutlass.cutlass_dsl.Constexpr[int])