tilelang.contrib.cutedsl.gemm_V1¶
Classes¶
Functions¶
|
|
|
GEMM with both A and B from shared memory |
|
GEMM with A from register/fragment and B from shared memory |
|
GEMM with A from shared memory and B from register/fragment |
|
GEMM with both A and B from register/fragment |
Module Contents¶
- tilelang.contrib.cutedsl.gemm_V1.make_aligned_tensor(ptr, layout, align_bytes, swizzle=False)¶
- Parameters:
ptr (cutlass.cute.Pointer)
layout (cutlass.cute.Layout)
align_bytes (int)
- tilelang.contrib.cutedsl.gemm_V1.gemm_ss(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)¶
GEMM with both A and B from shared memory
- Parameters:
A_ptr (cutlass.cute.Pointer)
B_ptr (cutlass.cute.Pointer)
C_ptr (cutlass.cute.Pointer)
- tilelang.contrib.cutedsl.gemm_V1.gemm_rs(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)¶
GEMM with A from register/fragment and B from shared memory
- Parameters:
A_ptr (cutlass.cute.Pointer)
B_ptr (cutlass.cute.Pointer)
C_ptr (cutlass.cute.Pointer)
- tilelang.contrib.cutedsl.gemm_V1.gemm_sr(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)¶
GEMM with A from shared memory and B from register/fragment
- Parameters:
A_ptr (cutlass.cute.Pointer)
B_ptr (cutlass.cute.Pointer)
C_ptr (cutlass.cute.Pointer)
- tilelang.contrib.cutedsl.gemm_V1.gemm_rr(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)¶
GEMM with both A and B from register/fragment
- Parameters:
A_ptr (cutlass.cute.Pointer)
B_ptr (cutlass.cute.Pointer)
C_ptr (cutlass.cute.Pointer)
- class tilelang.contrib.cutedsl.gemm_V1.Gemm_SM80(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type)¶
- __call__(sA_ptr, sB_ptr, rC_ptr)¶
GEMM body: both A and B from shared memory
- Parameters:
sA_ptr (cutlass.cute.Pointer)
sB_ptr (cutlass.cute.Pointer)
rC_ptr (cutlass.cute.Pointer)
- body_rs(rA_ptr, sB_ptr, rC_ptr)¶
GEMM body_rs: A from register, B from shared memory
- Parameters:
rA_ptr (cutlass.cute.Pointer)
sB_ptr (cutlass.cute.Pointer)
rC_ptr (cutlass.cute.Pointer)
- body_sr(sA_ptr, rB_ptr, rC_ptr)¶
GEMM body_sr: A from shared memory, B from register
- Parameters:
sA_ptr (cutlass.cute.Pointer)
rB_ptr (cutlass.cute.Pointer)
rC_ptr (cutlass.cute.Pointer)
- class tilelang.contrib.cutedsl.gemm_V1.Gemm_SM90(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type)¶
- static make_tma_atom(tensor, smem_layout_staged, smem_tile, mcast_dim)¶
- static get_tma_atom(tensor, tiler_mk, stages=1)¶
- static make_smem_layout_AB(dtype, major_mode, tiler_mk, stages=1)¶
- Parameters:
major_mode (cutlass.utils.LayoutEnum)
- __call__(sA_ptr, sB_ptr, rC_ptr, wg_wait=0, clear_accum=False)¶
- Parameters:
sA_ptr (cutlass.cute.Pointer)
sB_ptr (cutlass.cute.Pointer)
rC_ptr (cutlass.cute.Pointer)
wg_wait (cutlass.Constexpr)
clear_accum (cutlass.Constexpr)
- body_rs(rA_ptr, sB_ptr, rC_ptr, wg_wait=0, clear_accum=False)¶
GEMM body_rs for SM90/Hopper: A from register, B from shared memory. Based on cute::tl_wgmma::GemmTensorOp::body_rs from gemm_sm90.h
- Parameters:
rA_ptr (cutlass.cute.Pointer)
sB_ptr (cutlass.cute.Pointer)
rC_ptr (cutlass.cute.Pointer)
wg_wait (cutlass.Constexpr)
clear_accum (cutlass.Constexpr)