Source code for eformer.ops.quantization.implicit_array_8bit

# Copyright 2025 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
8-bit Integer Quantization Module.

This module provides 8-bit integer quantization for neural network weights,
offering approximately 4x memory reduction compared to float32 with minimal
accuracy loss for most applications.

The INT8 format uses per-axis scaling to maintain precision, making it suitable
for both inference and quantization-aware training. This module includes:

- Array8B: Implicit array class for INT8 quantized weights
- Weight-only quantization: Optimized bf16 @ int8 matmul path
- Sharding support: Distributed computation with JAX sharding
- JAX primitive handlers: Transparent integration with JAX operations

Key Features:
    - Per-axis quantization with configurable axis
    - Automatic sharding preservation across operations
    - Optimized bf16 @ int8 weight-only matmul
    - Direct transpose without materialization

Example:
    >>> import jax.numpy as jnp
    >>> from eformer.ops.quantization import Array8B
    >>>
    >>> # Quantize weights
    >>> weights = jnp.ones((128, 256), dtype=jnp.float32)
    >>> quantized = Array8B.quantize(weights, axis=(-1,))
    >>>
    >>> # Use transparently in matrix operations
    >>> inputs = jnp.ones((32, 128), dtype=jnp.bfloat16)
    >>> output = inputs @ quantized  # Uses optimized bf16 @ int8 path
"""

from __future__ import annotations

import dataclasses
import typing as tp

import jax
from jax import lax
from jax import numpy as jnp
from jax.extend.core import Primitive
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from eformer.jaximus import ImplicitArray, aux_field, register, ste

from .quantization_functions import dequantize_int8, quantize_int8

Array = jax.Array
ShardingType = NamedSharding | PartitionSpec | None


[docs]@dataclasses.dataclass class Array8B(ImplicitArray): """ 8-bit Quantization Class This class implements 8-bit quantization for arrays. It quantizes the input array into 8-bit integers and stores the quantization scale factor. The original array can be reconstructed (dequantized) using the stored scale factor. Attributes: scale (jax.Array): The scale factor used for quantization. weight (jax.Array): The quantized 8-bit integer array. axis (tuple[int, ...] | None): The axis used for quantization (static). sharding (ShardingType): The sharding specification to preserve across operations (static). Methods: quantize(array, dtype, axis): Creates a quantized Array8B from input array. materialize(): Reconstructs the original array from the quantized data. with_sharding(sharding): Returns a new Array8B with the specified sharding applied. """ scale: Array weight: Array axis: tuple[int, ...] | None = aux_field(default=None) sharding: ShardingType = aux_field(default=None) # noqa: RUF009
[docs] @classmethod def quantize( cls, array: Array, dtype: jnp.dtype | None = None, axis: int | tuple[int] | None = None, ) -> Array8B: """ Initializes the `Array8B` object by quantizing the input array. Args: array (jax.Array): The input array to be quantized. dtype (jnp.dtype | None): The dtype for materialization. Defaults to input dtype. axis (int | tuple[int] | None): The axis for quantization. Defaults to (-1,). Returns: Array8B: The quantized array with sharding preserved from input. """ if axis is None: axis = (-1,) elif not isinstance(axis, tuple): axis = (axis,) # Capture sharding from input array (only NamedSharding, not SingleDeviceSharding) input_sharding = None if hasattr(array, "sharding") and isinstance(array.sharding, NamedSharding): input_sharding = array.sharding weight, scale = quantize_int8(x=array, axis=axis) return cls( weight=weight, scale=scale, shape=array.shape, dtype=dtype or array.dtype, axis=axis, sharding=input_sharding, )
[docs] def materialize(self) -> Array: """ Reconstructs the original array from the quantized data. Returns: jax.Array: The dequantized array with sharding constraint applied if available. """ result = dequantize_int8(self.weight, self.scale).reshape(self.shape).astype(self.dtype) # Apply sharding constraint if available if self.sharding is not None: result = _apply_sharding(result, self.sharding) return result
[docs] def with_sharding(self, sharding: ShardingType) -> Array8B: """ Returns a new Array8B with the specified sharding applied to component arrays. This method creates a copy of the quantized array with sharding constraints applied to the underlying weight and scale arrays, ensuring they are properly distributed across devices. Args: sharding: A NamedSharding, PartitionSpec, or None. If PartitionSpec is provided, it will be used directly. For NamedSharding, both the mesh and spec are preserved. Returns: Array8B: A new instance with sharding applied to component arrays. """ new_weight = _apply_sharding(self.weight, sharding) new_scale = _apply_sharding(self.scale, sharding) return dataclasses.replace( self, weight=new_weight, scale=new_scale, sharding=sharding, )
[docs] def with_mesh_and_axis( self, mesh_and_axis: tuple[Mesh, PartitionSpec | tuple[tp.Any, ...] | None], ) -> Array8B: """ Apply sharding using a mesh and axis specification tuple. Convenience method that creates a NamedSharding from a mesh and axis specification, commonly used with model parallelism utilities. Args: mesh_and_axis: A tuple of (Mesh, axis_spec) where axis_spec can be: - None: Replicated across all devices - PartitionSpec: Use directly - tuple/list: Convert to PartitionSpec - str: Single axis name Returns: Array8B: New instance with the specified sharding applied. """ mesh, axis = mesh_and_axis if axis is None: spec = PartitionSpec() elif isinstance(axis, PartitionSpec): spec = axis elif isinstance(axis, (tuple, list)): spec = PartitionSpec(*axis) else: spec = PartitionSpec(axis) return self.with_sharding(NamedSharding(mesh, spec))
[docs] def reshard(self, sharding: ShardingType) -> Array8B: """Alias for with_sharding for API consistency.""" return self.with_sharding(sharding)
@property def is_sharded(self) -> bool: """Returns True if this array has sharding information.""" return self.sharding is not None
[docs] def delete(self): """ Delete the underlying weight and scale arrays to free memory. Explicitly releases the memory held by the quantized representation. Useful for manual memory management in memory-constrained environments. """ self.weight.delete() self.scale.delete()
def _apply_sharding(array: Array, sharding: ShardingType) -> Array: """ Apply sharding constraint to an array if sharding is specified. Args: array (jax.Array): The array to apply sharding to. sharding (ShardingType): NamedSharding, PartitionSpec, or None. Returns: jax.Array: The array with sharding constraint applied, or the original array if sharding is None. """ if sharding is None: return array from eformer.escale import with_sharding_constraint return with_sharding_constraint(array, sharding) ArrayType = Array | Array8B
[docs]@register("lt") def lt_8bit_xy(primitive: Primitive, x: ArrayType, y: ArrayType, **kwargs): """ Custom handler for JAX's less-than comparison operation. Materializes Array8B inputs before performing the comparison. Args: primitive (Primitive): The JAX primitive being handled. x (ArrayType): First operand for comparison. y (ArrayType): Second operand for comparison. **kwargs: Additional keyword arguments for the lt operation. Returns: jax.Array: Boolean array with element-wise comparison results. """ if isinstance(x, Array8B): x = x.materialize() if isinstance(y, Array8B): y = y.materialize() return jax.lax.lt(x, y, **kwargs)
[docs]@register("convert_element_type") def convert_element_type_8bit_operand_pos(primitive: Primitive, operand: Array8B, new_dtype: tp.Any) -> ArrayType: """ Custom handler for JAX's convert_element_type operation (positional args). For Array8B, updates the stored dtype without actual conversion. The conversion happens lazily during materialization. Args: primitive (Primitive): The JAX primitive being handled. operand (Array8B): The array to convert. new_dtype (Any): The target dtype. Returns: ArrayType: The array with updated dtype metadata. """ if isinstance(operand, Array8B): operand.dtype = new_dtype return operand else: return jax.lax.convert_element_type(operand=operand, new_dtype=new_dtype)
[docs]@register("convert_element_type") def convert_element_type_8bit_operand_kw(primitive: Primitive, operand: Array8B, **kwargs) -> ArrayType: """ Custom handler for JAX's convert_element_type operation (keyword args). For Array8B, updates the stored dtype without actual conversion. The conversion happens lazily during materialization. Args: primitive (Primitive): The JAX primitive being handled. operand (Array8B): The array to convert. **kwargs: Keyword arguments including 'new_dtype'. Returns: ArrayType: The array with updated dtype metadata. """ new_dtype = kwargs.get("new_dtype", jnp.bfloat16) if isinstance(operand, Array8B): operand.dtype = new_dtype return operand else: return jax.lax.convert_element_type(operand=operand, new_dtype=new_dtype)
[docs]@register("integer_pow") def integer_pow_8bit_xy(primitive: Primitive, x: ArrayType, y: ArrayType) -> ArrayType: """ Custom handler for JAX's integer power operation (positional args). Materializes Array8B inputs before performing the power operation. Args: primitive (Primitive): The JAX primitive being handled. x (ArrayType): Base array. y (ArrayType): Exponent array. Returns: ArrayType: Result of x raised to the power y. """ if isinstance(x, Array8B): x = x.materialize() if isinstance(y, Array8B): y = y.materialize() return lax.pow(x, y)
[docs]@register("integer_pow") def integer_pow_8bit_x(primitive: Primitive, x: ArrayType, **kwargs) -> ArrayType: """ Custom handler for JAX's integer power operation (keyword args). Materializes Array8B input before performing the power operation. Args: primitive (Primitive): The JAX primitive being handled. x (ArrayType): Base array. **kwargs: Keyword arguments including 'y' for exponent (default: 2). Returns: ArrayType: Result of x raised to the power y. """ y = kwargs.get("y", 2) if isinstance(x, Array8B): x = x.materialize() return lax.pow(x, y)
[docs]@register("div") def div_8bit_xy(primitive: Primitive, x: ArrayType, y: ArrayType) -> ArrayType: """ Custom handler for JAX's division operation. Materializes Array8B inputs before performing element-wise division. Args: primitive (Primitive): The JAX primitive being handled. x (ArrayType): Dividend array. y (ArrayType): Divisor array. Returns: ArrayType: Result of element-wise division x / y. """ if isinstance(x, Array8B): x = x.materialize() if isinstance(y, Array8B): y = y.materialize() return lax.div(x, y)
[docs]@register("sqrt") def sqrt_8bit_x(primitive: Primitive, x: Array8B) -> ArrayType: """ Custom handler for JAX's square root operation. Materializes Array8B input before computing element-wise square root. Args: primitive (Primitive): The JAX primitive being handled. x (Array8B): Input array. Returns: ArrayType: Element-wise square root of the input. """ x = x.materialize() return lax.sqrt(x)
[docs]def matmul_bf16_int8_weight_only( lhs_bf16: Array, # (..., M, K), bfloat16 rhs_q_int8: Array, # (K, N), int8 rhs_scale: Array, # typically (K, 1) or (1, N) from quantize_int8 ) -> Array: """ Optimized bfloat16 @ int8 matrix multiplication for weight-only quantization. This function performs matrix multiplication between bfloat16 activations and int8 quantized weights, with intelligent scale placement for optimal performance. It detects the scale shape and applies the mathematically equivalent but computationally cheaper operation. Args: lhs_bf16 (jax.Array): Left operand in bfloat16, typically activations. Shape: (..., M, K). rhs_q_int8 (jax.Array): Right operand as int8 quantized weights. Shape: (K, N). rhs_scale (jax.Array): Scale factors for the int8 weights. Typically (K, 1) for per-row or (1, N) for per-column scaling. Returns: jax.Array: Result in bfloat16 with shape (..., M, N). Scale Placement Strategies: - per-column (1, N): Y = (lhs @ W_q) * scale (post-multiply) - per-row (K, 1): Y = (lhs * scale^T) @ W_q (pre-multiply) - other shapes: Fall back to full dequantization before matmul Note: The computation uses float32 accumulation internally for numerical stability, then casts back to bfloat16 for the output. """ rhs_bf16 = rhs_q_int8.astype(jnp.bfloat16) w_shape = rhs_q_int8.shape s_shape = rhs_scale.shape if rhs_q_int8.ndim == 2 and rhs_scale.ndim == 2: K, N = w_shape if s_shape == (1, N): y = jnp.matmul(lhs_bf16, rhs_bf16, preferred_element_type=jnp.float32) extra_dims = y.ndim - 2 scale = rhs_scale.reshape((1,) * extra_dims + (1, N)) y = y * scale return y.astype(jnp.bfloat16) if s_shape == (K, 1): s_vec = rhs_scale.reshape((K,)) bcast_shape = (1,) * (lhs_bf16.ndim - 1) + (K,) s_bcast = s_vec.reshape(bcast_shape) lhs_scaled = lhs_bf16 * s_bcast y = jnp.matmul(lhs_scaled, rhs_bf16, preferred_element_type=jnp.float32) return y.astype(jnp.bfloat16) rhs_deq = rhs_bf16 * rhs_scale.astype(jnp.bfloat16) y = jnp.matmul(lhs_bf16, rhs_deq, preferred_element_type=jnp.float32) return y.astype(jnp.bfloat16)
[docs]@register("dot_general") def dot_general_8bit_lhs_rhs(primitive: Primitive, lhs: ArrayType, rhs: ArrayType, *args, **kwargs): """ Custom handler for JAX's dot_general operation with Array8B support. When the right operand is Array8B and left is a regular bfloat16 array, uses the optimized weight-only matmul path. Otherwise, materializes Array8B inputs before performing the standard operation. Args: primitive (Primitive): The JAX primitive being handled. lhs (ArrayType): Left-hand side array (activations). rhs (ArrayType): Right-hand side array (potentially quantized weights). *args: Variable length argument list including dimension_numbers. **kwargs: Arbitrary keyword arguments. Returns: jax.Array: The result of the dot_general operation. Note: The optimized path is triggered when: lhs is regular Array, rhs is Array8B. This enables efficient inference without full weight dequantization. """ if isinstance(lhs, Array) and isinstance(rhs, Array8B): return matmul_bf16_int8_weight_only( lhs_bf16=lhs, rhs_q_int8=rhs.weight, rhs_scale=rhs.scale, ) if isinstance(lhs, Array8B): lhs = lhs.materialize() if isinstance(rhs, Array8B): rhs = rhs.materialize() return lax.dot_general(lhs, rhs, *args, **kwargs)
[docs]@register("add") def add_8bit_xy(primitive: Primitive, x: ArrayType, y: ArrayType): """ Custom handler for JAX's add operation. Materializes Array8B inputs before performing the operation. Args: x (ArrayType): First array to add. y (ArrayType): Second array to add. Returns: The result of lax.add operation. """ if isinstance(x, Array8B): x = x.materialize() if isinstance(y, Array8B): y = y.materialize() return lax.add(x, y)
[docs]@register("reduce") def reduce_8bit_operand_init_value(primitive: Primitive, operand: ArrayType, init_value: ArrayType, *args, **kwargs): """ Custom handler for JAX's reduce operation. Materializes Array8B inputs before performing the operation. Args: operand (ArrayType): The array to be reduced. init_value (ArrayType): The initial value for the reduction. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: The result of lax.reduce operation. """ if isinstance(operand, Array8B): operand = operand.materialize() if isinstance(init_value, Array8B): init_value = init_value.materialize() return lax.reduce(operand, init_value, *args, **kwargs)
[docs]@register("mul") def mul_8bit_xy(primitive: Primitive, x: ArrayType, y: ArrayType): """ Custom handler for JAX's mul operation. Materializes Array8B inputs before performing the operation. Args: x (ArrayType): First array to multiply. y (ArrayType): Second array to multiply. Returns: The result of lax.mul operation. """ if isinstance(x, Array8B): x = x.materialize() if isinstance(y, Array8B): y = y.materialize() return lax.mul(x, y)
[docs]@register("transpose") def transpose_8bit_operand(primitive: Primitive, operand: Array8B, *args, **kwargs): """ Custom handler for JAX's transpose operation. Transposes the underlying weight and scale arrays directly. Preserves sharding from the original array. Args: primitive: The JAX primitive being handled. operand (Array8B): The array to be transposed. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: Array8B: The transposed array with sharding preserved. """ if "permutation" in kwargs: perm = kwargs["permutation"] elif args: perm = args[0] else: raise TypeError("transpose requires a `permutation` argument") weight_t = lax.transpose(operand.weight, perm) scale_t = lax.transpose(operand.scale, perm) if operand.axis is None: new_axis = None else: new_axis = tuple(perm[a] for a in operand.axis) result = Array8B( weight=weight_t, scale=scale_t, shape=weight_t.shape, dtype=operand.dtype, axis=new_axis, sharding=operand.sharding, # Preserve sharding ) return result
[docs]@register("conv_general_dilated") def conv_general_dilated_8bit_lhs_rhs(primitive: Primitive, lhs: ArrayType, rhs: ArrayType, *args, **kwargs): """ Custom handler for JAX's conv_general_dilated operation. Materializes Array8B inputs before performing the operation. Args: lhs (ArrayType): Left-hand side array (input). rhs (ArrayType): Right-hand side array (kernel). *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: The result of lax.conv operation. """ if isinstance(lhs, Array8B): lhs = lhs.materialize() if isinstance(rhs, Array8B): rhs = rhs.materialize() return lax.conv_general_dilated(lhs, rhs, *args, **kwargs)
[docs]@register("max") def max_8bit_xy(primitive: Primitive, x: ArrayType, y: ArrayType, *args, **kwargs): """ Custom handler for JAX's max operation. Materializes Array8B inputs before performing the operation. Args: x (ArrayType): First array for max comparison. y (ArrayType): Second array for max comparison. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: The result of lax.max operation. """ if isinstance(x, Array8B): x = x.materialize() if isinstance(y, Array8B): y = y.materialize() return lax.max(x, y, *args, **kwargs)
[docs]@register("exp") def exp_8bit_x(primitive: Primitive, x: Array8B, *args, **kwargs): """ Custom handler for JAX's exp operation. Materializes Array8B input before performing the operation. Args: x (ArrayType): The array to apply exponential to. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: The result of lax.exp operation. """ x = x.materialize() return lax.exp(x, *args, **kwargs)
[docs]@register("log") def log_8bit_x(primitive: Primitive, x: Array8B, *args, **kwargs): """ Custom handler for JAX's log operation. Materializes Array8B input before performing the operation. Args: x (ArrayType): The array to apply logarithm to. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: The result of lax.log operation. """ x = x.materialize() return lax.log(x, *args, **kwargs)
[docs]@register("reshape") def reshape_8bit_operand(primitive: Primitive, operand: Array8B, *args, **params): """ Custom handler for JAX's reshape operation. This function handles reshaping for Array8B quantized arrays. It materializes input before reshaping and re-quantizes the result. Preserves sharding from the original array. Args: primitive (Primitive): The JAX primitive being handled. operand (Array8B): The array to be reshaped. *args: Positional arguments for reshape. **params: Keyword arguments/parameters for reshape. Returns: Array8B: The reshaped array, re-quantized with sharding preserved. Raises: ValueError: If the new shape is not compatible with the original array's size. """ array = operand.materialize() subfuns, bind_params = primitive.get_bind_params(params) result = primitive.bind(*subfuns, array, *args, **bind_params) result = Array8B.quantize(result, dtype=operand.dtype) # Preserve sharding from original operand if operand.sharding is not None: result = result.with_sharding(operand.sharding) return result
[docs]@register("concatenate") def concatenate_8bit_operands(primitive: Primitive, operands: tp.Sequence[ArrayType], *args, **kwargs): """ Custom handler for JAX's concatenate operation. Materializes Array8B inputs before performing the operation. Args: operands (Sequence[ArrayType]): Sequence of arrays to concatenate. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: The result of lax.concatenate operation. """ materialized_operands = [op.materialize() if isinstance(op, Array8B) else op for op in operands] return lax.concatenate(materialized_operands, *args, **kwargs)
[docs]@register("broadcast_in_dim") def broadcast_in_dim_8bit_operand(primitive: Primitive, operand: Array8B, *args, **params) -> ArrayType: """ Custom handler for JAX's broadcast_in_dim operation. Broadcasts both weight and scale arrays directly without materialization, preserving the quantized representation. Updates axis mapping for the new shape and preserves sharding. Args: primitive (Primitive): The JAX primitive being handled. operand (Array8B): The array to broadcast. *args: Positional arguments for the broadcast operation. **params: Keyword parameters including shape and broadcast_dimensions. Returns: Array8B: The broadcasted array with updated shape, axis, and preserved sharding. """ subfuns, bind_params = primitive.get_bind_params(params) shape = bind_params["shape"] broadcast_dimensions = bind_params["broadcast_dimensions"] weight_b = primitive.bind(*subfuns, operand.weight, *args, **bind_params) scale_b = primitive.bind(*subfuns, operand.scale, *args, **bind_params) if operand.axis is None: new_axis = None else: new_axis = tuple(broadcast_dimensions[a] for a in operand.axis) return Array8B( weight=weight_b, scale=scale_b, shape=shape, dtype=operand.dtype, axis=new_axis, sharding=operand.sharding, # Preserve sharding )
[docs]@register("gather") def gather_8bit_operand(primitive: Primitive, operand: Array8B, *args, **kwargs) -> ArrayType: """ Custom handler for JAX's gather operation. Materializes Array8B input before performing index-based gathering. Returns a regular array (not re-quantized) since gather typically selects arbitrary elements. Args: primitive (Primitive): The JAX primitive being handled. operand (Array8B): The source array to gather from. *args: Positional arguments including start_indices. **kwargs: Keyword arguments for the gather operation. Returns: jax.Array: The gathered values as a regular JAX array. """ array = operand.materialize() result = jax.lax.gather(array, *args, **kwargs) return result
@ste def straight_through_8bit(weights: jax.Array, axis: int | None = None): """ Straight-through estimator for 8-bit quantization. Quantizes weights to int8 format in the forward pass, but passes gradients straight through (unchanged) in the backward pass. This enables quantization-aware training where the model learns to compensate for quantization effects. Args: weights (jax.Array): Input weights to quantize. Typically float32 or bfloat16 neural network parameters. axis (int | None): Axis along which to compute quantization scale. Defaults to None (uses -1, the last axis). Returns: jax.Array: Materialized quantized weights with the same shape as input. Forward pass returns quantized values, backward pass passes gradients through unchanged to enable training. Example: >>> import jax >>> import jax.numpy as jnp >>> from eformer.ops.quantization import straight_through_8bit >>> >>> # Use in a training step for QAT >>> @jax.jit ... def forward(params, x): ... # Apply INT8 quantization with STE ... w = straight_through_8bit(params['weight'], axis=-1) ... return x @ w >>> >>> # Gradients flow to original float32 params >>> grad_fn = jax.grad(lambda p, x: forward(p, x).sum()) Note: The @ste decorator makes this function differentiable by implementing the identity gradient (grad_input = grad_output) in the backward pass. """ return Array8B.quantize(weights, axis=axis).materialize()