tilelang.contrib.cutedsl.quantize ================================= .. py:module:: tilelang.contrib.cutedsl.quantize .. autoapi-nested-parse:: Quantization/dequantization functions for CuTeDSL backend. These implement the same functionality as the CUDA templates in tilelang/quantize/lop3.py using inline PTX via llvm.inline_asm. Attributes ---------- .. autoapisummary:: tilelang.contrib.cutedsl.quantize.BOTTOM_MASK tilelang.contrib.cutedsl.quantize.FP16_TOP_MAGIC_NUM tilelang.contrib.cutedsl.quantize.IMMLUT tilelang.contrib.cutedsl.quantize.MEDIAN_NUM_UNSIGNED tilelang.contrib.cutedsl.quantize.MEDIAN_NUM_SIGNED Functions --------- .. autoapisummary:: tilelang.contrib.cutedsl.quantize.decode_i4u_to_f16 tilelang.contrib.cutedsl.quantize.decode_i4s_to_f16 tilelang.contrib.cutedsl.quantize.decode_fp4_to_bf16_twiddling Module Contents --------------- .. py:data:: BOTTOM_MASK :value: 983055 .. py:data:: FP16_TOP_MAGIC_NUM :value: 1677747200 .. py:data:: IMMLUT :value: 234 .. py:data:: MEDIAN_NUM_UNSIGNED :value: 1677747200 .. py:data:: MEDIAN_NUM_SIGNED :value: 1678271496 .. py:function:: decode_i4u_to_f16(src_ptr, dst_ptr, N = 8) Decode unsigned INT4 to FP16. Equivalent to CUDA template: decode_i4b_to_f16(_i4u, B_local_decode, N); :param src_ptr: Pointer to packed INT4 data (4 bytes for 8 elements) :param dst_ptr: Pointer to FP16 output (16 bytes for 8 elements) :param N: Number of elements to decode (default 8, must be even) .. py:function:: decode_i4s_to_f16(src_ptr, dst_ptr, N = 8) Decode signed INT4 to FP16. N must be even. .. py:function:: decode_fp4_to_bf16_twiddling(src_ptr, dst_ptr, N = 8) Decode FP4 to BF16 using twiddling technique. Reference: triton/tensor_details/layout_details/hopper_value.py For each iteration: - Input: 4 bytes (uint32) = 8 FP4 values - Output: 8 BF16 values (16 bytes) C code output layout: B_local_decode[(i << 3) + j] = vec[j].high (j=0..3) B_local_decode[(i << 3) + j + 4] = vec[j].low (j=0..3) So output as uint32: dst[i*4 + 0] = {r1.high, r0.high} dst[i*4 + 1] = {r3.high, r2.high} dst[i*4 + 2] = {r1.low, r0.low} dst[i*4 + 3] = {r3.low, r2.low} :param src_ptr: Pointer to packed FP4 data :param dst_ptr: Pointer to BF16 output :param N: Number of iterations (default 8, processing 64 FP4 -> 64 BF16)