eformer.ops.quantization.implicit_array_1bit#

Quantization Module

This module provides functionality for quantizing and dequantizing arrays using two different quantization methods: - 1-bit quantization (Array1B)

These classes are designed to reduce memory usage and computational overhead while maintaining reasonable accuracy for machine learning models. They are built on top of JAX, a high-performance numerical computing library.

Classes:
  • Array1B: Implements 1-bit quantization for arrays.

Usage Example:

```python import jax from eformer.ops.quantization import Array1B, ArrayNF4 from eformer.jaximus import implicit

array = jax.random.normal(jax.random.key(0), (256, 64), “f2”)

qarray = Array1B(array)

n4array = ArrayNF4(array)

def power(x):

return x**2

print(jax.jit(implicit(power))(qarray)) print(qarray)

print(jax.jit(implicit(power))(n4array)) print(n4array) ```

class eformer.ops.quantization.implicit_array_1bit.Array1B(weight: Array, *, shape: Optional[Sequence[int]] = None, dtype: dtype = None)[source]#

Bases: ImplicitArray

1-bit Quantization Class

This class implements 1-bit quantization for arrays. It quantizes the input array into 1-bit integers.

weight#

The quantized 1-bit integer array.

Type

jax.Array

__init__(self, array

jax.Array): Initializes the Array1B object by quantizing the input array.

materialize(self)[source]#

Reconstructs the original array from the quantized data.

commute_ops: ClassVar[bool] = True#
delete()[source]#

Delete the underlying weight array to free memory.

Note

This method attempts to delete both weight and scale attributes, but Array1B only has weight. The scale.delete() call may raise an AttributeError as Array1B doesn’t have a scale attribute.

materialize()[source]#

Reconstructs the original array from the quantized data.

Returns

The dequantized array.

Return type

jax.Array

classmethod quantize(array: Array, dtype: numpy.dtype | None = None, axis: int | tuple[int] | None = None)[source]#

Create an Array1B by quantizing the input array.

Packs the input array containing values {-1, 0, 1} into a compact 2-bit per value format, storing 4 values per byte.

Parameters
  • array (jax.Array) – The input array to be quantized. Should contain integer values in {-1, 0, 1} for ternary or {-1, 1} for binary. The first dimension must be a multiple of 4.

  • dtype (jnp.dtype | None) – The dtype to use for materialization. Defaults to the input array’s dtype.

  • axis (int | tuple[int] | None) – The quantization axis. Defaults to -1 (last axis).

Returns

A new Array1B instance containing the packed weights.

Return type

Array1B

Raises

ValueError – If the first dimension is not a multiple of 4.

Example

>>> import jax.numpy as jnp
>>> from eformer.ops.quantization import Array1B
>>> # Create ternary weights
>>> weights = jnp.array([[-1, 0, 1, 1], [1, -1, 0, 0]])
>>> quantized = Array1B.quantize(weights.astype(jnp.int8))
warn_on_materialize: ClassVar[bool] = True#
weight: Array#
eformer.ops.quantization.implicit_array_1bit.add_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool)[source]#

Custom handler for JAX’s add operation.

Materializes Array1B inputs before performing the operation.

Parameters
  • x (ArrayType) – First array to add.

  • y (ArrayType) – Second array to add.

Returns

The result of lax.add operation.

eformer.ops.quantization.implicit_array_1bit.broadcast_in_dim_1bit_operand(primitive: Primitive, operand: Array1B, *args, **params) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#

Custom handler for JAX’s broadcast_in_dim operation.

Materializes Array1B input, performs broadcasting, and re-quantizes the result.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array1B) – The array to broadcast.

  • *args – Positional arguments for the broadcast operation.

  • **params – Keyword parameters including shape and broadcast_dimensions.

Returns

The broadcasted array, re-quantized as Array1B.

Return type

ArrayType

eformer.ops.quantization.implicit_array_1bit.concatenate_1bit_operands(primitive: Primitive, operands: Sequence[jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool], *args, **kwargs)[source]#

Custom handler for JAX’s concatenate operation.

Materializes Array1B inputs before performing the operation.

Parameters
  • operands (Sequence[ArrayType]) – Sequence of arrays to concatenate.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.concatenate operation.

eformer.ops.quantization.implicit_array_1bit.conv_general_dilated_1bit_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, *args, **kwargs)[source]#

Custom handler for JAX’s conv_general_dilated operation.

Materializes Array1B inputs before performing the operation.

Parameters
  • lhs (ArrayType) – Left-hand side array (input).

  • rhs (ArrayType) – Right-hand side array (kernel).

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.conv operation.

eformer.ops.quantization.implicit_array_1bit.convert_element_type_1bit_operand_kw(primitive: Primitive, operand: Array1B, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#

Custom handler for JAX’s convert_element_type operation (keyword args).

For Array1B, updates the stored dtype without actual conversion. The conversion happens lazily during materialization.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array1B) – The array to convert.

  • **kwargs – Keyword arguments including ‘new_dtype’.

Returns

The array with updated dtype metadata.

Return type

ArrayType

eformer.ops.quantization.implicit_array_1bit.convert_element_type_1bit_operand_pos(primitive: Primitive, operand: Array1B, new_dtype: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#

Custom handler for JAX’s convert_element_type operation (positional args).

For Array1B, updates the stored dtype without actual conversion. The conversion happens lazily during materialization.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array1B) – The array to convert.

  • new_dtype (Any) – The target dtype.

Returns

The array with updated dtype metadata.

Return type

ArrayType

eformer.ops.quantization.implicit_array_1bit.div_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#

Custom handler for JAX’s division operation.

Materializes Array1B inputs before performing element-wise division.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – Dividend array.

  • y (ArrayType) – Divisor array.

Returns

Result of element-wise division x / y.

Return type

ArrayType

eformer.ops.quantization.implicit_array_1bit.div_1bit_xy_fallback(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool) Any[source]#

Fallback handler for JAX’s division operation.

Materializes Array1B inputs before performing element-wise division. This is a duplicate registration to handle additional dispatch cases.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – Dividend array.

  • y (ArrayType) – Divisor array.

Returns

Result of element-wise division x / y.

Return type

Any

eformer.ops.quantization.implicit_array_1bit.dot_general_1bit_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, *args, **kwargs)[source]#

Custom handler for JAX’s dot_general operation.

Materializes Array1B inputs before performing the operation.

Parameters
  • lhs (ArrayType) – Left-hand side array.

  • rhs (ArrayType) – Right-hand side array.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.dot_general operation.

eformer.ops.quantization.implicit_array_1bit.exp_1bit_x(primitive: Primitive, x: Array1B, *args, **kwargs)[source]#

Custom handler for JAX’s exp operation.

Materializes Array1B input before performing the operation.

Parameters
  • x (ArrayType) – The array to apply exponential to.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.exp operation.

eformer.ops.quantization.implicit_array_1bit.gather_1bit_operand(primitive: Primitive, operand: Array1B, *args, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#

Custom handler for JAX’s gather operation.

Materializes Array1B input before performing index-based gathering.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array1B) – The source array to gather from.

  • *args – Positional arguments including start_indices.

  • **kwargs – Keyword arguments for the gather operation.

Returns

The gathered values as a regular JAX array.

Return type

ArrayType

eformer.ops.quantization.implicit_array_1bit.integer_pow_1bit_x(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#

Custom handler for JAX’s integer power operation (keyword args).

Materializes Array1B input before performing the power operation.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – Base array.

  • **kwargs – Keyword arguments including ‘y’ for exponent (default: 2).

Returns

Result of x raised to the power y.

Return type

ArrayType

eformer.ops.quantization.implicit_array_1bit.integer_pow_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#

Custom handler for JAX’s integer power operation (positional args).

Materializes Array1B inputs before performing the power operation.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – Base array.

  • y (ArrayType) – Exponent array.

Returns

Result of x raised to the power y.

Return type

ArrayType

eformer.ops.quantization.implicit_array_1bit.log_1bit_x(primitive: Primitive, x: Array1B, *args, **kwargs)[source]#

Custom handler for JAX’s log operation.

Materializes Array1B input before performing the operation.

Parameters
  • x (ArrayType) – The array to apply logarithm to.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.log operation.

eformer.ops.quantization.implicit_array_1bit.lt_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, **kwargs)[source]#

Custom handler for JAX’s less-than comparison operation.

Materializes Array1B inputs before performing the comparison.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – First operand for comparison.

  • y (ArrayType) – Second operand for comparison.

  • **kwargs – Additional keyword arguments for the lt operation.

Returns

Boolean array with element-wise comparison results.

Return type

jax.Array

eformer.ops.quantization.implicit_array_1bit.max_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, *args, **kwargs)[source]#

Custom handler for JAX’s max operation.

Materializes Array1B inputs before performing the operation.

Parameters
  • x (ArrayType) – First array for max comparison.

  • y (ArrayType) – Second array for max comparison.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.max operation.

eformer.ops.quantization.implicit_array_1bit.mul_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool)[source]#

Custom handler for JAX’s mul operation.

Materializes Array1B inputs before performing the operation.

Parameters
  • x (ArrayType) – First array to multiply.

  • y (ArrayType) – Second array to multiply.

Returns

The result of lax.mul operation.

eformer.ops.quantization.implicit_array_1bit.reduce_1bit_operand_init_value(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, init_value: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, *args, **kwargs)[source]#

Custom handler for JAX’s reduce operation.

Materializes Array1B inputs before performing the operation.

Parameters
  • operand (ArrayType) – The array to be reduced.

  • init_value (ArrayType) – The initial value for the reduction.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.reduce operation.

eformer.ops.quantization.implicit_array_1bit.reshape_1bit_operand(primitive: Primitive, operand: Array1B, *args, **params)[source]#

Custom handler for JAX’s reshape operation.

This function handles reshaping for both regular arrays and Array1B quantized arrays. It materializes ArrayNF4 input before reshaping and re-quantizes the result if the input was ArrayNF4.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (ArrayType) – The array to be reshaped.

  • new_sizes (Tuple[int, ...]) – The desired new shape of the array.

  • dimensions (Tuple[int, ...], optional) – The order in which dimensions should be permuted before reshaping.

  • **kwargs – Additional keyword arguments for the reshape operation.

Returns

The reshaped array, potentially re-quantized if the input was Array1B.

Return type

ArrayType

Raises

ValueError – If the new shape is not compatible with the original array’s size.

eformer.ops.quantization.implicit_array_1bit.sqrt_1bit_x(primitive: Primitive, x: Array1B) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#

Custom handler for JAX’s square root operation.

Materializes Array1B input before computing element-wise square root.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (Array1B) – Input array.

Returns

Element-wise square root of the input.

Return type

ArrayType

eformer.ops.quantization.implicit_array_1bit.transpose_1bit_operand(primitive: Primitive, operand: Array1B, *args, **kwargs)[source]#

Custom handler for JAX’s transpose operation.

Materializes Array1B input before performing the operation. Re-quantizes the result if the input was Array1B.

Parameters
  • operand (ArrayType) – The array to be transposed.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.transpose operation, potentially re-quantized.