eformer.ops.quantization.implicit_array_nf4#

4-bit NormalFloat (NF4) Quantization Module.

This module provides NF4 quantization for neural network weights, offering approximately 8x memory reduction compared to float32 while maintaining good accuracy for weights that follow Gaussian distributions.

The NF4 format uses a 16-value codebook optimized for normally distributed data, making it particularly effective for transformer model weights. This module includes:

  • ArrayNF4: Implicit array class for NF4 quantized weights

  • TPU kernel support: Optimized Pallas kernels for direct matrix operations

  • Sharding support: Distributed computation with JAX sharding

  • JAX primitive handlers: Transparent integration with JAX operations

Key Features:
  • Block-wise quantization with configurable block sizes

  • Automatic sharding preservation across operations

  • TPU kernel dispatch for efficient matrix multiplication

  • Fallback to materialization for unsupported operations

Example

>>> import jax.numpy as jnp
>>> from eformer.ops.quantization import ArrayNF4
>>>
>>> # Quantize weights
>>> weights = jnp.ones((128, 256), dtype=jnp.float32)
>>> quantized = ArrayNF4.quantize(weights, block_size=64)
>>>
>>> # Use transparently in matrix operations
>>> inputs = jnp.ones((32, 128))
>>> # On TPU: uses kernel, otherwise materializes
>>> output = inputs @ quantized  # Works via registered primitives
class eformer.ops.quantization.implicit_array_nf4.ArrayNF4(packed: Array, absmax: Array, block_size: int, sharding: ShardingType = None, *, shape: tp.Sequence[int] | None = None, dtype: jnp.dtype = None)[source]#

Bases: ImplicitArray

4-bit NormalFloat Quantization Class

This class implements 4-bit NormalFloat (NF4) quantization for arrays. It quantizes the input array into 4-bit integers and stores the absolute maximum values for each block. The original array can be reconstructed using the stored packed data and absolute maximum values.

packed#

The packed 4-bit integer array.

Type

jax.Array

absmax#

The absolute maximum values for each block.

Type

jax.Array

block_size#

The size of each quantization block (static).

Type

int

sharding#

The sharding specification to preserve across operations (static).

Type

ShardingType

quantize(array, block_size)[source]#

Creates a quantized ArrayNF4 from input array.

materialize()[source]#

Reconstructs the original array from the quantized data.

with_sharding(sharding)[source]#

Returns a new ArrayNF4 with the specified sharding applied.

absmax: Array#
block_size: int#
commute_ops: tp.ClassVar[bool] = True#
delete()[source]#

Delete the underlying packed and absmax arrays to free memory.

Explicitly releases the memory held by the quantized representation. Useful for manual memory management in memory-constrained environments.

dequantize() Array[source]#

Alias for materialize() for compatibility.

property is_sharded: bool#

Returns True if this array has sharding information.

materialize() Array[source]#

Reconstructs the original array from the quantized data.

Returns

The dequantized array with sharding constraint applied if available.

Return type

jax.Array

packed: Array#
classmethod quantize(array: Array, block_size: int = 64) ArrayNF4[source]#

Initializes the ArrayNF4 object by quantizing the input array.

Parameters
  • array (jax.Array) – The input array to be quantized.

  • block_size (int) – The size of each quantization block. Defaults to 64.

  • verbose (bool) – Print verbose information. Defaults to False.

Returns

The quantized array with sharding preserved from input.

Return type

ArrayNF4

reshard(sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None) ArrayNF4[source]#

Alias for with_sharding for API consistency.

sharding: ShardingType = None#
warn_on_materialize: tp.ClassVar[bool] = True#
with_mesh_and_axis(mesh_and_axis: tuple[jax._src.mesh.Mesh, jax.sharding.PartitionSpec | tuple[Any, ...] | None]) ArrayNF4[source]#

Apply sharding using a mesh and axis specification tuple.

Convenience method that creates a NamedSharding from a mesh and axis specification, commonly used with model parallelism utilities.

Parameters

mesh_and_axis – A tuple of (Mesh, axis_spec) where axis_spec can be: - None: Replicated across all devices - PartitionSpec: Use directly - tuple/list: Convert to PartitionSpec - str: Single axis name

Returns

New instance with the specified sharding applied.

Return type

ArrayNF4

with_sharding(sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None) ArrayNF4[source]#

Returns a new ArrayNF4 with the specified sharding applied to component arrays.

This method creates a copy of the quantized array with sharding constraints applied to the underlying packed and absmax arrays, ensuring they are properly distributed across devices.

Parameters

sharding – A NamedSharding, PartitionSpec, or None. If PartitionSpec is provided, it will be used directly. For NamedSharding, both the mesh and spec are preserved.

Returns

A new instance with sharding applied to component arrays.

Return type

ArrayNF4

eformer.ops.quantization.implicit_array_nf4.add_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s add operation.

Materializes ArrayNF4 inputs before performing the operation.

Parameters
  • primitive – The JAX primitive being handled.

  • x – First array to add.

  • y – Second array to add.

Returns

The result of lax.add operation.

eformer.ops.quantization.implicit_array_nf4.broadcast_in_dim_nf4_operand(primitive: Primitive, operand: ArrayNF4, *args, **params) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s broadcast_in_dim operation.

Materializes ArrayNF4 input, performs broadcasting, and re-quantizes the result. Preserves sharding from the original array.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (ArrayNF4) – The array to broadcast.

  • *args – Positional arguments for the broadcast operation.

  • **params – Keyword parameters including shape and broadcast_dimensions.

Returns

The broadcasted array, re-quantized as ArrayNF4 with sharding preserved.

Return type

ArrayType

eformer.ops.quantization.implicit_array_nf4.concatenate_nf4_operands(primitive: Primitive, operands: Sequence[jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any], *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s concatenate operation.

Materializes ArrayNF4 inputs before performing the operation.

Parameters
  • primitive – The JAX primitive being handled.

  • operands – Sequence of arrays to concatenate.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.concatenate operation.

eformer.ops.quantization.implicit_array_nf4.conv_general_dilated_nf4_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s conv_general_dilated operation.

Materializes ArrayNF4 inputs before performing the operation.

Parameters
  • primitive – The JAX primitive being handled.

  • lhs – Left-hand side array (input).

  • rhs – Right-hand side array (kernel).

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.conv_general_dilated operation.

eformer.ops.quantization.implicit_array_nf4.convert_element_type_nf4_operand_kw(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
eformer.ops.quantization.implicit_array_nf4.convert_element_type_nf4_operand_pos(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, new_dtype: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
eformer.ops.quantization.implicit_array_nf4.div_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) Any[source]#
eformer.ops.quantization.implicit_array_nf4.dot_general_nf4_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s dot_general operation.

Supports both kernel-based and materialization-based execution.

Parameters
  • primitive – The JAX primitive being handled.

  • lhs – Left-hand side array.

  • rhs – Right-hand side array.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.dot_general operation.

eformer.ops.quantization.implicit_array_nf4.exp_nf4_x(primitive: Primitive, x: ArrayNF4, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s exp operation.

Materializes ArrayNF4 input before performing the operation.

Parameters
  • primitive – The JAX primitive being handled.

  • x – The array to apply exponential to.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.exp operation.

eformer.ops.quantization.implicit_array_nf4.gather_nf4_operand(primitive: Primitive, operand: ArrayNF4, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s gather operation.

Materializes ArrayNF4 input before performing index-based gathering. Returns a regular array (not re-quantized) since gather typically selects arbitrary elements.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (ArrayNF4) – The source array to gather from.

  • *args – Positional arguments including start_indices.

  • **kwargs – Keyword arguments for the gather operation.

Returns

The gathered values as a regular JAX array.

Return type

ArrayType

eformer.ops.quantization.implicit_array_nf4.integer_pow_nf4_x(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, **kwargs) Any[source]#
eformer.ops.quantization.implicit_array_nf4.integer_pow_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) Any[source]#
eformer.ops.quantization.implicit_array_nf4.log_nf4_x(primitive: Primitive, x: ArrayNF4, **kwargs: Any) Array[source]#

Custom handler for JAX’s log operation.

This function computes the natural logarithm of the input, handling both regular arrays and ArrayNF4 quantized arrays.

Parameters
  • primitive – The JAX primitive being handled.

  • x – The array to apply logarithm to. (Must be ArrayNF4 for this registration)

  • **kwargs – Additional keyword arguments for the log operation.

Returns

The result of the natural logarithm operation.

Raises

RuntimeError – If the log operation fails.

eformer.ops.quantization.implicit_array_nf4.lt_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, **kwargs)[source]#
eformer.ops.quantization.implicit_array_nf4.max_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s max operation.

Materializes ArrayNF4 inputs before performing the operation.

Parameters
  • primitive – The JAX primitive being handled.

  • x – First array for max comparison.

  • y – Second array for max comparison.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.max operation.

eformer.ops.quantization.implicit_array_nf4.mul_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s mul operation.

Materializes ArrayNF4 inputs before performing the operation.

Parameters
  • primitive – The JAX primitive being handled.

  • x – First array to multiply.

  • y – Second array to multiply.

Returns

The result of lax.mul operation.

eformer.ops.quantization.implicit_array_nf4.reduce_nf4_operand_init_value(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, init_value: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s reduce operation.

Materializes ArrayNF4 inputs before performing the operation.

Parameters
  • primitive – The JAX primitive being handled.

  • operand – The array to be reduced.

  • init_value – The initial value for the reduction.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.reduce operation.

eformer.ops.quantization.implicit_array_nf4.reshape_nf4_operand(primitive: Primitive, operand: ArrayNF4, *args, **params) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s reshape operation.

This function handles reshaping for ArrayNF4 quantized arrays. It materializes ArrayNF4 input before reshaping and re-quantizes the result. Preserves sharding from the original array.

Parameters
  • primitive – The JAX primitive being handled.

  • operand – The ArrayNF4 array to be reshaped.

  • *args – Positional arguments for reshape (e.g., new_sizes, dimensions).

  • **params – Keyword arguments/parameters for reshape.

Returns

The reshaped array, re-quantized as ArrayNF4 with sharding preserved.

Raises

ValueError – If the new shape is not compatible with the original array’s size.

eformer.ops.quantization.implicit_array_nf4.safe_delete(arr: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, materialized: bool) None[source]#

Placeholder for safe array deletion after materialization.

This function is provided for API completeness but currently does nothing. JAX arrays are garbage collected automatically.

Parameters
  • arr (ArrayType) – The array to potentially delete.

  • materialized (bool) – Whether the array was materialized.

eformer.ops.quantization.implicit_array_nf4.safe_materialize(arr: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) tuple[jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, bool][source]#

Safely materialize an array if it’s an ArrayNF4 quantized array.

This helper function handles the common pattern of conditionally materializing quantized arrays for operations that don’t support implicit arrays.

Parameters

arr (ArrayType) – Input that may be ArrayNF4 or a regular array.

Returns

A tuple containing:
  • The materialized array (or original if not ArrayNF4)

  • Boolean flag indicating if materialization occurred

Return type

tuple[ArrayType, bool]

eformer.ops.quantization.implicit_array_nf4.sqrt_nf4_x(primitive: Primitive, x: ArrayNF4) Any[source]#
eformer.ops.quantization.implicit_array_nf4.transpose_nf4_operand(primitive: Primitive, operand: ArrayNF4, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#

Custom handler for JAX’s transpose operation.

Materializes ArrayNF4 input before performing the operation. Re-quantizes the result if the input was ArrayNF4. Preserves sharding from the original array.

Parameters
  • primitive – The JAX primitive being handled.

  • operand – The array to be transposed.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.transpose operation, potentially re-quantized with sharding preserved.