tilelang.contrib.cutedsl.ptx_mma¶

PTX MMA operations for CuTeDSL backend. Based on tl_templates/cuda/instruction/mma.h

These functions provide wrappers around PTX mma.sync instructions for performing matrix multiply-accumulate operations using Tensor Cores.

Uses inline PTX assembly for direct MMA instruction generation.

Supported dense configurations (from mma.h): - FP16: m16n8k16 -> f16/f32 accumulator - BF16: m16n8k16 -> f32 accumulator - INT8: m16n8k32 -> i32 accumulator - UINT8: m16n8k32 -> i32 accumulator - INT4: m16n8k32 -> i32 accumulator (mapped to m16n8k64 in PTX) - UINT4: m16n8k32 -> i32 accumulator - FP8 (e4m3/e5m2): m16n8k32 -> f16/f32 accumulator - TF32: m16n8k4, m16n8k8 -> f32 accumulator - FP64: m8n8k4 -> f64 accumulator

Sparse (mma.sp) variants mirror the dense ones with halved A registers, an extra metadata register, and a sparse_selector literal.

Attributes¶

Functions¶

ptx_mma(shape, a_layout, b_layout, a_dtype, b_dtype, ...)

Generic PTX MMA dispatcher.

ptx_mma_sp(shape, a_layout, b_layout, a_dtype, ...[, ...])

Generic PTX sparse MMA dispatcher.

Module Contents¶

tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k16_f16_f16_f32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k16_f16_f16_f16¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k16_bf16_bf16_f32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k32_s8_s8_s32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k32_u8_u8_s32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k32_s4_s4_s32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k32_u4_u4_s32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k4_tf32_tf32_f32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k8_tf32_tf32_f32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m8n8k4_f64_f64_f64¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k32_e4m3_e4m3_f32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k32_e4m3_e4m3_f16¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma_m16n8k32_e5m2_e5m2_f32¶
tilelang.contrib.cutedsl.ptx_mma.ptx_mma(shape, a_layout, b_layout, a_dtype, b_dtype, c_dtype, a_ptr, a_offset, b_ptr, b_offset, c_ptr, c_offset, saturate=False)¶

Generic PTX MMA dispatcher.

Dispatches to the appropriate specialized MMA function based on shape and data types.

Parameters:
  • shape (str)

  • a_layout (str)

  • b_layout (str)

  • a_dtype (str)

  • b_dtype (str)

  • c_dtype (str)

  • saturate (bool)

tilelang.contrib.cutedsl.ptx_mma.ptx_mma_sp(shape, a_layout, b_layout, a_dtype, b_dtype, c_dtype, a_ptr, a_offset, b_ptr, b_offset, c_ptr, c_offset, meta_ptr, meta_offset, sparse_selector=0, saturate=False)¶

Generic PTX sparse MMA dispatcher.

Dispatches to the appropriate specialized sparse MMA function based on shape and data types.

Parameters:
  • shape (str)

  • a_layout (str)

  • b_layout (str)

  • a_dtype (str)

  • b_dtype (str)

  • c_dtype (str)

  • sparse_selector (int)

  • saturate (bool)