tilelang.quantize.mxfp ====================== .. py:module:: tilelang.quantize.mxfp Attributes ---------- .. autoapisummary:: tilelang.quantize.mxfp.decode_f4_to_bf16_twiddling Functions --------- .. autoapisummary:: tilelang.quantize.mxfp.get_mxfp_intrin_group Module Contents --------------- .. py:data:: decode_f4_to_bf16_twiddling :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ // N should be the number of elements processed by one thread template __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, const int N = 8) { #pragma unroll for (int i = 0; i < N; ++i) { uint B_dequantize_local_vec[4]; uint tmp, bias, d0, d1, d2, d3, d4, d5, d6; asm volatile( // To handle the endianness issue "prmt.b32 %13, %4, 0, 0x0123;" "mov.b32 %12, 0x7e807e80;" "and.b32 %0, %13, 0b10000001110000001000000111000000;" "mul.bf16x2 %0, %0, %12;" "shl.b32 %1, %13, 3;" "and.b32 %1, %1, 0b10000001110000001000000111000000;" "mul.bf16x2 %1, %1, %12;" "shl.b32 %2, %13, 6;" "and.b32 %2, %2, 0b10000001110000001000000111000000;" "mul.bf16x2 %2, %2, %12;" "shl.b32 %5, %13, 1;" "and.b32 %6, %5, 0b10000000000000001000000000000000;" "shr.b32 %7, %13, 3;" "and.b32 %8, %7, 0b00000001100000000000000110000000;" "or.b32 %9, %6, %8;" "shr.b32 %10, %13, 7;" "and.b32 %11, %10, 0b00000000010000000000000001000000;" "or.b32 %3, %9, %11;" "mul.bf16x2 %3, %3, %12;" :"=r"(B_dequantize_local_vec[0]) ,"=r"(B_dequantize_local_vec[1]) ,"=r"(B_dequantize_local_vec[2]) ,"=r"(B_dequantize_local_vec[3]) :"r"(*(uint*)&B_local[i << 2]), "r"(d0), "r"(d1), "r"(d2), "r"(d3), "r"(d4), "r"(d5), "r"(d6), "r"(bias), "r"(tmp) ); for (int j = 0; j < 4; ++j) { // Pay attention to the big-endianness issue B_local_decode[(i << 3) + j] = reinterpret_cast(&B_dequantize_local_vec[j])[1]; B_local_decode[(i << 3) + j + 4] = reinterpret_cast(&B_dequantize_local_vec[j])[0]; } } // Check if the synchronization is needed } """ .. raw:: html
.. py:function:: get_mxfp_intrin_group(out_dtype = 'bfloat16', source_format = 'uint', source_bit = 4, storage_dtype = 'uint8', use_twiddling = False) Return metadata for an MXFP decoding intrinsic: function name and C source string. Validates the requested output dtype, source format, and storage dtype, then constructs a lookup key of the form `fp{source_bit}_to_{f16|bf16}` (appending `_twiddling` when use_twiddling is True) to select the corresponding C source snippet and a matching function name `decode_fp{source_bit}_to_{f16|bf16}` (also optionally suffixed with `_twiddling`). :param out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16". :param source_format: Integer source representation; "int" or "uint". :param source_bit: Bit width of the packed source format (e.g., 4). :param storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8"). :param use_twiddling: When True, select the twiddling variant of the decoding intrinsic. :returns: - "func_name": the generated C function name string for the requested decode intrinsic. - "c_source": the C source string for that intrinsic. :rtype: A dict with :raises AssertionError: if out_dtype, source_format, or storage_dtype are not supported. :raises KeyError: if the constructed key does not match any available C source implementation.