eformer.ops.quantization._config#

Quantization configuration and unified interface.

class eformer.ops.quantization._config.QuantizationConfig(dtype: eformer.ops.quantization._config.QuantizationType | str = QuantizationType.NF4, block_size: int = 64, simulate: bool = False, use_kernel: bool = True)[source]#

Bases: object

Configuration for quantization behavior.

This config controls how weights are quantized during training and inference.

dtype#

The quantization type to use (NF4, INT4, INT8, etc.)

Type

eformer.ops.quantization._config.QuantizationType | str

block_size#

Block size for block-wise quantization (default: 64) Only applicable for NF4, Q4_0, and block-quantized formats.

Type

int

simulate#

If True, uses straight-through estimation without actual bit packing. Useful for QAT (quantization-aware training) simulation.

Type

bool

use_kernel#

If True and available, use optimized TPU/GPU kernels. Auto-detected based on device type.

Type

bool

Example

>>> # NF4 quantization with 64-element blocks
>>> config = QuantizationConfig(dtype=QuantizationType.NF4, block_size=64)
>>>
>>> # INT8 quantization
>>> config = QuantizationConfig(dtype=QuantizationType.INT8, block_size=64)
>>>
>>> # Binary quantization
>>> config = QuantizationConfig(dtype=QuantizationType.BINARY)
>>>
>>> # Simulation mode (no actual bit packing)
>>> config = QuantizationConfig(
...     dtype=QuantizationType.NF4,
...     simulate=True  # QAT mode
... )
block_size: int = 64#
dtype: eformer.ops.quantization._config.QuantizationType | str = 'nf4'#
simulate: bool = False#
use_kernel: bool = True#
class eformer.ops.quantization._config.QuantizationType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: StrEnum

Supported quantization types.

BINARY = 'binary'#
INT8 = 'int8'#
NF4 = 'nf4'#
TERNARY = 'ternary'#
eformer.ops.quantization._config.quantize(array: Array, config: eformer.ops.quantization._config.QuantizationConfig | None = None, dtype: eformer.ops.quantization._config.QuantizationType | str | None = None, block_size: int = 64, simulate: bool = False) eformer.ops.quantization.implicit_array_1bit.Array1B | eformer.ops.quantization.implicit_array_8bit.Array8B | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | jax.jaxlib._jax.Array[source]#

Quantize an array using the specified configuration.

This is the unified quantization interface that dispatches to the appropriate quantization implementation based on the dtype.

Parameters
  • array – Input array to quantize (typically float32/bfloat16)

  • config – QuantizationConfig object (if provided, overrides other args)

  • dtype – Quantization type (NF4, INT8, BINARY, TERNARY)

  • block_size – Block size for blockwise quantization

  • simulate – If True, use simulation mode (STE without bit packing)

Returns

Quantized array as ImplicitArray (or regular array if simulate=True)

Example

>>> # Using config
>>> config = QuantizationConfig(dtype=QuantizationType.NF4, block_size=64)
>>> quantized = quantize(weights, config=config)
>>>
>>> # Direct parameters
>>> quantized = quantize(weights, dtype=QuantizationType.NF4, block_size=64)
>>>
>>> # Simulation mode (for QAT)
>>> quantized = quantize(weights, dtype=QuantizationType.NF4, simulate=True)

See also

  • straight_through: Unified STE wrapper for all quantization types

  • QuantizationConfig: Configuration dataclass

  • QuantizationType: Enum of supported types

eformer.ops.quantization._config.straight_through(array: Array, config: eformer.ops.quantization._config.QuantizationConfig | None = None, dtype: eformer.ops.quantization._config.QuantizationType | str | None = None, block_size: int = 64) Array[source]#

Unified straight-through estimator for all quantization types.

This function quantizes in the forward pass but passes gradients straight through to the original float32 weights in the backward pass. Use this for training with quantization awareness.

Parameters
  • array – Input array to quantize (typically trainable weights)

  • config – QuantizationConfig object (if provided, overrides other args)

  • dtype – Quantization type (NF4, INT8, BINARY, TERNARY)

  • block_size – Block size for blockwise quantization

Returns

Materialized quantized array with straight-through gradients

Example

>>> # In training loop
>>> @jax.jit
... def train_step(params, inputs, targets):
...     def loss_fn(params):
...         # Quantize weights with STE
...         quant_w = straight_through(params['weight'], dtype=QuantizationType.NF4)
...         preds = inputs @ quant_w
...         return jnp.mean((preds - targets) ** 2)
...     loss, grads = jax.value_and_grad(loss_fn)(params)
...     # grads flow to float32 params, not quantized weights
...     return loss, grads
Technical Details:
  • Forward: Uses quantized representation (memory efficient)

  • Backward: Gradients bypass quantization (grad_input = grad_output)

  • Always materializes to ensure compatibility with standard ops

  • Underlying float32 params are updated during optimization

See also

  • quantize: Unified quantization interface

  • ste: Low-level STE decorator in jaximus