tilelang.quantize.mxfp¶
Attributes¶
Functions¶
|
Return metadata for an MXFP decoding intrinsic: function name and C source string. |
Module Contents¶
- tilelang.quantize.mxfp.decode_f4_to_bf16_twiddling = Multiline-String¶
Show Value
""" // N should be the number of elements processed by one thread template<typename T1, typename T2> __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, const int N = 8) { #pragma unroll for (int i = 0; i < N; ++i) { uint B_dequantize_local_vec[4]; uint tmp, bias, d0, d1, d2, d3, d4, d5, d6; asm volatile( // To handle the endianness issue "prmt.b32 %13, %4, 0, 0x0123;" "mov.b32 %12, 0x7e807e80;" "and.b32 %0, %13, 0b10000001110000001000000111000000;" "mul.bf16x2 %0, %0, %12;" "shl.b32 %1, %13, 3;" "and.b32 %1, %1, 0b10000001110000001000000111000000;" "mul.bf16x2 %1, %1, %12;" "shl.b32 %2, %13, 6;" "and.b32 %2, %2, 0b10000001110000001000000111000000;" "mul.bf16x2 %2, %2, %12;" "shl.b32 %5, %13, 1;" "and.b32 %6, %5, 0b10000000000000001000000000000000;" "shr.b32 %7, %13, 3;" "and.b32 %8, %7, 0b00000001100000000000000110000000;" "or.b32 %9, %6, %8;" "shr.b32 %10, %13, 7;" "and.b32 %11, %10, 0b00000000010000000000000001000000;" "or.b32 %3, %9, %11;" "mul.bf16x2 %3, %3, %12;" :"=r"(B_dequantize_local_vec[0]) ,"=r"(B_dequantize_local_vec[1]) ,"=r"(B_dequantize_local_vec[2]) ,"=r"(B_dequantize_local_vec[3]) :"r"(*(uint*)&B_local[i << 2]), "r"(d0), "r"(d1), "r"(d2), "r"(d3), "r"(d4), "r"(d5), "r"(d6), "r"(bias), "r"(tmp) ); for (int j = 0; j < 4; ++j) { // Pay attention to the big-endianness issue B_local_decode[(i << 3) + j] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[1]; B_local_decode[(i << 3) + j + 4] = reinterpret_cast<T2*>(&B_dequantize_local_vec[j])[0]; } } // Check if the synchronization is needed } """
- tilelang.quantize.mxfp.get_mxfp_intrin_group(out_dtype='bfloat16', source_format='uint', source_bit=4, storage_dtype='uint8', use_twiddling=False)¶
Return metadata for an MXFP decoding intrinsic: function name and C source string.
Validates the requested output dtype, source format, and storage dtype, then constructs a lookup key of the form fp{source_bit}_to_{f16|bf16} (appending _twiddling when use_twiddling is True) to select the corresponding C source snippet and a matching function name decode_fp{source_bit}_to_{f16|bf16} (also optionally suffixed with _twiddling).
- Parameters:
out_dtype (Literal['float16', 'bfloat16']) – Target floating-point type for decoded values; either “float16” or “bfloat16”.
source_format (Literal['int', 'uint']) – Integer source representation; “int” or “uint”.
source_bit (int) – Bit width of the packed source format (e.g., 4).
storage_dtype (Literal['int32', 'int8', 'uint8']) – Underlying storage integer dtype (one of “int32”, “int8”, “uint8”).
use_twiddling (bool) – When True, select the twiddling variant of the decoding intrinsic.
- Returns:
“func_name”: the generated C function name string for the requested decode intrinsic.
”c_source”: the C source string for that intrinsic.
- Return type:
A dict with
- Raises:
AssertionError – if out_dtype, source_format, or storage_dtype are not supported.
KeyError – if the constructed key does not match any available C source implementation.