eformer.ops.quantization.implicit_array_8bit#

8-bit Integer Quantization Module.

This module provides 8-bit integer quantization for neural network weights, offering approximately 4x memory reduction compared to float32 with minimal accuracy loss for most applications.

The INT8 format uses per-axis scaling to maintain precision, making it suitable for both inference and quantization-aware training. This module includes:

  • Array8B: Implicit array class for INT8 quantized weights

  • Weight-only quantization: Optimized bf16 @ int8 matmul path

  • Sharding support: Distributed computation with JAX sharding

  • JAX primitive handlers: Transparent integration with JAX operations

Key Features:
  • Per-axis quantization with configurable axis

  • Automatic sharding preservation across operations

  • Optimized bf16 @ int8 weight-only matmul

  • Direct transpose without materialization

Example

>>> import jax.numpy as jnp
>>> from eformer.ops.quantization import Array8B
>>>
>>> # Quantize weights
>>> weights = jnp.ones((128, 256), dtype=jnp.float32)
>>> quantized = Array8B.quantize(weights, axis=(-1,))
>>>
>>> # Use transparently in matrix operations
>>> inputs = jnp.ones((32, 128), dtype=jnp.bfloat16)
>>> output = inputs @ quantized  # Uses optimized bf16 @ int8 path
class eformer.ops.quantization.implicit_array_8bit.Array8B(scale: Array, weight: Array, axis: tuple[int, ...] | None = None, sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None, *, shape: Optional[Sequence[int]] = None, dtype: dtype = None)[source]#

Bases: ImplicitArray

8-bit Quantization Class

This class implements 8-bit quantization for arrays. It quantizes the input array into 8-bit integers and stores the quantization scale factor. The original array can be reconstructed (dequantized) using the stored scale factor.

scale#

The scale factor used for quantization.

Type

jax.Array

weight#

The quantized 8-bit integer array.

Type

jax.Array

axis#

The axis used for quantization (static).

Type

tuple[int, …] | None

sharding#

The sharding specification to preserve across operations (static).

Type

ShardingType

quantize(array, dtype, axis)[source]#

Creates a quantized Array8B from input array.

materialize()[source]#

Reconstructs the original array from the quantized data.

with_sharding(sharding)[source]#

Returns a new Array8B with the specified sharding applied.

axis: tuple[int, ...] | None = None#
commute_ops: ClassVar[bool] = True#
delete()[source]#

Delete the underlying weight and scale arrays to free memory.

Explicitly releases the memory held by the quantized representation. Useful for manual memory management in memory-constrained environments.

property is_sharded: bool#

Returns True if this array has sharding information.

materialize() Array[source]#

Reconstructs the original array from the quantized data.

Returns

The dequantized array with sharding constraint applied if available.

Return type

jax.Array

classmethod quantize(array: Array, dtype: numpy.dtype | None = None, axis: int | tuple[int] | None = None) Array8B[source]#

Initializes the Array8B object by quantizing the input array.

Parameters
  • array (jax.Array) – The input array to be quantized.

  • dtype (jnp.dtype | None) – The dtype for materialization. Defaults to input dtype.

  • axis (int | tuple[int] | None) – The axis for quantization. Defaults to (-1,).

Returns

The quantized array with sharding preserved from input.

Return type

Array8B

reshard(sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None) Array8B[source]#

Alias for with_sharding for API consistency.

scale: Array#
sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None#
warn_on_materialize: ClassVar[bool] = True#
weight: Array#
with_mesh_and_axis(mesh_and_axis: tuple[jax._src.mesh.Mesh, jax.sharding.PartitionSpec | tuple[Any, ...] | None]) Array8B[source]#

Apply sharding using a mesh and axis specification tuple.

Convenience method that creates a NamedSharding from a mesh and axis specification, commonly used with model parallelism utilities.

Parameters

mesh_and_axis – A tuple of (Mesh, axis_spec) where axis_spec can be: - None: Replicated across all devices - PartitionSpec: Use directly - tuple/list: Convert to PartitionSpec - str: Single axis name

Returns

New instance with the specified sharding applied.

Return type

Array8B

with_sharding(sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None) Array8B[source]#

Returns a new Array8B with the specified sharding applied to component arrays.

This method creates a copy of the quantized array with sharding constraints applied to the underlying weight and scale arrays, ensuring they are properly distributed across devices.

Parameters

sharding – A NamedSharding, PartitionSpec, or None. If PartitionSpec is provided, it will be used directly. For NamedSharding, both the mesh and spec are preserved.

Returns

A new instance with sharding applied to component arrays.

Return type

Array8B

eformer.ops.quantization.implicit_array_8bit.add_8bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B)[source]#

Custom handler for JAX’s add operation.

Materializes Array8B inputs before performing the operation.

Parameters
  • x (ArrayType) – First array to add.

  • y (ArrayType) – Second array to add.

Returns

The result of lax.add operation.

eformer.ops.quantization.implicit_array_8bit.broadcast_in_dim_8bit_operand(primitive: Primitive, operand: Array8B, *args, **params) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B[source]#

Custom handler for JAX’s broadcast_in_dim operation.

Broadcasts both weight and scale arrays directly without materialization, preserving the quantized representation. Updates axis mapping for the new shape and preserves sharding.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array8B) – The array to broadcast.

  • *args – Positional arguments for the broadcast operation.

  • **params – Keyword parameters including shape and broadcast_dimensions.

Returns

The broadcasted array with updated shape, axis, and preserved sharding.

Return type

Array8B

eformer.ops.quantization.implicit_array_8bit.concatenate_8bit_operands(primitive: Primitive, operands: Sequence[jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B], *args, **kwargs)[source]#

Custom handler for JAX’s concatenate operation.

Materializes Array8B inputs before performing the operation.

Parameters
  • operands (Sequence[ArrayType]) – Sequence of arrays to concatenate.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.concatenate operation.

eformer.ops.quantization.implicit_array_8bit.conv_general_dilated_8bit_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, *args, **kwargs)[source]#

Custom handler for JAX’s conv_general_dilated operation.

Materializes Array8B inputs before performing the operation.

Parameters
  • lhs (ArrayType) – Left-hand side array (input).

  • rhs (ArrayType) – Right-hand side array (kernel).

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.conv operation.

eformer.ops.quantization.implicit_array_8bit.convert_element_type_8bit_operand_kw(primitive: Primitive, operand: Array8B, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B[source]#

Custom handler for JAX’s convert_element_type operation (keyword args).

For Array8B, updates the stored dtype without actual conversion. The conversion happens lazily during materialization.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array8B) – The array to convert.

  • **kwargs – Keyword arguments including ‘new_dtype’.

Returns

The array with updated dtype metadata.

Return type

ArrayType

eformer.ops.quantization.implicit_array_8bit.convert_element_type_8bit_operand_pos(primitive: Primitive, operand: Array8B, new_dtype: Any) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B[source]#

Custom handler for JAX’s convert_element_type operation (positional args).

For Array8B, updates the stored dtype without actual conversion. The conversion happens lazily during materialization.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array8B) – The array to convert.

  • new_dtype (Any) – The target dtype.

Returns

The array with updated dtype metadata.

Return type

ArrayType

eformer.ops.quantization.implicit_array_8bit.div_8bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B[source]#

Custom handler for JAX’s division operation.

Materializes Array8B inputs before performing element-wise division.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – Dividend array.

  • y (ArrayType) – Divisor array.

Returns

Result of element-wise division x / y.

Return type

ArrayType

eformer.ops.quantization.implicit_array_8bit.dot_general_8bit_lhs_rhs(primitive: Primitive, lhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, rhs: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, *args, **kwargs)[source]#

Custom handler for JAX’s dot_general operation with Array8B support.

When the right operand is Array8B and left is a regular bfloat16 array, uses the optimized weight-only matmul path. Otherwise, materializes Array8B inputs before performing the standard operation.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • lhs (ArrayType) – Left-hand side array (activations).

  • rhs (ArrayType) – Right-hand side array (potentially quantized weights).

  • *args – Variable length argument list including dimension_numbers.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of the dot_general operation.

Return type

jax.Array

Note

The optimized path is triggered when: lhs is regular Array, rhs is Array8B. This enables efficient inference without full weight dequantization.

eformer.ops.quantization.implicit_array_8bit.exp_8bit_x(primitive: Primitive, x: Array8B, *args, **kwargs)[source]#

Custom handler for JAX’s exp operation.

Materializes Array8B input before performing the operation.

Parameters
  • x (ArrayType) – The array to apply exponential to.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.exp operation.

eformer.ops.quantization.implicit_array_8bit.gather_8bit_operand(primitive: Primitive, operand: Array8B, *args, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B[source]#

Custom handler for JAX’s gather operation.

Materializes Array8B input before performing index-based gathering. Returns a regular array (not re-quantized) since gather typically selects arbitrary elements.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array8B) – The source array to gather from.

  • *args – Positional arguments including start_indices.

  • **kwargs – Keyword arguments for the gather operation.

Returns

The gathered values as a regular JAX array.

Return type

jax.Array

eformer.ops.quantization.implicit_array_8bit.integer_pow_8bit_x(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, **kwargs) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B[source]#

Custom handler for JAX’s integer power operation (keyword args).

Materializes Array8B input before performing the power operation.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – Base array.

  • **kwargs – Keyword arguments including ‘y’ for exponent (default: 2).

Returns

Result of x raised to the power y.

Return type

ArrayType

eformer.ops.quantization.implicit_array_8bit.integer_pow_8bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B[source]#

Custom handler for JAX’s integer power operation (positional args).

Materializes Array8B inputs before performing the power operation.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – Base array.

  • y (ArrayType) – Exponent array.

Returns

Result of x raised to the power y.

Return type

ArrayType

eformer.ops.quantization.implicit_array_8bit.log_8bit_x(primitive: Primitive, x: Array8B, *args, **kwargs)[source]#

Custom handler for JAX’s log operation.

Materializes Array8B input before performing the operation.

Parameters
  • x (ArrayType) – The array to apply logarithm to.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.log operation.

eformer.ops.quantization.implicit_array_8bit.lt_8bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, **kwargs)[source]#

Custom handler for JAX’s less-than comparison operation.

Materializes Array8B inputs before performing the comparison.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (ArrayType) – First operand for comparison.

  • y (ArrayType) – Second operand for comparison.

  • **kwargs – Additional keyword arguments for the lt operation.

Returns

Boolean array with element-wise comparison results.

Return type

jax.Array

eformer.ops.quantization.implicit_array_8bit.matmul_bf16_int8_weight_only(lhs_bf16: Array, rhs_q_int8: Array, rhs_scale: Array) Array[source]#

Optimized bfloat16 @ int8 matrix multiplication for weight-only quantization.

This function performs matrix multiplication between bfloat16 activations and int8 quantized weights, with intelligent scale placement for optimal performance. It detects the scale shape and applies the mathematically equivalent but computationally cheaper operation.

Parameters
  • lhs_bf16 (jax.Array) – Left operand in bfloat16, typically activations. Shape: (…, M, K).

  • rhs_q_int8 (jax.Array) – Right operand as int8 quantized weights. Shape: (K, N).

  • rhs_scale (jax.Array) – Scale factors for the int8 weights. Typically (K, 1) for per-row or (1, N) for per-column scaling.

Returns

Result in bfloat16 with shape (…, M, N).

Return type

jax.Array

Scale Placement Strategies:
  • per-column (1, N): Y = (lhs @ W_q) * scale (post-multiply)

  • per-row (K, 1): Y = (lhs * scale^T) @ W_q (pre-multiply)

  • other shapes: Fall back to full dequantization before matmul

Note

The computation uses float32 accumulation internally for numerical stability, then casts back to bfloat16 for the output.

eformer.ops.quantization.implicit_array_8bit.max_8bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, *args, **kwargs)[source]#

Custom handler for JAX’s max operation.

Materializes Array8B inputs before performing the operation.

Parameters
  • x (ArrayType) – First array for max comparison.

  • y (ArrayType) – Second array for max comparison.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.max operation.

eformer.ops.quantization.implicit_array_8bit.mul_8bit_xy(primitive: Primitive, x: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, y: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B)[source]#

Custom handler for JAX’s mul operation.

Materializes Array8B inputs before performing the operation.

Parameters
  • x (ArrayType) – First array to multiply.

  • y (ArrayType) – Second array to multiply.

Returns

The result of lax.mul operation.

eformer.ops.quantization.implicit_array_8bit.reduce_8bit_operand_init_value(primitive: Primitive, operand: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, init_value: jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B, *args, **kwargs)[source]#

Custom handler for JAX’s reduce operation.

Materializes Array8B inputs before performing the operation.

Parameters
  • operand (ArrayType) – The array to be reduced.

  • init_value (ArrayType) – The initial value for the reduction.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The result of lax.reduce operation.

eformer.ops.quantization.implicit_array_8bit.reshape_8bit_operand(primitive: Primitive, operand: Array8B, *args, **params)[source]#

Custom handler for JAX’s reshape operation.

This function handles reshaping for Array8B quantized arrays. It materializes input before reshaping and re-quantizes the result. Preserves sharding from the original array.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • operand (Array8B) – The array to be reshaped.

  • *args – Positional arguments for reshape.

  • **params – Keyword arguments/parameters for reshape.

Returns

The reshaped array, re-quantized with sharding preserved.

Return type

Array8B

Raises

ValueError – If the new shape is not compatible with the original array’s size.

eformer.ops.quantization.implicit_array_8bit.sqrt_8bit_x(primitive: Primitive, x: Array8B) jax.jaxlib._jax.Array | eformer.ops.quantization.implicit_array_8bit.Array8B[source]#

Custom handler for JAX’s square root operation.

Materializes Array8B input before computing element-wise square root.

Parameters
  • primitive (Primitive) – The JAX primitive being handled.

  • x (Array8B) – Input array.

Returns

Element-wise square root of the input.

Return type

ArrayType

eformer.ops.quantization.implicit_array_8bit.transpose_8bit_operand(primitive: Primitive, operand: Array8B, *args, **kwargs)[source]#

Custom handler for JAX’s transpose operation.

Transposes the underlying weight and scale arrays directly. Preserves sharding from the original array.

Parameters
  • primitive – The JAX primitive being handled.

  • operand (Array8B) – The array to be transposed.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

The transposed array with sharding preserved.

Return type

Array8B