eformer.ops.quantization.implicit_array_1bit#
Quantization Module
This module provides functionality for quantizing and dequantizing arrays using two different quantization methods: - 1-bit quantization (Array1B)
These classes are designed to reduce memory usage and computational overhead while maintaining reasonable accuracy for machine learning models. They are built on top of JAX, a high-performance numerical computing library.
- Classes:
Array1B: Implements 1-bit quantization for arrays.
- Usage Example:
```python import jax from eformer.ops.quantization import Array1B, ArrayNF4 from eformer.jaximus import implicit
array = jax.random.normal(jax.random.key(0), (256, 64), “f2”)
qarray = Array1B(array)
n4array = ArrayNF4(array)
- def power(x):
return x**2
print(jax.jit(implicit(power))(qarray)) print(qarray)
- class eformer.ops.quantization.implicit_array_1bit.Array1B(weight: Array, *, shape: Optional[Sequence[int]] = None, dtype: dtype = None)[source]#
Bases:
ImplicitArray1-bit Quantization Class
This class implements 1-bit quantization for arrays. It quantizes the input array into 1-bit integers.
- __init__(self, array
jax.Array): Initializes the Array1B object by quantizing the input array.
- commute_ops: ClassVar[bool] = True#
- delete()[source]#
Delete the underlying weight array to free memory.
Note
This method attempts to delete both weight and scale attributes, but Array1B only has weight. The scale.delete() call may raise an AttributeError as Array1B doesn’t have a scale attribute.
- materialize()[source]#
Reconstructs the original array from the quantized data.
- Returns
The dequantized array.
- Return type
- classmethod quantize(array: Array, dtype: numpy.dtype | None = None, axis: int | tuple[int] | None = None)[source]#
Create an Array1B by quantizing the input array.
Packs the input array containing values {-1, 0, 1} into a compact 2-bit per value format, storing 4 values per byte.
- Parameters
array (jax.Array) – The input array to be quantized. Should contain integer values in {-1, 0, 1} for ternary or {-1, 1} for binary. The first dimension must be a multiple of 4.
dtype (jnp.dtype | None) – The dtype to use for materialization. Defaults to the input array’s dtype.
axis (int | tuple[int] | None) – The quantization axis. Defaults to -1 (last axis).
- Returns
A new Array1B instance containing the packed weights.
- Return type
- Raises
ValueError – If the first dimension is not a multiple of 4.
Example
>>> import jax.numpy as jnp >>> from eformer.ops.quantization import Array1B >>> # Create ternary weights >>> weights = jnp.array([[-1, 0, 1, 1], [1, -1, 0, 0]]) >>> quantized = Array1B.quantize(weights.astype(jnp.int8))
- warn_on_materialize: ClassVar[bool] = True#
- weight: Array#
- eformer.ops.quantization.implicit_array_1bit.add_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool)[source]#
Custom handler for JAX’s add operation.
Materializes Array1B inputs before performing the operation.
- Parameters
x (ArrayType) – First array to add.
y (ArrayType) – Second array to add.
- Returns
The result of lax.add operation.
- eformer.ops.quantization.implicit_array_1bit.broadcast_in_dim_1bit_operand(primitive: Primitive, operand: Array1B, *args, **params) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#
Custom handler for JAX’s broadcast_in_dim operation.
Materializes Array1B input, performs broadcasting, and re-quantizes the result.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
operand (Array1B) – The array to broadcast.
*args – Positional arguments for the broadcast operation.
**params – Keyword parameters including shape and broadcast_dimensions.
- Returns
The broadcasted array, re-quantized as Array1B.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_1bit.concatenate_1bit_operands(primitive: Primitive, operands: Sequence[jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool], *args, **kwargs)[source]#
Custom handler for JAX’s concatenate operation.
Materializes Array1B inputs before performing the operation.
- Parameters
operands (Sequence[ArrayType]) – Sequence of arrays to concatenate.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.concatenate operation.
- eformer.ops.quantization.implicit_array_1bit.conv_general_dilated_1bit_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, *args, **kwargs)[source]#
Custom handler for JAX’s conv_general_dilated operation.
Materializes Array1B inputs before performing the operation.
- Parameters
lhs (ArrayType) – Left-hand side array (input).
rhs (ArrayType) – Right-hand side array (kernel).
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.conv operation.
- eformer.ops.quantization.implicit_array_1bit.convert_element_type_1bit_operand_kw(primitive: Primitive, operand: Array1B, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#
Custom handler for JAX’s convert_element_type operation (keyword args).
For Array1B, updates the stored dtype without actual conversion. The conversion happens lazily during materialization.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
operand (Array1B) – The array to convert.
**kwargs – Keyword arguments including ‘new_dtype’.
- Returns
The array with updated dtype metadata.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_1bit.convert_element_type_1bit_operand_pos(primitive: Primitive, operand: Array1B, new_dtype: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#
Custom handler for JAX’s convert_element_type operation (positional args).
For Array1B, updates the stored dtype without actual conversion. The conversion happens lazily during materialization.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
operand (Array1B) – The array to convert.
new_dtype (Any) – The target dtype.
- Returns
The array with updated dtype metadata.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_1bit.div_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#
Custom handler for JAX’s division operation.
Materializes Array1B inputs before performing element-wise division.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
x (ArrayType) – Dividend array.
y (ArrayType) – Divisor array.
- Returns
Result of element-wise division x / y.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_1bit.div_1bit_xy_fallback(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool) Any[source]#
Fallback handler for JAX’s division operation.
Materializes Array1B inputs before performing element-wise division. This is a duplicate registration to handle additional dispatch cases.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
x (ArrayType) – Dividend array.
y (ArrayType) – Divisor array.
- Returns
Result of element-wise division x / y.
- Return type
Any
- eformer.ops.quantization.implicit_array_1bit.dot_general_1bit_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, *args, **kwargs)[source]#
Custom handler for JAX’s dot_general operation.
Materializes Array1B inputs before performing the operation.
- Parameters
lhs (ArrayType) – Left-hand side array.
rhs (ArrayType) – Right-hand side array.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.dot_general operation.
- eformer.ops.quantization.implicit_array_1bit.exp_1bit_x(primitive: Primitive, x: Array1B, *args, **kwargs)[source]#
Custom handler for JAX’s exp operation.
Materializes Array1B input before performing the operation.
- Parameters
x (ArrayType) – The array to apply exponential to.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.exp operation.
- eformer.ops.quantization.implicit_array_1bit.gather_1bit_operand(primitive: Primitive, operand: Array1B, *args, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#
Custom handler for JAX’s gather operation.
Materializes Array1B input before performing index-based gathering.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
operand (Array1B) – The source array to gather from.
*args – Positional arguments including start_indices.
**kwargs – Keyword arguments for the gather operation.
- Returns
The gathered values as a regular JAX array.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_1bit.integer_pow_1bit_x(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#
Custom handler for JAX’s integer power operation (keyword args).
Materializes Array1B input before performing the power operation.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
x (ArrayType) – Base array.
**kwargs – Keyword arguments including ‘y’ for exponent (default: 2).
- Returns
Result of x raised to the power y.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_1bit.integer_pow_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#
Custom handler for JAX’s integer power operation (positional args).
Materializes Array1B inputs before performing the power operation.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
x (ArrayType) – Base array.
y (ArrayType) – Exponent array.
- Returns
Result of x raised to the power y.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_1bit.log_1bit_x(primitive: Primitive, x: Array1B, *args, **kwargs)[source]#
Custom handler for JAX’s log operation.
Materializes Array1B input before performing the operation.
- Parameters
x (ArrayType) – The array to apply logarithm to.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.log operation.
- eformer.ops.quantization.implicit_array_1bit.lt_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, **kwargs)[source]#
Custom handler for JAX’s less-than comparison operation.
Materializes Array1B inputs before performing the comparison.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
x (ArrayType) – First operand for comparison.
y (ArrayType) – Second operand for comparison.
**kwargs – Additional keyword arguments for the lt operation.
- Returns
Boolean array with element-wise comparison results.
- Return type
- eformer.ops.quantization.implicit_array_1bit.max_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, *args, **kwargs)[source]#
Custom handler for JAX’s max operation.
Materializes Array1B inputs before performing the operation.
- Parameters
x (ArrayType) – First array for max comparison.
y (ArrayType) – Second array for max comparison.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.max operation.
- eformer.ops.quantization.implicit_array_1bit.mul_1bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool)[source]#
Custom handler for JAX’s mul operation.
Materializes Array1B inputs before performing the operation.
- Parameters
x (ArrayType) – First array to multiply.
y (ArrayType) – Second array to multiply.
- Returns
The result of lax.mul operation.
- eformer.ops.quantization.implicit_array_1bit.reduce_1bit_operand_init_value(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, init_value: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool, *args, **kwargs)[source]#
Custom handler for JAX’s reduce operation.
Materializes Array1B inputs before performing the operation.
- Parameters
operand (ArrayType) – The array to be reduced.
init_value (ArrayType) – The initial value for the reduction.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.reduce operation.
- eformer.ops.quantization.implicit_array_1bit.reshape_1bit_operand(primitive: Primitive, operand: Array1B, *args, **params)[source]#
Custom handler for JAX’s reshape operation.
This function handles reshaping for both regular arrays and Array1B quantized arrays. It materializes ArrayNF4 input before reshaping and re-quantizes the result if the input was ArrayNF4.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
operand (ArrayType) – The array to be reshaped.
new_sizes (Tuple[int, ...]) – The desired new shape of the array.
dimensions (Tuple[int, ...], optional) – The order in which dimensions should be permuted before reshaping.
**kwargs – Additional keyword arguments for the reshape operation.
- Returns
The reshaped array, potentially re-quantized if the input was Array1B.
- Return type
ArrayType
- Raises
ValueError – If the new shape is not compatible with the original array’s size.
- eformer.ops.quantization.implicit_array_1bit.sqrt_1bit_x(primitive: Primitive, x: Array1B) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_1bit.Array1B | int | float | bool[source]#
Custom handler for JAX’s square root operation.
Materializes Array1B input before computing element-wise square root.
- Parameters
primitive (Primitive) – The JAX primitive being handled.
x (Array1B) – Input array.
- Returns
Element-wise square root of the input.
- Return type
ArrayType
- eformer.ops.quantization.implicit_array_1bit.transpose_1bit_operand(primitive: Primitive, operand: Array1B, *args, **kwargs)[source]#
Custom handler for JAX’s transpose operation.
Materializes Array1B input before performing the operation. Re-quantizes the result if the input was Array1B.
- Parameters
operand (ArrayType) – The array to be transposed.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
The result of lax.transpose operation, potentially re-quantized.