tilelang.quantize.lop3
======================
.. py:module:: tilelang.quantize.lop3
Attributes
----------
.. autoapisummary::
tilelang.quantize.lop3.decode_i4_to_f16
tilelang.quantize.lop3.decode_i4_to_f16_scale
tilelang.quantize.lop3.decode_i4_to_f16_scale_offset
tilelang.quantize.lop3.decode_i4_to_f16_scale_zeros_original
tilelang.quantize.lop3.decode_i4_to_f16_scale_zeros_original_offset
tilelang.quantize.lop3.decode_i4_to_f16_scale_zeros_rescale
tilelang.quantize.lop3.decode_i4_to_f16_scale_zeros_rescale_offset
tilelang.quantize.lop3.decode_i4_to_f16_scale_zeros_quantized
tilelang.quantize.lop3.decode_i4_to_f16_scale_zeros_quantized_offset
tilelang.quantize.lop3.decode_i2_to_f16
tilelang.quantize.lop3.decode_i2_to_f16_scale
tilelang.quantize.lop3.decode_i2_to_f16_scale_zeros_original_offset
tilelang.quantize.lop3.decode_i2_to_f16_scale_zeros_original
tilelang.quantize.lop3.decode_i2_to_f16_scale_zeros_rescale
tilelang.quantize.lop3.decode_i2_to_f16_scale_zeros_quantized
tilelang.quantize.lop3.decode_i1_to_f16
tilelang.quantize.lop3.decode_i1_to_f16_scale
tilelang.quantize.lop3.decode_i1_to_f16_scale_zeros_original
tilelang.quantize.lop3.decode_i1_to_f16_scale_zeros_rescale
tilelang.quantize.lop3.decode_i1s_to_i8s
tilelang.quantize.lop3.decode_i2s_to_i8s
tilelang.quantize.lop3.decode_i4s_to_i8s
tilelang.quantize.lop3.decode_i2s_to_i4s
Functions
---------
.. autoapisummary::
tilelang.quantize.lop3.get_lop3_intrin_group
Module Contents
---------------
.. py:data:: decode_i4_to_f16
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
}
template
__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8)
{
decode_i4b_to_f16(_i4s, B_local_decode, N);
}
template
__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8)
{
decode_i4b_to_f16(_i4u, B_local_decode, N);
}
"""
.. raw:: html
.. py:data:: decode_i4_to_f16_scale
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template
__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
decode_i4b_to_f16_scale(_i4s, B_local_decode, N, scale);
}
template
__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
decode_i4b_to_f16_scale(_i4u, B_local_decode, N, scale);
}
"""
.. raw:: html
.. py:data:: decode_i4_to_f16_scale_offset
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16_scale_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const int offset = 0)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0));
}
}
template
__device__ void decode_i4s_to_f16_scale_offset(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int offset = 0, const int N = 8)
{
decode_i4b_to_f16_scale_offset(_i4s, B_local_decode, N, scale, offset);
}
template
__device__ void decode_i4u_to_f16_scale_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int offset = 0, const int N = 8)
{
decode_i4b_to_f16_scale_offset(_i4u, B_local_decode, N, scale, offset);
}
"""
.. raw:: html
.. py:data:: decode_i4_to_f16_scale_zeros_original
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16_zeros_original(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
T4 const zero_r = *zeros;
uint const packed_zeros = __pack_half2(zero_r, zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template
__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i4b_to_f16_zeros_original(_i4u, B_local_decode, N, scale, zeros);
}
"""
.. raw:: html
.. py:data:: decode_i4_to_f16_scale_zeros_original_offset
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16_zeros_original_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr, const int offset = 0)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
T3 const zeros_l = *zeros;
T3 const zeros_r = *(zeros + offset);
uint const packed_zeros_l = __pack_half2(zeros_l, zeros_l);
uint const packed_zeros_r = __pack_half2(zeros_r, zeros_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_l));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_r));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0));
}
}
template
__device__ void decode_i4u_to_f16_scale_zeros_original_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int offset = 0, const int N = 8)
{
decode_i4b_to_f16_zeros_original_offset(_i4u, B_local_decode, N, scale, zeros, offset);
}
"""
.. raw:: html
.. py:data:: decode_i4_to_f16_scale_zeros_rescale
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
T4 const zero_r = *zeros;
uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros));
}
}
template
__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i4b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros);
}
"""
.. raw:: html
.. py:data:: decode_i4_to_f16_scale_zeros_rescale_offset
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16_scale_zeros_rescale_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr, const int offset = 0)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
T3 const zeros_l = *zeros;
T3 const zeros_r = *(zeros + offset);
uint const packed_zeros_l = 0x80008000 | __pack_half2(zeros_l, zeros_l);
uint const packed_zeros_r = 0x80008000 | __pack_half2(zeros_r, zeros_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(packed_zeros_l));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(packed_zeros_r));
}
}
template
__device__ void decode_i4u_to_f16_scale_zeros_rescale_offset(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int offset = 0, const int N = 8)
{
decode_i4b_to_f16_scale_zeros_rescale_offset(_i4u, B_local_decode, N, scale, zeros, offset);
}
"""
.. raw:: html
.. py:data:: decode_i4_to_f16_scale_zeros_quantized
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16_scale_zeros_quantized(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
int16_t const zero_r = *((int16_t*)zeros);
uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template
__device__ void decode_i4u_to_f16_scale_zeros_quantized(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, zero_dtype *zeros = nullptr, const int N = 8)
{
decode_i4b_to_f16_scale_zeros_quantized(_i4u, B_local_decode, N, scale, zeros);
}
"""
.. raw:: html
.. py:data:: decode_i4_to_f16_scale_zeros_quantized_offset
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i4b_to_f16_scale_zeros_quantized_offset(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T1 *qzeros = nullptr, const int scale_offset = 0, const int qzeros_offset = 0, const int group_offset = 0)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
uint const i4s = *reinterpret_cast(_i4s);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + scale_offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
const int num_elems_per_storage_dtype = sizeof(T1) * 8 / 4;
T1 const qzeros_l = *qzeros;
T1 const qzeros_r = *(qzeros + qzeros_offset);
int16_t const zero_l = (qzeros_l >> (group_offset * 4) & 0xf);
int16_t const zero_r = (qzeros_r >> (group_offset * 4) & 0xf);
uint median_num_l = ((0xe400 | zero_l) << 16) | (0xe400 | zero_l);
uint median_num_r = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num_l));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num_r));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0));
}
}
template
__device__ void decode_i4u_to_f16_scale_zeros_quantized_offset(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, storage_dtype *qzeros = nullptr, const int scale_offset = 0, const int zero_offset = 0, const int group_offset = 0, const int N = 8)
{
decode_i4b_to_f16_scale_zeros_quantized_offset(_i4u, B_local_decode, N, scale, qzeros, scale_offset, zero_offset, group_offset);
}
"""
.. raw:: html
.. py:data:: decode_i2_to_f16
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
}
template
__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8)
{
decode_i2b_to_f16(_i2s, B_local_decode, N);
}
template
__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8)
{
decode_i2b_to_f16(_i2u, B_local_decode, N);
}
"""
.. raw:: html
.. py:data:: decode_i2_to_f16_scale
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i2b_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
}
}
template
__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale, const int N = 8)
{
decode_i2b_to_f16_scale(_i2s, B_local_decode, scale, N);
}
template
__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, T3 *scale, const int N = 8)
{
decode_i2b_to_f16_scale(_i2u, B_local_decode, scale, N);
}
"""
.. raw:: html
.. py:data:: decode_i2_to_f16_scale_zeros_original_offset
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i2b_to_f16_scale_zeros_original_offset(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int offset = 0, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
T3 const zeros_l = *zeros;
T3 const zeros_r = *(zeros + offset);
uint const packed_zeros_l = __pack_half2(zeros_l, zeros_l);
uint const packed_zeros_r = __pack_half2(zeros_r, zeros_r);
T3 const scale_l = *scale;
T3 const scale_r = *(scale + offset);
uint const packed_scales_l = __pack_half2(scale_l, scale_l);
uint const packed_scales_r = __pack_half2(scale_r, scale_r);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
}
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_l));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_l), "r"(0));
}
#pragma unroll
for (int i = (N / 4); i < (N / 2); i++)
{
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros_r));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales_r), "r"(0));
}
}
template
__device__ void decode_i2u_to_f16_scale_zeros_original_offset(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int offset = 0, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_original(_i2u, B_local_decode, scale, zeros, offset, N);
}
"""
.. raw:: html
.. py:data:: decode_i2_to_f16_scale_zeros_original
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i2b_to_f16_scale_zeros_original(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
}
}
template
__device__ void decode_i2u_to_f16_scale_zeros_original(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_original(_i2u, B_local_decode, scale, zeros, N);
}
"""
.. raw:: html
.. py:data:: decode_i2_to_f16_scale_zeros_rescale
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i2b_to_f16_scale_zeros_rescale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
}
}
template
__device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_rescale(_i2u, B_local_decode, scale, zeros, N);
}
"""
.. raw:: html
.. py:data:: decode_i2_to_f16_scale_zeros_quantized
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast(_i2s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
int16_t const zero_r = *((int16_t*)zeros);
uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0}
// otherwise the pointer of _i2s should be moved to
int i2s = (i2s_i16 & 0x00ff);
i2s |= ((i2s_i16 & 0xff00) << 8);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template
__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros);
}
"""
.. raw:: html
.. py:data:: decode_i1_to_f16
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
/*
Kind 0: original
Kind 1: rescale
Kind 2: quantized
# documents for zeros_mode:
# original: target = (dequantize_weight - zero_point) * scale
# rescale: target = dequantize_weight * scale - zero_point
# quantized: target = (dequantize_weight - dequantize_zeros) * scale
# Notice: only support "original" and "rescale" now
zeros_mode: Literal["original", "rescale", "quantized"] = "original"
*/
template
__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, half *scale = nullptr, half *zeros = nullptr)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
if constexpr (isSigned)
{
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
}
if constexpr (withZeros && ZerosKind == 0)
{
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
}
if constexpr (withScaling)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
}
if constexpr (withZeros && ZerosKind == 1)
{
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
}
}
}
template
__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
{
decode_i1b_to_f16(_i1s, B_local_decode, N);
}
template
__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8)
{
decode_i1b_to_f16(_i1u, B_local_decode, N);
}
"""
.. raw:: html
.. py:data:: decode_i1_to_f16_scale
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template
__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
"""
.. raw:: html
.. py:data:: decode_i1_to_f16_scale_zeros_original
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
// input zeros maybe int32(qzeros) or half format
T4 const zero_r = *zeros;
uint const packed_zeros = __pack_half2(zero_r, zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template
__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i1b_to_f16_zeros_original(_i1u, B_local_decode, N, scale, zeros);
}
"""
.. raw:: html
.. py:data:: decode_i1_to_f16_scale_zeros_rescale
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i1b_to_f16_scale_zeros_rescale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
{
uint *h = reinterpret_cast(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
T4 const zero_r = *zeros;
uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros));
}
}
template
__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8)
{
decode_i1b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros);
}
"""
.. raw:: html
.. py:data:: decode_i1s_to_i8s
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""template
__device__ void decode_i1s_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16)
{
int i8s[4];
// vector load
*reinterpret_cast(i8s) = *reinterpret_cast(_i8s);
int16_t i1b_i16 = *reinterpret_cast(_i1b);
// permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15}
// into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x}
int i1b = (i1b_i16 & 0x0f0f);
i1b |= ((i1b_i16 & 0xf0f0) << 12);
// i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0}
// interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// First, we extract the i1b and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1
static constexpr uint I8s_MAGIC_NUM = 0x00000000;
static constexpr uint TRANSFORM_SUBTRACT = 0xffffffff; // for signed int 2x - 1
for (int i = 0; i < N / 4; i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vadd4(i8s[i], i8s[i]);
i8s[i] = __vadd4(i8s[i], TRANSFORM_SUBTRACT);
}
*reinterpret_cast(_i8s) = *reinterpret_cast(i8s);
}
template
__device__ void decode_i1u_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16)
{
int *i8s = reinterpret_cast(_i8s);
int16_t i1b_i16 = *reinterpret_cast(_i1b);
// permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15}
// into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x}
int i1b = (i1b_i16 & 0x0f0f);
i1b |= ((i1b_i16 & 0xf0f0) << 12);
// i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0}
// interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
// First, we extract the i1b and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1
static constexpr uint I8s_MAGIC_NUM = 0x00000000;
static constexpr uint MEDIAN_NUM = 0x00000000;
for (int i = 0; i < N / 4; i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
}
}
"""
.. raw:: html
.. py:data:: decode_i2s_to_i8s
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""template
__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast(_i8s);
// i2b = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2b = *reinterpret_cast(_i2b);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024
static constexpr uint MEDIAN_NUM = 0x02020202;
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsub4(i8s[i], MEDIAN_NUM);
}
}
template
__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast(_i8s);
// i2b = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2b = *reinterpret_cast(_i2b);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
}
}
"""
.. raw:: html
.. py:data:: decode_i4s_to_i8s
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""template
__device__ void decode_i4s_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16)
{
uint *i8s = reinterpret_cast(_i8s);
uint *i4b = reinterpret_cast(_i4b);
// First, we extract the i4s and construct an intermediate i8 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12
static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0
static constexpr uint MEDIAN_NUM = 0x07070707;
#pragma unroll
for (int i = 0; i < (N / 8); i++)
{
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i + 2])
: "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM);
i8s[i + 2] = __vsubss4(i8s[i + 2], MEDIAN_NUM);
}
}
template
__device__ void decode_i4u_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16)
{
uint *i8s = reinterpret_cast(_i8s);
uint *i4b = reinterpret_cast(_i4b);
// First, we extract the i4s and construct an intermediate i8 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12
static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0
#pragma unroll
for (int i = 0; i < (N / 8); i++)
{
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i + 2])
: "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
}
}
"""
.. raw:: html
.. py:data:: decode_i2s_to_i4s
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
template
__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16)
{
uint *i4s = reinterpret_cast(_i4s);
uint *i2b = reinterpret_cast(_i2b);
// First, we extract the i4s and construct an intermediate i8 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12
static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0
static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000;
#pragma unroll
for (int i = 0; i < (N / 8); i++)
{
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i4s[i])
: "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
if constexpr (isSigned)
{
// TODO(lei): uint4 sub should be enhanced.
// 0x03 0x03 0x03 0x03
// i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i];
}
}
}
template
__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16)
{
decode_i2b_to_i4s(_i4s, B_local_decode, N);
}
template
__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16)
{
decode_i2b_to_i4s(_i4u, B_local_decode, N);
}
"""
.. raw:: html
.. py:function:: get_lop3_intrin_group(out_dtype, source_format = 'uint', source_bit = 4, storage_dtype = 'int8', with_scaling = False, with_zeros = False, zeros_mode = 'original', storage_scope = 'local')
This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding.
LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
intrinsic operations that can be performed on these inputs. This function retrieves and returns this group.
:param in_dtype: The data type of the input. It should be "int8".
:type in_dtype: Literal["int8"]
:param out_dtype: The data type of the output. It can be either "float16" or "int8" or "int4".
:type out_dtype: Literal["float16", "int8", "int4"]
:param storage_nbit: The number of bits used for storage. By default, it is 4.
:type storage_nbit: int, optional
:param with_scale: A boolean parameter that indicates whether scaling should be applied. By default, it is False.
:type with_scale: bool, optional
:param with_zeros: A boolean parameter that indicates whether zeros should be used. By default, it is False.
:type with_zeros: bool, optional
:param zeros_mode: The mode of zeros. It can be either "original", "rescale", or "quantized". By default, it is "original".
:type zeros_mode: Literal["original", "rescale", "quantized"], optional
:param storage_scope: The scope of the storage. It can be either "local" or "warp". By default, it is "local".
:type storage_scope: Literal["local", "warp"], optional
:returns: A dictionary mapping the names of the intrinsics to their corresponding implementations.
:rtype: Dict[str, str]