eformer.ops.quantization.quantization_functions#

Low-level Quantization Functions.

This module provides core quantization and dequantization functions for various bit-widths including NF4 (4-bit NormalFloat), INT8, and 1-bit formats. It also includes TPU-optimized Pallas kernels for efficient matrix multiplication with quantized weights.

Key Functions:
Quantization/Dequantization:
  • quantize_int8 / dequantize_int8: 8-bit integer quantization with scaling

  • quantize_and_pack_nf4 / dequantize_nf4: NF4 block-wise quantization

  • pack_weights_1bit / unpack_weights_1bit: 1-bit ternary weight packing

TPU Kernels:
  • bmm_nf4: Pallas kernel for NF4 batch matrix multiplication

  • bmm_nf4_transpose: Transposed variant for backward passes

  • nf4_matmul: High-level wrapper for TPU-optimized matmul

Configuration:
  • is_kernel_available(): Check if TPU kernels are available

  • nf4_use_kernel(): Context manager for kernel mode control

Utilities:
  • get_nf4(): Get NF4 lookup table

  • nf4xf32_to_f32(): Polynomial approximation for NF4 dequantization

  • i8tou8, u4toi4, i4tou4: Bit conversion utilities

Environment Variables:
USE_NF4_KERNEL_TPU: Set to “0”, “false”, or “off” to disable TPU kernels.

Default is enabled (“1”).

Example

>>> import jax.numpy as jnp
>>> from eformer.ops.quantization.quantization_functions import (
...     quantize_int8, dequantize_int8,
...     quantize_and_pack_nf4, dequantize_nf4
... )
>>>
>>> # INT8 quantization
>>> weights = jnp.ones((128, 256))
>>> quants, scales = quantize_int8(weights, axis=-1)
>>> reconstructed = dequantize_int8(quants, scales)
>>>
>>> # NF4 quantization
>>> packed, absmax = quantize_and_pack_nf4(weights, block_size=64)
>>> reconstructed = dequantize_nf4(packed, absmax, block_size=64)
eformer.ops.quantization.quantization_functions.bmm_nf4(inputs_ref, quants_ref, scale_ref, outputs_ref, accum_ref, *, block_k)[source]#

Pallas kernel for NF4 matrix multiplication with on-the-fly dequantization.

This kernel performs efficient matrix multiplication by dequantizing weights on-the-fly during computation, avoiding materialization of full-precision weights.

Parameters
  • inputs_ref – Reference to input activation tensor

  • quants_ref – Reference to quantized weight tensor (int4)

  • scale_ref – Reference to scale factors

  • outputs_ref – Reference to output tensor

  • accum_ref – Reference to accumulator tensor

  • block_k – Block size for K dimension (static)

eformer.ops.quantization.quantization_functions.bmm_nf4_transpose(inputs_ref, quants_ref, scale_ref, outputs_ref, accum_ref, *, block_k)[source]#

Pallas kernel for transposed NF4 matrix multiplication.

This kernel handles the transpose case where weights need to be accessed in a different order. Used for backward passes in training.

Parameters
  • inputs_ref – Reference to input activation tensor

  • quants_ref – Reference to quantized weight tensor (int4)

  • scale_ref – Reference to scale factors

  • outputs_ref – Reference to output tensor

  • accum_ref – Reference to accumulator tensor

  • block_k – Block size for K dimension (static)

eformer.ops.quantization.quantization_functions.dequantize_int8(quants, scales)[source]#

Dequantize 8-bit integers back to floating-point values.

Multiplies the quantized int8 values by their corresponding scale factors to reconstruct approximate original values.

Parameters
  • quants (jax.Array) – int8 array containing quantized values in range [-127, 127].

  • scales (jax.Array) – Float array of scale factors, must be broadcastable with quants. Typically has shape with size-1 dimensions where quantization was performed.

Returns

Dequantized array with the same shape as quants, dtype

determined by the scales array.

Return type

jax.Array

Example

>>> quants = jnp.array([[42, 85, 127], [32, 64, 96]], dtype=jnp.int8)
>>> scales = jnp.array([[0.024], [0.047]])  # Shape (2, 1)
>>> result = dequantize_int8(quants, scales)
>>> # result has shape (2, 3) with float values

Note

This is the inverse operation of quantize_int8. Due to rounding during quantization, the reconstructed values are approximate.

eformer.ops.quantization.quantization_functions.dequantize_nf4(packed_values, absmax, block_size)[source]#

Dequantize an array from NF4 (4-bit NormalFloat) format.

High-level API for NF4 dequantization. Unpacks 4-bit values and scales them by per-block absmax values to reconstruct approximate original values.

Parameters
  • packed_values (jax.Array) – uint8 array with packed 4-bit values as produced by quantize_and_pack_nf4. Shape: (…, num_blocks, block_size // 2).

  • absmax (jax.Array) – Per-block scale factors from quantization. Shape: (…, num_blocks).

  • block_size (int) – Number of elements per quantization block. Must match the value used during quantization.

Returns

Dequantized float32 array with shape (…, features),

where features = num_blocks * block_size.

Return type

jax.Array

Example

>>> # Given packed data from quantize_and_pack_nf4
>>> reconstructed = dequantize_nf4(packed, absmax, block_size=64)
>>> # reconstructed approximates original values

See also

  • quantize_and_pack_nf4: Forward quantization operation.

  • single_dequantize_nf4: Internal implementation.

Note

Due to quantization, the reconstructed values are approximate. NF4 provides better accuracy than uniform 4-bit quantization for normally distributed data (typical of neural network weights).

eformer.ops.quantization.quantization_functions.get_nf4()[source]#

Get the NF4 (4-bit NormalFloat) lookup table.

Creates the lookup table lazily to avoid triggering JAX initialization at import time. The NF4 format uses 16 values optimized for Gaussian distributions, providing better accuracy than uniform quantization for neural network weights.

Returns

A 16-element float32 array containing the NF4 codebook

values ranging from -1.0 to 1.0. The values are symmetric around zero and optimized for normally distributed data.

Return type

jax.Array

Note

The codebook values are derived from the quantiles of a standard normal distribution, making NF4 particularly effective for neural network weights which tend to follow Gaussian distributions.

eformer.ops.quantization.quantization_functions.i4tou4(x)[source]#

Convert signed 4-bit integer to unsigned 4-bit integer.

Maps signed values -8 to 7 to unsigned range 0-15. Negative values -8 to -1 become 8 to 15.

Parameters

x (jax.Array) – Input array with signed int4 values (-8 to 7).

Returns

Output array with uint4 values (0 to 15).

Return type

jax.Array

Example

>>> i4tou4(jnp.array([-8, -1, 0, 7]))
Array([8, 15, 0, 7], dtype=int32)
eformer.ops.quantization.quantization_functions.i8tou8(x)[source]#

Convert signed int8 to unsigned uint8.

Handles the two’s complement representation by adding 256 to negative values.

Parameters

x (jax.Array) – Input array with int8 values (-128 to 127).

Returns

Output array with equivalent uint8 values (0 to 255).

Return type

jax.Array

Example

>>> i8tou8(jnp.array([-1, 0, 127], dtype=jnp.int8))
Array([255, 0, 127], dtype=int32)
eformer.ops.quantization.quantization_functions.is_kernel_available()[source]#

Check if NF4 kernels are available on the current device.

Returns

True if running on TPU (where kernels are supported), False otherwise

Return type

bool

eformer.ops.quantization.quantization_functions.nf4_matmul(inputs, *tensors, kernel, backward=False, blocks=None)[source]#

Fast matrix multiplication using Pallas TPU kernels.

This function provides optimized matrix multiplication with automatic block size selection and padding for optimal TPU performance.

Parameters
  • inputs – Input activation tensor

  • *tensors – Quantized weight tensors (quants, scales)

  • kernel – Pallas kernel function to use

  • backward – Whether this is a backward pass

  • blocks – Optional manual block size specification (block_x, block_y, block_k)

Returns

Result of matrix multiplication

eformer.ops.quantization.quantization_functions.nf4_use_kernel(value: bool)[source]#

Context manager to enable/disable NF4 kernel mode.

Note: Kernels are only enabled on TPU devices. On other devices, this setting has no effect and the code will fall back to materialization.

Parameters

value – Whether to enable kernel mode (only effective on TPU)

Example

>>> with nf4_use_kernel(True):
...     result = input @ quantized_weight  # Uses kernel on TPU
eformer.ops.quantization.quantization_functions.nf4xf32_to_f32(x)[source]#

Fast polynomial approximation for NF4 dequantization.

This is significantly faster than table lookups and provides accurate approximation of the NF4 codebook values.

Parameters

x – Integer array (0-15) representing NF4 quantized values

Returns

Float32 array with approximated NF4 values

eformer.ops.quantization.quantization_functions.pack_weights_1bit(quantized_weights: Array) Array[source]#

Packs a JAX array of quantized weights into a compact format using 2 bits per value.

Parameters:#

quantized_weightsjnp.ndarray

An array containing ternary quantized weights {-1, 0, 1}. The first dimension must be a multiple of 4.

Returns:#

jnp.ndarray

A packed jnp.uint8 array.

eformer.ops.quantization.quantization_functions.quantize_and_pack_nf4(blocks, block_size=64)[source]#

Quantize and pack an array using NF4 (4-bit NormalFloat) quantization.

High-level API for NF4 quantization. Quantizes the input array into 4-bit NormalFloat format with block-wise scaling and packs two values per byte.

Parameters
  • blocks (jax.Array) – Input array to quantize. The last dimension must be divisible by block_size. Shape: (…, features).

  • block_size (int) – Number of elements per quantization block. Defaults to 64. Larger blocks use less memory for scales but may have higher quantization error.

Returns

A tuple containing:
  • packed (jax.Array): uint8 array with packed 4-bit values. Shape: (…, num_blocks, block_size // 2).

  • absmax (jax.Array): Per-block scale factors for dequantization. Shape: (…, num_blocks).

Return type

tuple[jax.Array, jax.Array]

Example

>>> weights = jnp.ones((128, 256), dtype=jnp.float32)
>>> packed, absmax = quantize_and_pack_nf4(weights, block_size=64)
>>> # Reconstruct with dequantize_nf4
>>> reconstructed = dequantize_nf4(packed, absmax, block_size=64)

See also

  • dequantize_nf4: Reverse operation to reconstruct values.

  • single_quantize_and_pack_nf4: Internal implementation.

eformer.ops.quantization.quantization_functions.quantize_int8(x: Array, axis: int | tuple = -1)[source]#

Quantize floating-point values to 8-bit integers with per-axis scaling.

Computes a scale factor based on the maximum absolute value along the specified axis, then quantizes values to the int8 range [-127, 127].

Parameters
  • x (jax.Array) – Input array to quantize. Can be any floating-point dtype.

  • axis (int | tuple) – Axis or axes along which to compute the scale factor. Defaults to -1 (last axis). Can be a single int or tuple of ints.

Returns

A tuple containing:
  • quant (jax.Array): int8 array with quantized values in range [-127, 127].

  • scale (jax.Array): Float array of scale factors with shape broadcastable to the input (dimensions along axis are kept as size 1).

Return type

tuple[jax.Array, jax.Array]

Example

>>> x = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> quant, scale = quantize_int8(x, axis=-1)
>>> # quant is int8, scale has shape (2, 1)
>>> reconstructed = dequantize_int8(quant, scale)

Note

A tiny epsilon is added to the scale to avoid division by zero for arrays with all zeros. The scale preserves the original dtype of the input array.

eformer.ops.quantization.quantization_functions.single_dequantize_nf4(packed_values, absmax, block_size)[source]#

Dequantize packed NF4 values back to floating-point.

This function reverses the NF4 quantization by: 1. Unpacking two 4-bit values from each uint8 byte 2. Converting 4-bit indices to NF4 codebook values using polynomial approximation 3. Scaling by per-block absmax values 4. Flattening back to original feature dimension

Parameters
  • packed_values (jax.Array) – uint8 array with packed 4-bit values. Shape (…, num_blocks, block_size // 2).

  • absmax (jax.Array) – Absolute maximum per block used during quantization. Shape (…, num_blocks).

  • block_size (int) – Number of elements per quantization block. Must match the value used during quantization.

Returns

Dequantized float32 array with shape (…, num_blocks * block_size),

which equals the original feature dimension.

Return type

jax.Array

Example

>>> # Given packed data from single_quantize_and_pack_nf4
>>> reconstructed = single_dequantize_nf4(packed, absmax, block_size=64)
>>> reconstructed.shape  # Original shape restored

Note

Uses a polynomial approximation (nf4xf32_to_f32) instead of table lookup for faster dequantization on accelerators.

eformer.ops.quantization.quantization_functions.single_quantize_and_pack_nf4(blocks, block_size=64)[source]#

Quantize and pack an array to NF4 format in a single pass.

This function performs block-wise NF4 quantization by: 1. Reshaping the input into blocks along the last dimension 2. Computing per-block absolute maximum values for scaling 3. Normalizing values by absmax 4. Finding nearest NF4 codebook values 5. Packing two 4-bit values into each uint8 byte

Parameters
  • blocks (jax.Array) – Input array with shape (…, features) where features must be divisible by block_size.

  • block_size (int) – Number of elements per quantization block. Defaults to 64. Must be even for packing.

Returns

A tuple containing:
  • packed (jax.Array): uint8 array with packed 4-bit values. Shape (…, num_blocks, block_size // 2).

  • absmax (jax.Array): Absolute maximum per block for dequantization. Shape (…, num_blocks).

Return type

tuple[jax.Array, jax.Array]

Example

>>> x = jnp.ones((128, 256))  # 256 features, 4 blocks of 64
>>> packed, absmax = single_quantize_and_pack_nf4(x, block_size=64)
>>> packed.shape  # (128, 4, 32) - 4 blocks, 32 bytes each
>>> absmax.shape  # (128, 4) - one scale per block

Note

This is optimized for JAX JIT compilation. For the high-level API, use quantize_and_pack_nf4 instead.

eformer.ops.quantization.quantization_functions.u4toi4(x)[source]#

Convert unsigned 4-bit integer to signed 4-bit integer.

Maps unsigned values 0-15 to signed range -8 to 7. Values 8-15 become negative values -8 to -1.

Parameters

x (jax.Array) – Input array with uint4 values (0 to 15).

Returns

Output array with signed int4 values (-8 to 7).

Return type

jax.Array

Example

>>> u4toi4(jnp.array([0, 7, 8, 15]))
Array([0, 7, -8, -1], dtype=int32)
eformer.ops.quantization.quantization_functions.unpack_weights_1bit(packed: Array, dtype: dtype) Array[source]#

Unpacks a JAX array of quantized weights, matching the logic of the PyTorch original. This function concatenates the unpacked bit groups.

Parameters:#

packedjnp.ndarray

A packed jnp.uint8 array.

dtypejnp.dtype

The dtype of the returned array (e.g., jnp.int8). This is a static argument for JIT.

Returns:#

jnp.ndarray

An unpacked array with ternary values {-1, 0, 1}.