tilelang.contrib.cutedsl.gemm_V1¶

Classes¶

Functions¶

make_aligned_tensor(ptr, layout, align_bytes[, swizzle])

gemm_ss(M, N, K, warp_m, warp_n, trans_A, trans_B, ...)

GEMM with both A and B from shared memory

gemm_rs(M, N, K, warp_m, warp_n, trans_A, trans_B, ...)

GEMM with A from register/fragment and B from shared memory

gemm_sr(M, N, K, warp_m, warp_n, trans_A, trans_B, ...)

GEMM with A from shared memory and B from register/fragment

gemm_rr(M, N, K, warp_m, warp_n, trans_A, trans_B, ...)

GEMM with both A and B from register/fragment

Module Contents¶

tilelang.contrib.cutedsl.gemm_V1.make_aligned_tensor(ptr, layout, align_bytes, swizzle=False)¶
Parameters:
  • ptr (cutlass.cute.Pointer)

  • layout (cutlass.cute.Layout)

  • align_bytes (int)

tilelang.contrib.cutedsl.gemm_V1.gemm_ss(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)¶

GEMM with both A and B from shared memory

Parameters:
  • A_ptr (cutlass.cute.Pointer)

  • B_ptr (cutlass.cute.Pointer)

  • C_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_V1.gemm_rs(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)¶

GEMM with A from register/fragment and B from shared memory

Parameters:
  • A_ptr (cutlass.cute.Pointer)

  • B_ptr (cutlass.cute.Pointer)

  • C_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_V1.gemm_sr(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)¶

GEMM with A from shared memory and B from register/fragment

Parameters:
  • A_ptr (cutlass.cute.Pointer)

  • B_ptr (cutlass.cute.Pointer)

  • C_ptr (cutlass.cute.Pointer)

tilelang.contrib.cutedsl.gemm_V1.gemm_rr(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, use_wgmma=None, wg_wait=0, A_ptr=None, B_ptr=None, C_ptr=None)¶

GEMM with both A and B from register/fragment

Parameters:
  • A_ptr (cutlass.cute.Pointer)

  • B_ptr (cutlass.cute.Pointer)

  • C_ptr (cutlass.cute.Pointer)

class tilelang.contrib.cutedsl.gemm_V1.Gemm_SM80(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type)¶
__call__(sA_ptr, sB_ptr, rC_ptr)¶

GEMM body: both A and B from shared memory

Parameters:
  • sA_ptr (cutlass.cute.Pointer)

  • sB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

body_rs(rA_ptr, sB_ptr, rC_ptr)¶

GEMM body_rs: A from register, B from shared memory

Parameters:
  • rA_ptr (cutlass.cute.Pointer)

  • sB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

body_sr(sA_ptr, rB_ptr, rC_ptr)¶

GEMM body_sr: A from shared memory, B from register

Parameters:
  • sA_ptr (cutlass.cute.Pointer)

  • rB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

class tilelang.contrib.cutedsl.gemm_V1.Gemm_SM90(M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type)¶
static make_tma_atom(tensor, smem_layout_staged, smem_tile, mcast_dim)¶
static get_tma_atom(tensor, tiler_mk, stages=1)¶
static make_smem_layout_AB(dtype, major_mode, tiler_mk, stages=1)¶
Parameters:

major_mode (cutlass.utils.LayoutEnum)

__call__(sA_ptr, sB_ptr, rC_ptr, wg_wait=0, clear_accum=False)¶
Parameters:
  • sA_ptr (cutlass.cute.Pointer)

  • sB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

  • wg_wait (cutlass.Constexpr)

  • clear_accum (cutlass.Constexpr)

body_rs(rA_ptr, sB_ptr, rC_ptr, wg_wait=0, clear_accum=False)¶

GEMM body_rs for SM90/Hopper: A from register, B from shared memory. Based on cute::tl_wgmma::GemmTensorOp::body_rs from gemm_sm90.h

Parameters:
  • rA_ptr (cutlass.cute.Pointer)

  • sB_ptr (cutlass.cute.Pointer)

  • rC_ptr (cutlass.cute.Pointer)

  • wg_wait (cutlass.Constexpr)

  • clear_accum (cutlass.Constexpr)