eformer.ops.quantization._config#
Quantization configuration and unified interface.
- class eformer.ops.quantization._config.QuantizationConfig(dtype: eformer.ops.quantization._config.QuantizationType | str = QuantizationType.NF4, block_size: int = 64, simulate: bool = False, use_kernel: bool = True)[source]#
Bases:
objectConfiguration for quantization behavior.
This config controls how weights are quantized during training and inference.
- dtype#
The quantization type to use (NF4, INT4, INT8, etc.)
- block_size#
Block size for block-wise quantization (default: 64) Only applicable for NF4, Q4_0, and block-quantized formats.
- Type
int
- simulate#
If True, uses straight-through estimation without actual bit packing. Useful for QAT (quantization-aware training) simulation.
- Type
bool
- use_kernel#
If True and available, use optimized TPU/GPU kernels. Auto-detected based on device type.
- Type
bool
Example
>>> # NF4 quantization with 64-element blocks >>> config = QuantizationConfig(dtype=QuantizationType.NF4, block_size=64) >>> >>> # INT8 quantization >>> config = QuantizationConfig(dtype=QuantizationType.INT8, block_size=64) >>> >>> # Binary quantization >>> config = QuantizationConfig(dtype=QuantizationType.BINARY) >>> >>> # Simulation mode (no actual bit packing) >>> config = QuantizationConfig( ... dtype=QuantizationType.NF4, ... simulate=True # QAT mode ... )
- block_size: int = 64#
- dtype: eformer.ops.quantization._config.QuantizationType | str = 'nf4'#
- simulate: bool = False#
- use_kernel: bool = True#
- class eformer.ops.quantization._config.QuantizationType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
StrEnumSupported quantization types.
- BINARY = 'binary'#
- INT8 = 'int8'#
- NF4 = 'nf4'#
- TERNARY = 'ternary'#
- eformer.ops.quantization._config.quantize(array: Array, config: eformer.ops.quantization._config.QuantizationConfig | None = None, dtype: eformer.ops.quantization._config.QuantizationType | str | None = None, block_size: int = 64, simulate: bool = False) eformer.ops.quantization.implicit_array_1bit.Array1B | eformer.ops.quantization.implicit_array_8bit.Array8B | eformer.ops.quantization.implicit_array_nf4.ArrayNF4 | jax.jaxlib._jax.Array[source]#
Quantize an array using the specified configuration.
This is the unified quantization interface that dispatches to the appropriate quantization implementation based on the dtype.
- Parameters
array – Input array to quantize (typically float32/bfloat16)
config – QuantizationConfig object (if provided, overrides other args)
dtype – Quantization type (NF4, INT8, BINARY, TERNARY)
block_size – Block size for blockwise quantization
simulate – If True, use simulation mode (STE without bit packing)
- Returns
Quantized array as ImplicitArray (or regular array if simulate=True)
Example
>>> # Using config >>> config = QuantizationConfig(dtype=QuantizationType.NF4, block_size=64) >>> quantized = quantize(weights, config=config) >>> >>> # Direct parameters >>> quantized = quantize(weights, dtype=QuantizationType.NF4, block_size=64) >>> >>> # Simulation mode (for QAT) >>> quantized = quantize(weights, dtype=QuantizationType.NF4, simulate=True)
See also
straight_through: Unified STE wrapper for all quantization types
QuantizationConfig: Configuration dataclass
QuantizationType: Enum of supported types
- eformer.ops.quantization._config.straight_through(array: Array, config: eformer.ops.quantization._config.QuantizationConfig | None = None, dtype: eformer.ops.quantization._config.QuantizationType | str | None = None, block_size: int = 64) Array[source]#
Unified straight-through estimator for all quantization types.
This function quantizes in the forward pass but passes gradients straight through to the original float32 weights in the backward pass. Use this for training with quantization awareness.
- Parameters
array – Input array to quantize (typically trainable weights)
config – QuantizationConfig object (if provided, overrides other args)
dtype – Quantization type (NF4, INT8, BINARY, TERNARY)
block_size – Block size for blockwise quantization
- Returns
Materialized quantized array with straight-through gradients
Example
>>> # In training loop >>> @jax.jit ... def train_step(params, inputs, targets): ... def loss_fn(params): ... # Quantize weights with STE ... quant_w = straight_through(params['weight'], dtype=QuantizationType.NF4) ... preds = inputs @ quant_w ... return jnp.mean((preds - targets) ** 2) ... loss, grads = jax.value_and_grad(loss_fn)(params) ... # grads flow to float32 params, not quantized weights ... return loss, grads
- Technical Details:
Forward: Uses quantized representation (memory efficient)
Backward: Gradients bypass quantization (grad_input = grad_output)
Always materializes to ensure compatibility with standard ops
Underlying float32 params are updated during optimization
See also
quantize: Unified quantization interface
ste: Low-level STE decorator in jaximus