tilelang.contrib.cutedsl.gemm_V1 ================================ .. py:module:: tilelang.contrib.cutedsl.gemm_V1 Classes ------- .. autoapisummary:: tilelang.contrib.cutedsl.gemm_V1.Gemm_SM80 tilelang.contrib.cutedsl.gemm_V1.Gemm_SM90 Functions --------- .. autoapisummary:: tilelang.contrib.cutedsl.gemm_V1.make_aligned_tensor tilelang.contrib.cutedsl.gemm_V1.gemm_ss tilelang.contrib.cutedsl.gemm_V1.gemm_rs tilelang.contrib.cutedsl.gemm_V1.gemm_sr tilelang.contrib.cutedsl.gemm_V1.gemm_rr Module Contents --------------- .. py:function:: make_aligned_tensor(ptr, layout, align_bytes, swizzle=False) .. py:function:: gemm_ss(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr = None, B_ptr = None, C_ptr = None) GEMM with both A and B from shared memory .. py:function:: gemm_rs(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr = None, B_ptr = None, C_ptr = None) GEMM with A from register/fragment and B from shared memory .. py:function:: gemm_sr(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr = None, B_ptr = None, C_ptr = None) GEMM with A from shared memory and B from register/fragment .. py:function:: gemm_rr(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr = None, B_ptr = None, C_ptr = None) GEMM with both A and B from register/fragment .. py:class:: Gemm_SM80(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type) .. py:method:: __call__(sA_ptr, sB_ptr, rC_ptr) GEMM body: both A and B from shared memory .. py:method:: body_rs(rA_ptr, sB_ptr, rC_ptr) GEMM body_rs: A from register, B from shared memory .. py:method:: body_sr(sA_ptr, rB_ptr, rC_ptr) GEMM body_sr: A from shared memory, B from register .. py:class:: Gemm_SM90(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type) .. py:method:: make_tma_atom(tensor, smem_layout_staged, smem_tile, mcast_dim) :staticmethod: .. py:method:: get_tma_atom(tensor, tiler_mk, stages=1) :staticmethod: .. py:method:: make_smem_layout_AB(dtype, major_mode, tiler_mk, stages=1) :staticmethod: .. py:method:: __call__(sA_ptr, sB_ptr, rC_ptr, wg_wait = 0, clear_accum = False) .. py:method:: body_rs(rA_ptr, sB_ptr, rC_ptr, wg_wait = 0, clear_accum = False) GEMM body_rs for SM90/Hopper: A from register, B from shared memory. Based on cute::tl_wgmma::GemmTensorOp::body_rs from gemm_sm90.h