eformer.ops.quantization.quantization_functions#
Low-level Quantization Functions.
This module provides core quantization and dequantization functions for various bit-widths including NF4 (4-bit NormalFloat), INT8, and 1-bit formats. It also includes TPU-optimized Pallas kernels for efficient matrix multiplication with quantized weights.
- Key Functions:
- Quantization/Dequantization:
quantize_int8 / dequantize_int8: 8-bit integer quantization with scaling
quantize_and_pack_nf4 / dequantize_nf4: NF4 block-wise quantization
pack_weights_1bit / unpack_weights_1bit: 1-bit ternary weight packing
- TPU Kernels:
bmm_nf4: Pallas kernel for NF4 batch matrix multiplication
bmm_nf4_transpose: Transposed variant for backward passes
nf4_matmul: High-level wrapper for TPU-optimized matmul
- Configuration:
is_kernel_available(): Check if TPU kernels are available
nf4_use_kernel(): Context manager for kernel mode control
- Utilities:
get_nf4(): Get NF4 lookup table
nf4xf32_to_f32(): Polynomial approximation for NF4 dequantization
i8tou8, u4toi4, i4tou4: Bit conversion utilities
- Environment Variables:
- USE_NF4_KERNEL_TPU: Set to “0”, “false”, or “off” to disable TPU kernels.
Default is enabled (“1”).
Example
>>> import jax.numpy as jnp
>>> from eformer.ops.quantization.quantization_functions import (
... quantize_int8, dequantize_int8,
... quantize_and_pack_nf4, dequantize_nf4
... )
>>>
>>> # INT8 quantization
>>> weights = jnp.ones((128, 256))
>>> quants, scales = quantize_int8(weights, axis=-1)
>>> reconstructed = dequantize_int8(quants, scales)
>>>
>>> # NF4 quantization
>>> packed, absmax = quantize_and_pack_nf4(weights, block_size=64)
>>> reconstructed = dequantize_nf4(packed, absmax, block_size=64)
- eformer.ops.quantization.quantization_functions.bmm_nf4(inputs_ref, quants_ref, scale_ref, outputs_ref, accum_ref, *, block_k)[source]#
Pallas kernel for NF4 matrix multiplication with on-the-fly dequantization.
This kernel performs efficient matrix multiplication by dequantizing weights on-the-fly during computation, avoiding materialization of full-precision weights.
- Parameters
inputs_ref – Reference to input activation tensor
quants_ref – Reference to quantized weight tensor (int4)
scale_ref – Reference to scale factors
outputs_ref – Reference to output tensor
accum_ref – Reference to accumulator tensor
block_k – Block size for K dimension (static)
- eformer.ops.quantization.quantization_functions.bmm_nf4_transpose(inputs_ref, quants_ref, scale_ref, outputs_ref, accum_ref, *, block_k)[source]#
Pallas kernel for transposed NF4 matrix multiplication.
This kernel handles the transpose case where weights need to be accessed in a different order. Used for backward passes in training.
- Parameters
inputs_ref – Reference to input activation tensor
quants_ref – Reference to quantized weight tensor (int4)
scale_ref – Reference to scale factors
outputs_ref – Reference to output tensor
accum_ref – Reference to accumulator tensor
block_k – Block size for K dimension (static)
- eformer.ops.quantization.quantization_functions.dequantize_int8(quants, scales)[source]#
Dequantize 8-bit integers back to floating-point values.
Multiplies the quantized int8 values by their corresponding scale factors to reconstruct approximate original values.
- Parameters
- Returns
- Dequantized array with the same shape as quants, dtype
determined by the scales array.
- Return type
Example
>>> quants = jnp.array([[42, 85, 127], [32, 64, 96]], dtype=jnp.int8) >>> scales = jnp.array([[0.024], [0.047]]) # Shape (2, 1) >>> result = dequantize_int8(quants, scales) >>> # result has shape (2, 3) with float values
Note
This is the inverse operation of quantize_int8. Due to rounding during quantization, the reconstructed values are approximate.
- eformer.ops.quantization.quantization_functions.dequantize_nf4(packed_values, absmax, block_size)[source]#
Dequantize an array from NF4 (4-bit NormalFloat) format.
High-level API for NF4 dequantization. Unpacks 4-bit values and scales them by per-block absmax values to reconstruct approximate original values.
- Parameters
packed_values (jax.Array) – uint8 array with packed 4-bit values as produced by quantize_and_pack_nf4. Shape: (…, num_blocks, block_size // 2).
absmax (jax.Array) – Per-block scale factors from quantization. Shape: (…, num_blocks).
block_size (int) – Number of elements per quantization block. Must match the value used during quantization.
- Returns
- Dequantized float32 array with shape (…, features),
where features = num_blocks * block_size.
- Return type
Example
>>> # Given packed data from quantize_and_pack_nf4 >>> reconstructed = dequantize_nf4(packed, absmax, block_size=64) >>> # reconstructed approximates original values
See also
quantize_and_pack_nf4: Forward quantization operation.
single_dequantize_nf4: Internal implementation.
Note
Due to quantization, the reconstructed values are approximate. NF4 provides better accuracy than uniform 4-bit quantization for normally distributed data (typical of neural network weights).
- eformer.ops.quantization.quantization_functions.get_nf4()[source]#
Get the NF4 (4-bit NormalFloat) lookup table.
Creates the lookup table lazily to avoid triggering JAX initialization at import time. The NF4 format uses 16 values optimized for Gaussian distributions, providing better accuracy than uniform quantization for neural network weights.
- Returns
- A 16-element float32 array containing the NF4 codebook
values ranging from -1.0 to 1.0. The values are symmetric around zero and optimized for normally distributed data.
- Return type
Note
The codebook values are derived from the quantiles of a standard normal distribution, making NF4 particularly effective for neural network weights which tend to follow Gaussian distributions.
- eformer.ops.quantization.quantization_functions.i4tou4(x)[source]#
Convert signed 4-bit integer to unsigned 4-bit integer.
Maps signed values -8 to 7 to unsigned range 0-15. Negative values -8 to -1 become 8 to 15.
- Parameters
x (jax.Array) – Input array with signed int4 values (-8 to 7).
- Returns
Output array with uint4 values (0 to 15).
- Return type
Example
>>> i4tou4(jnp.array([-8, -1, 0, 7])) Array([8, 15, 0, 7], dtype=int32)
- eformer.ops.quantization.quantization_functions.i8tou8(x)[source]#
Convert signed int8 to unsigned uint8.
Handles the two’s complement representation by adding 256 to negative values.
- Parameters
x (jax.Array) – Input array with int8 values (-128 to 127).
- Returns
Output array with equivalent uint8 values (0 to 255).
- Return type
Example
>>> i8tou8(jnp.array([-1, 0, 127], dtype=jnp.int8)) Array([255, 0, 127], dtype=int32)
- eformer.ops.quantization.quantization_functions.is_kernel_available()[source]#
Check if NF4 kernels are available on the current device.
- Returns
True if running on TPU (where kernels are supported), False otherwise
- Return type
bool
- eformer.ops.quantization.quantization_functions.nf4_matmul(inputs, *tensors, kernel, backward=False, blocks=None)[source]#
Fast matrix multiplication using Pallas TPU kernels.
This function provides optimized matrix multiplication with automatic block size selection and padding for optimal TPU performance.
- Parameters
inputs – Input activation tensor
*tensors – Quantized weight tensors (quants, scales)
kernel – Pallas kernel function to use
backward – Whether this is a backward pass
blocks – Optional manual block size specification (block_x, block_y, block_k)
- Returns
Result of matrix multiplication
- eformer.ops.quantization.quantization_functions.nf4_use_kernel(value: bool)[source]#
Context manager to enable/disable NF4 kernel mode.
Note: Kernels are only enabled on TPU devices. On other devices, this setting has no effect and the code will fall back to materialization.
- Parameters
value – Whether to enable kernel mode (only effective on TPU)
Example
>>> with nf4_use_kernel(True): ... result = input @ quantized_weight # Uses kernel on TPU
- eformer.ops.quantization.quantization_functions.nf4xf32_to_f32(x)[source]#
Fast polynomial approximation for NF4 dequantization.
This is significantly faster than table lookups and provides accurate approximation of the NF4 codebook values.
- Parameters
x – Integer array (0-15) representing NF4 quantized values
- Returns
Float32 array with approximated NF4 values
- eformer.ops.quantization.quantization_functions.pack_weights_1bit(quantized_weights: Array) Array[source]#
Packs a JAX array of quantized weights into a compact format using 2 bits per value.
Parameters:#
- quantized_weightsjnp.ndarray
An array containing ternary quantized weights {-1, 0, 1}. The first dimension must be a multiple of 4.
Returns:#
- jnp.ndarray
A packed jnp.uint8 array.
- eformer.ops.quantization.quantization_functions.quantize_and_pack_nf4(blocks, block_size=64)[source]#
Quantize and pack an array using NF4 (4-bit NormalFloat) quantization.
High-level API for NF4 quantization. Quantizes the input array into 4-bit NormalFloat format with block-wise scaling and packs two values per byte.
- Parameters
blocks (jax.Array) – Input array to quantize. The last dimension must be divisible by block_size. Shape: (…, features).
block_size (int) – Number of elements per quantization block. Defaults to 64. Larger blocks use less memory for scales but may have higher quantization error.
- Returns
- A tuple containing:
packed (jax.Array): uint8 array with packed 4-bit values. Shape: (…, num_blocks, block_size // 2).
absmax (jax.Array): Per-block scale factors for dequantization. Shape: (…, num_blocks).
- Return type
Example
>>> weights = jnp.ones((128, 256), dtype=jnp.float32) >>> packed, absmax = quantize_and_pack_nf4(weights, block_size=64) >>> # Reconstruct with dequantize_nf4 >>> reconstructed = dequantize_nf4(packed, absmax, block_size=64)
See also
dequantize_nf4: Reverse operation to reconstruct values.
single_quantize_and_pack_nf4: Internal implementation.
- eformer.ops.quantization.quantization_functions.quantize_int8(x: Array, axis: int | tuple = -1)[source]#
Quantize floating-point values to 8-bit integers with per-axis scaling.
Computes a scale factor based on the maximum absolute value along the specified axis, then quantizes values to the int8 range [-127, 127].
- Parameters
x (jax.Array) – Input array to quantize. Can be any floating-point dtype.
axis (int | tuple) – Axis or axes along which to compute the scale factor. Defaults to -1 (last axis). Can be a single int or tuple of ints.
- Returns
- A tuple containing:
quant (jax.Array): int8 array with quantized values in range [-127, 127].
scale (jax.Array): Float array of scale factors with shape broadcastable to the input (dimensions along axis are kept as size 1).
- Return type
Example
>>> x = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> quant, scale = quantize_int8(x, axis=-1) >>> # quant is int8, scale has shape (2, 1) >>> reconstructed = dequantize_int8(quant, scale)
Note
A tiny epsilon is added to the scale to avoid division by zero for arrays with all zeros. The scale preserves the original dtype of the input array.
- eformer.ops.quantization.quantization_functions.single_dequantize_nf4(packed_values, absmax, block_size)[source]#
Dequantize packed NF4 values back to floating-point.
This function reverses the NF4 quantization by: 1. Unpacking two 4-bit values from each uint8 byte 2. Converting 4-bit indices to NF4 codebook values using polynomial approximation 3. Scaling by per-block absmax values 4. Flattening back to original feature dimension
- Parameters
packed_values (jax.Array) – uint8 array with packed 4-bit values. Shape (…, num_blocks, block_size // 2).
absmax (jax.Array) – Absolute maximum per block used during quantization. Shape (…, num_blocks).
block_size (int) – Number of elements per quantization block. Must match the value used during quantization.
- Returns
- Dequantized float32 array with shape (…, num_blocks * block_size),
which equals the original feature dimension.
- Return type
Example
>>> # Given packed data from single_quantize_and_pack_nf4 >>> reconstructed = single_dequantize_nf4(packed, absmax, block_size=64) >>> reconstructed.shape # Original shape restored
Note
Uses a polynomial approximation (nf4xf32_to_f32) instead of table lookup for faster dequantization on accelerators.
- eformer.ops.quantization.quantization_functions.single_quantize_and_pack_nf4(blocks, block_size=64)[source]#
Quantize and pack an array to NF4 format in a single pass.
This function performs block-wise NF4 quantization by: 1. Reshaping the input into blocks along the last dimension 2. Computing per-block absolute maximum values for scaling 3. Normalizing values by absmax 4. Finding nearest NF4 codebook values 5. Packing two 4-bit values into each uint8 byte
- Parameters
blocks (jax.Array) – Input array with shape (…, features) where features must be divisible by block_size.
block_size (int) – Number of elements per quantization block. Defaults to 64. Must be even for packing.
- Returns
- A tuple containing:
packed (jax.Array): uint8 array with packed 4-bit values. Shape (…, num_blocks, block_size // 2).
absmax (jax.Array): Absolute maximum per block for dequantization. Shape (…, num_blocks).
- Return type
Example
>>> x = jnp.ones((128, 256)) # 256 features, 4 blocks of 64 >>> packed, absmax = single_quantize_and_pack_nf4(x, block_size=64) >>> packed.shape # (128, 4, 32) - 4 blocks, 32 bytes each >>> absmax.shape # (128, 4) - one scale per block
Note
This is optimized for JAX JIT compilation. For the high-level API, use quantize_and_pack_nf4 instead.
- eformer.ops.quantization.quantization_functions.u4toi4(x)[source]#
Convert unsigned 4-bit integer to signed 4-bit integer.
Maps unsigned values 0-15 to signed range -8 to 7. Values 8-15 become negative values -8 to -1.
- Parameters
x (jax.Array) – Input array with uint4 values (0 to 15).
- Returns
Output array with signed int4 values (-8 to 7).
- Return type
Example
>>> u4toi4(jnp.array([0, 7, 8, 15])) Array([0, 7, -8, -1], dtype=int32)
- eformer.ops.quantization.quantization_functions.unpack_weights_1bit(packed: Array, dtype: dtype) Array[source]#
Unpacks a JAX array of quantized weights, matching the logic of the PyTorch original. This function concatenates the unpacked bit groups.
Parameters:#
- packedjnp.ndarray
A packed jnp.uint8 array.
- dtypejnp.dtype
The dtype of the returned array (e.g., jnp.int8). This is a static argument for JIT.
Returns:#
- jnp.ndarray
An unpacked array with ternary values {-1, 0, 1}.