eformer.ops.quantization.implicit_array_nf4#
4-bit NormalFloat (NF4) Quantization Module.
This module provides NF4 quantization for neural network weights, offering approximately 8x memory reduction compared to float32 while maintaining good accuracy for weights that follow Gaussian distributions.
The NF4 format uses a 16-value codebook optimized for normally distributed data, making it particularly effective for transformer model weights. This module includes:
ArrayNF4: Implicit array class for NF4 quantized weights
TPU kernel support: Optimized Pallas kernels for direct matrix operations
Sharding support: Distributed computation with JAX sharding
JAX primitive handlers: Transparent integration with JAX operations
- Key Features:
Block-wise quantization with configurable block sizes
Automatic sharding preservation across operations
TPU kernel dispatch for efficient matrix multiplication
Fallback to materialization for unsupported operations
Example
>>> import jax.numpy as jnp
>>> from eformer.ops.quantization import ArrayNF4
>>>
>>> # Quantize weights
>>> weights = jnp.ones((128, 256), dtype=jnp.float32)
>>> quantized = ArrayNF4.quantize(weights, block_size=64)
>>>
>>> # Use transparently in matrix operations
>>> inputs = jnp.ones((32, 128))
>>> # On TPU: uses kernel, otherwise materializes
>>> output = inputs @ quantized # Works via registered primitives
- class eformer.ops.quantization.implicit_array_nf4.ArrayNF4(packed: Array, absmax: Array, block_size: int, sharding: ShardingType = None, *, shape: tp.Sequence[int] | None = None, dtype: jnp.dtype = None)[source]#
Bases:
ImplicitArray4-bit NormalFloat Quantization Class
This class implements 4-bit NormalFloat (NF4) quantization for arrays. It quantizes the input array into 4-bit integers and stores the absolute maximum values for each block. The original array can be reconstructed using the stored packed data and absolute maximum values.
- block_size#
The size of each quantization block (static).
- Type
int
- sharding#
The sharding specification to preserve across operations (static).
- Type
ShardingType
- absmax: Array#
- block_size: int#
- commute_ops: tp.ClassVar[bool] = True#
- delete()[source]#
Delete the underlying packed and absmax arrays to free memory.
Explicitly releases the memory held by the quantized representation. Useful for manual memory management in memory-constrained environments.
- property is_sharded: bool#
Returns True if this array has sharding information.
- materialize() Array[source]#
Reconstructs the original array from the quantized data.
- Returns
The dequantized array with sharding constraint applied if available.
- Return type
- packed: Array#
- classmethod quantize(array: Array, block_size: int = 64) ArrayNF4[source]#
Initializes the ArrayNF4 object by quantizing the input array.
- reshard(sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None) ArrayNF4[source]#
Alias for with_sharding for API consistency.
- sharding: ShardingType = None#
- warn_on_materialize: tp.ClassVar[bool] = True#
- with_mesh_and_axis(mesh_and_axis: tuple[jax._src.mesh.Mesh, jax.sharding.PartitionSpec | tuple[Any, ...] | None]) ArrayNF4[source]#
Apply sharding using a mesh and axis specification tuple.
Convenience method that creates a NamedSharding from a mesh and axis specification, commonly used with model parallelism utilities.
- Parameters
mesh_and_axis – A tuple of (Mesh, axis_spec) where axis_spec can be: - None: Replicated across all devices - PartitionSpec: Use directly - tuple/list: Convert to PartitionSpec - str: Single axis name
- Returns
New instance with the specified sharding applied.
- Return type
- with_sharding(sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None) ArrayNF4[source]#
Returns a new ArrayNF4 with the specified sharding applied to component arrays.
This method creates a copy of the quantized array with sharding constraints applied to the underlying packed and absmax arrays, ensuring they are properly distributed across devices.
- Parameters
sharding – A NamedSharding, PartitionSpec, or None. If PartitionSpec is provided, it will be used directly. For NamedSharding, both the mesh and spec are preserved.
- Returns
A new instance with sharding applied to component arrays.
- Return type
- eformer.ops.quantization.implicit_array_nf4.add_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s add operation.
Materializes ArrayNF4 inputs before performing the operation.
- Parameters
primitive – The JAX primitive being handled.
x – First array to add.
y – Second array to add.
- Returns
The result of lax.add operation.
- eformer.ops.quantization.implicit_array_nf4.broadcast_in_dim_nf4_operand(primitive: Primitive, operand: ArrayNF4, *args, **params) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s broadcast_in_dim operation.
Materializes ArrayNF4 input, performs broadcasting, and re-quantizes the result. Preserves sharding from the original array.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
operand (ArrayNF4) – The array to broadcast.
*args – Positional arguments for the broadcast operation.
**params – Keyword parameters including shape and broadcast_dimensions.
- Returns
The broadcasted array, re-quantized as ArrayNF4 with sharding preserved.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_nf4.concatenate_nf4_operands(primitive: Primitive, operands: Sequence[jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any], *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s concatenate operation.
Materializes ArrayNF4 inputs before performing the operation.
- Parameters
primitive – The JAX primitive being handled.
operands – Sequence of arrays to concatenate.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.concatenate operation.
- eformer.ops.quantization.implicit_array_nf4.conv_general_dilated_nf4_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s conv_general_dilated operation.
Materializes ArrayNF4 inputs before performing the operation.
- Parameters
primitive – The JAX primitive being handled.
lhs – Left-hand side array (input).
rhs – Right-hand side array (kernel).
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.conv_general_dilated operation.
- eformer.ops.quantization.implicit_array_nf4.convert_element_type_nf4_operand_kw(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
- eformer.ops.quantization.implicit_array_nf4.convert_element_type_nf4_operand_pos(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, new_dtype: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
- eformer.ops.quantization.implicit_array_nf4.div_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) Any[source]#
- eformer.ops.quantization.implicit_array_nf4.dot_general_nf4_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s dot_general operation.
Supports both kernel-based and materialization-based execution.
- Parameters
primitive – The JAX primitive being handled.
lhs – Left-hand side array.
rhs – Right-hand side array.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.dot_general operation.
- eformer.ops.quantization.implicit_array_nf4.exp_nf4_x(primitive: Primitive, x: ArrayNF4, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s exp operation.
Materializes ArrayNF4 input before performing the operation.
- Parameters
primitive – The JAX primitive being handled.
x – The array to apply exponential to.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.exp operation.
- eformer.ops.quantization.implicit_array_nf4.gather_nf4_operand(primitive: Primitive, operand: ArrayNF4, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s gather operation.
Materializes ArrayNF4 input before performing index-based gathering. Returns a regular array (not re-quantized) since gather typically selects arbitrary elements.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
operand (ArrayNF4) – The source array to gather from.
*args – Positional arguments including start_indices.
**kwargs – Keyword arguments for the gather operation.
- Returns
The gathered values as a regular JAX array.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_nf4.integer_pow_nf4_x(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, **kwargs) Any[source]#
- eformer.ops.quantization.implicit_array_nf4.integer_pow_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) Any[source]#
- eformer.ops.quantization.implicit_array_nf4.log_nf4_x(primitive: Primitive, x: ArrayNF4, **kwargs: Any) Array[source]#
Custom handler for JAX’s log operation.
This function computes the natural logarithm of the input, handling both regular arrays and ArrayNF4 quantized arrays.
- Parameters
primitive – The JAX primitive being handled.
x – The array to apply logarithm to. (Must be ArrayNF4 for this registration)
**kwargs – Additional keyword arguments for the log operation.
- Returns
The result of the natural logarithm operation.
- Raises
RuntimeError – If the log operation fails.
- eformer.ops.quantization.implicit_array_nf4.lt_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, **kwargs)[source]#
- eformer.ops.quantization.implicit_array_nf4.max_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s max operation.
Materializes ArrayNF4 inputs before performing the operation.
- Parameters
primitive – The JAX primitive being handled.
x – First array for max comparison.
y – Second array for max comparison.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.max operation.
- eformer.ops.quantization.implicit_array_nf4.mul_nf4_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s mul operation.
Materializes ArrayNF4 inputs before performing the operation.
- Parameters
primitive – The JAX primitive being handled.
x – First array to multiply.
y – Second array to multiply.
- Returns
The result of lax.mul operation.
- eformer.ops.quantization.implicit_array_nf4.reduce_nf4_operand_init_value(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, init_value: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s reduce operation.
Materializes ArrayNF4 inputs before performing the operation.
- Parameters
primitive – The JAX primitive being handled.
operand – The array to be reduced.
init_value – The initial value for the reduction.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.reduce operation.
- eformer.ops.quantization.implicit_array_nf4.reshape_nf4_operand(primitive: Primitive, operand: ArrayNF4, *args, **params) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s reshape operation.
This function handles reshaping for ArrayNF4 quantized arrays. It materializes ArrayNF4 input before reshaping and re-quantizes the result. Preserves sharding from the original array.
- Parameters
primitive – The JAX primitive being handled.
operand – The ArrayNF4 array to be reshaped.
*args – Positional arguments for reshape (e.g., new_sizes, dimensions).
**params – Keyword arguments/parameters for reshape.
- Returns
The reshaped array, re-quantized as ArrayNF4 with sharding preserved.
- Raises
ValueError – If the new shape is not compatible with the original array’s size.
- eformer.ops.quantization.implicit_array_nf4.safe_delete(arr: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, materialized: bool) None[source]#
Placeholder for safe array deletion after materialization.
This function is provided for API completeness but currently does nothing. JAX arrays are garbage collected automatically.
- Parameters
arr (ArrayType) – The array to potentially delete.
materialized (bool) – Whether the array was materialized.
- eformer.ops.quantization.implicit_array_nf4.safe_materialize(arr: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any) tuple[jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any, bool][source]#
Safely materialize an array if it’s an ArrayNF4 quantized array.
This helper function handles the common pattern of conditionally materializing quantized arrays for operations that don’t support implicit arrays.
- Parameters
arr (ArrayType) – Input that may be ArrayNF4 or a regular array.
- Returns
- A tuple containing:
The materialized array (or original if not ArrayNF4)
Boolean flag indicating if materialization occurred
- Return type
tuple[ArrayType, bool]
- eformer.ops.quantization.implicit_array_nf4.sqrt_nf4_x(primitive: Primitive, x: ArrayNF4) Any[source]#
- eformer.ops.quantization.implicit_array_nf4.transpose_nf4_operand(primitive: Primitive, operand: ArrayNF4, *args: Any, **kwargs: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | Any[source]#
Custom handler for JAX’s transpose operation.
Materializes ArrayNF4 input before performing the operation. Re-quantizes the result if the input was ArrayNF4. Preserves sharding from the original array.
- Parameters
primitive – The JAX primitive being handled.
operand – The array to be transposed.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.transpose operation, potentially re-quantized with sharding preserved.