eformer.jaximus._imus#

Implicit Array System and JAX Primitive Interception.

This module provides a framework for creating and manipulating “implicit arrays” - custom array-like objects that defer expensive operations (like materialization from quantized formats) until absolutely necessary. It enables transparent integration with JAX by intercepting primitive operations and routing them to custom handlers.

Core Components:
  • ImplicitArray: Abstract base class for lazy/deferred array representations

  • ste (Straight-Through Estimator): Decorator for gradient pass-through in quantization

  • use_implicit/implicit: Context manager for enabling implicit array dispatch

  • register: Decorator for registering custom primitive handlers

  • _CustomTrace: JAX trace implementation for intercepting operations

The system uses JAX’s tracing infrastructure to intercept operations on ImplicitArray instances and dispatch them to registered handlers, allowing custom behavior while maintaining compatibility with JAX transformations (jit, grad, vmap, etc.).

Example

>>> from eformer.jaximus import ImplicitArray, register, implicit
>>> from eformer.ops.quantization import ArrayNF4
>>>
>>> # Create a quantized weight array
>>> weight = jnp.ones((128, 64), dtype=jnp.float32)
>>> nf4_weight = ArrayNF4.quantize(weight, block_size=64)
>>>
>>> # Use implicit dispatch to avoid premature materialization
>>> @implicit
... def linear(x, w):
...     return x @ w  # Uses custom dot_general handler for NF4
>>>
>>> output = linear(inputs, nf4_weight)  # NF4 kernel used, no materialization
class eformer.jaximus._imus.ImplicitArray(*, shape: Optional[Sequence[int]] = None, dtype: dtype = None)[source]#

Bases: _ArrayBase

Abstract base class for implicit (lazy/deferred) array representations.

ImplicitArray enables custom array-like types that defer expensive operations until materialization is required. This is particularly useful for quantized arrays (NF4, INT4, INT8) where you want to: 1. Store data in compressed format to save memory 2. Perform operations directly on compressed data when possible (via custom kernels) 3. Only materialize to full precision when necessary

The ImplicitArray system integrates with JAX’s tracing infrastructure to intercept operations and dispatch to custom handlers registered via @register decorator.

shape#

Tuple representing the logical shape of the array. Uses _AvalDescriptor for lazy initialization.

Type

Optional[Sequence[int]]

dtype#

JAX dtype of the materialized array. Uses _AvalDescriptor for lazy initialization.

Type

numpy.dtype

Subclass Requirements:
  1. Must be a dataclass

  2. Must implement materialize() method

  3. Should register custom handlers for primitives via @register

  4. Non-array fields should use aux_field() to mark them as auxiliary

Example Subclass:
>>> from dataclasses import dataclass
>>> from eformer.jaximus import ImplicitArray, aux_field, register
>>>
>>> @dataclass
>>> class ArrayNF4(ImplicitArray):
...     packed: jax.Array           # 4-bit packed data
...     absmax: jax.Array           # Scale factors
...     block_size: int = aux_field()  # Static metadata
...
...     def materialize(self):
...         # Dequantize back to float32
...         return dequantize_nf4(self.packed, self.absmax, self.block_size)
>>>
>>> # Register custom handler for matrix multiplication
>>> @register("dot_general")
>>> def nf4_matmul(lhs, rhs: ArrayNF4, **kwargs):
...     # Use optimized NF4 kernel instead of materializing
...     return nf4_kernel(lhs, rhs.packed, rhs.absmax)
Usage:
>>> # Create quantized weight
>>> weight = jnp.randn(128, 64)
>>> nf4_weight = ArrayNF4.quantize(weight, block_size=64)
>>>
>>> # Use with implicit dispatch
>>> @implicit
... def linear(x, w):
...     return x @ w  # Automatically uses nf4_matmul handler
>>>
>>> output = linear(inputs, nf4_weight)  # No materialization!
Technical Details:
  • Registered as a JAX pytree for compatibility with transformations

  • Integrates with JAX’s abstract value (aval) system for shape inference

  • Uses _AvalDescriptor for lazy shape/dtype discovery

  • Supports nested implicit arrays via tree_flatten_with_keys

  • Automatic materialization fallback when no custom handler exists

See also

  • register: Decorator for registering custom primitive handlers

  • use_implicit/implicit: Enable implicit array dispatch

  • ste: Straight-through estimator for quantization-aware training

astype(new_dtype)[source]#
property aval#
compute_dtype()[source]#
compute_shape()[source]#
classmethod default_handler(primitive, *args, params=None)[source]#
dtype: dtype = None#
abstract materialize()[source]#
shape: Optional[Sequence[int]] = None#
tree_flatten_with_keys()[source]#
classmethod tree_unflatten(aux_data, children)[source]#
class eformer.jaximus._imus.OrginArray[source]#

Bases: ABC

Abstract base class for array-like types in the implicit array system.

This class serves as a registry for types that should be treated as “original” array types. jax.Array is automatically registered.

The class is used in type checking and dispatch logic to determine whether a value is an array-like type that can participate in implicit array operations.

exception eformer.jaximus._imus.UninitializedAval[source]#

Bases: Exception

Exception raised when accessing uninitialized abstract value attributes.

This is raised when trying to access shape or dtype on an ImplicitArray before these values have been computed or set.

eformer.jaximus._imus.aux_field(metadata=None, **kwargs)[source]#

Create a dataclass field marked as auxiliary (non-pytree) data.

In ImplicitArray subclasses, fields can be either: 1. Pytree children: Arrays and nested structures that should be traced by JAX 2. Auxiliary data: Static metadata (ints, strings, dtypes) that don’t participate in tracing

This function creates fields marked as auxiliary, which means: - Not included in pytree flattening/unflattening - Not traced by JAX transformations (jit, grad, vmap) - Typically used for static configuration (block_size, dtype, mesh info)

Parameters
  • metadata – Optional existing metadata dict to extend

  • **kwargs – Additional arguments passed to dataclasses.field()

Returns

A dataclass field with auxiliary metadata set

Example

>>> from dataclasses import dataclass
>>> from eformer.jaximus import ImplicitArray, aux_field
>>>
>>> @dataclass
>>> class ArrayNF4(ImplicitArray):
...     # These are pytree children (traced by JAX)
...     packed: jax.Array
...     absmax: jax.Array
...
...     # These are auxiliary (static metadata)
...     block_size: int = aux_field()
...     dtype: jnp.dtype = aux_field(default=jnp.float32)
...     mesh_config: tuple | None = aux_field(default=None)
Technical Details:
  • Sets metadata[“implicit_array_aux”] = True

  • Used by tree_flatten_with_keys and tree_unflatten

  • Auxiliary fields are passed as aux_data in pytree registration

  • Must use this for non-array fields to avoid JAX tracing errors

See also

  • ImplicitArray.tree_flatten_with_keys: Uses aux_field metadata

  • _get_names_and_aux: Helper that reads this metadata

eformer.jaximus._imus.combine_leaf_predicate(base_fn, is_leaf)[source]#

Wrap a tree function to include ImplicitArray as a leaf type.

Creates a new function that combines the given is_leaf predicate with an additional predicate, allowing custom leaf detection.

Parameters
  • base_fn – Original tree function (e.g., tu.tree_map).

  • is_leaf – Predicate function to treat values as leaves.

Returns

Wrapped function that uses the combined leaf predicate.

eformer.jaximus._imus.default_handler(primitive, *args, **params)[source]#

Default handler that executes a JAX primitive normally.

Parameters
  • primitive – JAX primitive to execute.

  • *args – Arguments to the primitive.

  • **params – Parameters for the primitive.

Returns

Result of executing the primitive with the given arguments.

eformer.jaximus._imus.flatten_one_implicit_layer(tree)[source]#

Flatten one layer of nested ImplicitArrays in a tree.

For nested ImplicitArray structures, this function flattens just one level, treating nested ImplicitArrays as leaves.

Parameters

tree – Pytree potentially containing nested ImplicitArrays.

Returns

Tuple of (leaves, structure) where leaves are flattened one level and structure is the pytree structure.

eformer.jaximus._imus.implicit(fn)#

Enable implicit array dispatch for a function.

This decorator/wrapper sets up a custom JAX trace that intercepts operations on ImplicitArray instances and routes them to registered handlers. This allows transparent use of quantized or lazy arrays without manual materialization.

Parameters

fn – Function to wrap with implicit array support.

Returns

Wrapped function that handles ImplicitArray instances transparently.

Example

>>> @use_implicit
... def matmul(x, w):
...     return x @ w  # Automatically uses custom dot_general for NF4
>>>
>>> # Or use as context:
>>> with implicit:
...     output = inputs @ nf4_weights  # Custom handler dispatched
Technical Details:
  • Creates a custom JAX trace (_CustomTrace) that intercepts primitive operations

  • Wraps ImplicitArray instances in _CustomTracer for operation interception

  • Dispatches to registered handlers via the @register decorator

  • Falls back to materialization if no handler is registered

  • Maintains compatibility with JAX transformations (jit, grad, vmap)

See also

  • register: Decorator for registering custom primitive handlers

  • ImplicitArray: Base class for implicit array implementations

  • _CustomTrace: The trace implementation that handles dispatch

eformer.jaximus._imus.implicit_depth(tree)[source]#

Calculate the maximum nesting depth of ImplicitArrays in a tree.

Parameters

tree – Pytree potentially containing nested ImplicitArrays.

Returns

Integer depth of nesting. 0 means no ImplicitArrays, 1 means ImplicitArrays with no nesting, etc.

eformer.jaximus._imus.is_array(element: Any) bool[source]#

Check if an element is an array type (NumPy or JAX).

Parameters

element – Value to check.

Returns

True if element is a NumPy array, NumPy scalar, or JAX array.

eformer.jaximus._imus.leaf_predicate(x)[source]#

Predicate for identifying ImplicitArray leaves in tree operations.

Parameters

x – Value to check.

Returns

True if x is an ImplicitArray instance.

eformer.jaximus._imus.materialize_handler(primitive, *vals, params)[source]#

Handler that materializes all ImplicitArrays before executing primitive.

This is the fallback handler used when no custom handler is registered for a primitive operation involving ImplicitArrays.

Parameters
  • primitive – JAX primitive to execute.

  • *vals – Values that may include ImplicitArrays.

  • params – Parameters for the primitive.

Returns

Result of executing the primitive after materializing all values.

eformer.jaximus._imus.materialize_nested(implicit_arr, full=False)[source]#

Recursively materialize nested ImplicitArray structures.

Parameters
  • implicit_arr – ImplicitArray or nested ImplicitArray to materialize.

  • full – If True, recursively materialize until reaching a regular array. If False, only materialize one level.

Returns

Materialized array. If materialization fails, returns an array of ones with the expected shape and dtype.

eformer.jaximus._imus.register(primitive: jax._src.core.Primitive | str, *, precedence: int = 0) Callable[[CT], CT][source]#

Register a custom handler for a JAX primitive operation.

This decorator allows you to define custom behavior for JAX primitives (like dot_general, add, mul, etc.) when operating on ImplicitArray instances. Handlers are dispatched via multiple dispatch based on argument types.

Parameters
  • primitive – JAX primitive to register handler for. Can be: - A core.Primitive object (e.g., jax.lax.dot_general_p) - A string name (e.g., “dot_general”) - automatically resolved to primitive

  • precedence – Handler precedence for multiple dispatch (higher = higher priority). Default is 0.

Returns

Decorator that registers the handler function.

Example - Basic Usage:
>>> from eformer.jaximus import register
>>> from jax.extend.core import Primitive
>>>
>>> @register("dot_general")
>>> def nf4_matmul(lhs: jax.Array, rhs: ArrayNF4, **kwargs):
...     # Custom matmul for dense @ NF4
...     return nf4_kernel(lhs, rhs.packed, rhs.absmax)
Example - Multiple Dispatch:
>>> @register("dot_general")
>>> def nf4_lhs_matmul(lhs: ArrayNF4, rhs: jax.Array, **kwargs):
...     # Handle NF4 @ dense (different from above)
...     return nf4_transpose_kernel(rhs, lhs.packed, lhs.absmax)
>>>
>>> # Both handlers registered; dispatched based on argument types
>>> dense @ nf4_weight  # Uses first handler
>>> nf4_weight @ dense  # Uses second handler
Example - Precedence:
>>> @register("add", precedence=10)
>>> def high_priority_add(x: ArrayNF4, y: ArrayNF4, **kwargs):
...     # This handler takes precedence over lower-priority ones
...     return optimized_nf4_add(x, y)
Technical Details:
  • Uses plum for multiple dispatch based on type signatures

  • Handlers are called within _CustomTrace.process_primitive

  • If no handler matches, falls back to materialization

  • Supports string primitive names (auto-resolved from jax.lax)

  • Multiple handlers per primitive allowed (dispatched by type)

Handler Signature:

Handlers receive: - primitive: The JAX primitive being executed (sometimes) - *args: Operands (ImplicitArray or regular arrays) - **kwargs: Primitive parameters (dimension_numbers, precision, etc.)

And should return: - Result array (can be ImplicitArray or regular array)

Common Primitives to Register:
  • “dot_general”: Matrix multiplication (x @ y)

  • “add”, “sub”, “mul”, “div”: Arithmetic operations

  • “reshape”, “transpose”: Shape operations

  • “reduce”: Reductions (sum, max, etc.)

  • “convert_element_type”: Type conversions

See also

  • ImplicitArray: Base class for custom array types

  • use_implicit/implicit: Enable handler dispatch

  • _CustomTrace.process_primitive: Dispatch implementation

eformer.jaximus._imus.ste(func)[source]#

Straight-Through Estimator (STE) decorator for quantization-aware training.

This decorator enables gradient flow through non-differentiable quantization operations by using a custom VJP (vector-Jacobian product) that passes gradients straight through to the input, ignoring the quantization step in the backward pass.

The STE is critical for training with quantized weights: - Forward pass: Uses the quantized representation (e.g., NF4, INT4) - Backward pass: Gradients flow to the full-precision master weights

This allows training with quantization awareness while maintaining gradient-based optimization of the underlying float32 parameters.

Parameters

func – A function that performs quantization or other non-differentiable operations. Signature: func(x: Array, *args, **kwargs) -> Array | ImplicitArray

Returns

Wrapped function with straight-through gradient behavior.

Example

>>> @ste
... def quantize_nf4(weights, block_size=64):
...     return ArrayNF4.quantize(weights, block_size)
>>>
>>> # Forward: uses quantized weights
>>> # Backward: gradients flow to original float32 weights
>>> quantized = quantize_nf4(fp32_weights)
Technical Details:
  • Uses jax.custom_vjp to define custom backward pass behavior

  • Automatically materializes ImplicitArray cotangents to ensure valid gradients

  • Supports arbitrary positional and keyword arguments

  • Returns None for cotangents of non-differentiable arguments

Note

The STE is a biased gradient estimator - it assumes the gradient of the quantization function is the identity. This works well in practice for quantization-aware training despite the theoretical bias.

eformer.jaximus._imus.tree_flatten_with_implicit(tree: Any, is_leaf: collections.abc.Callable[[Any], bool] | None = None) tuple[list[Any], jaxlib._jax.pytree.PyTreeDef]#

Like jax.tree_util.tree_flatten but treats ImplicitArray as leaves.

eformer.jaximus._imus.tree_flatten_with_path_with_implicit(tree: Any, is_leaf: collections.abc.Callable[[...], bool] | None = None, is_leaf_takes_path: bool = False) tuple[list[tuple[tuple[KeyEntry, ...], Any]], jaxlib._jax.pytree.PyTreeDef]#

Like jax.tree_util.tree_flatten_with_path but treats ImplicitArray as leaves.

eformer.jaximus._imus.tree_leaves_with_implicit(tree: Any, is_leaf: collections.abc.Callable[[Any], bool] | None = None) list[Any]#

Like jax.tree_util.tree_leaves but treats ImplicitArray as leaves.

eformer.jaximus._imus.tree_map_with_implicit(f: Callable[[...], Any], tree: Any, *rest: Any, is_leaf: collections.abc.Callable[[Any], bool] | None = None) Any#

Like jax.tree_util.tree_map but treats ImplicitArray as leaves.

eformer.jaximus._imus.tree_map_with_path_with_implicit(f: Callable[[...], Any], tree: Any, *rest: Any, is_leaf: collections.abc.Callable[[...], bool] | None = None, is_leaf_takes_path: bool = False) Any#

Like jax.tree_util.tree_map_with_path but treats ImplicitArray as leaves.

eformer.jaximus._imus.tree_structure_with_implicit(tree: Any, is_leaf: None | collections.abc.Callable[[Any], bool] = None) PyTreeDef#

Like jax.tree_util.tree_structure but treats ImplicitArray as leaves.

eformer.jaximus._imus.use_implicit(fn)[source]#

Enable implicit array dispatch for a function.

This decorator/wrapper sets up a custom JAX trace that intercepts operations on ImplicitArray instances and routes them to registered handlers. This allows transparent use of quantized or lazy arrays without manual materialization.

Parameters

fn – Function to wrap with implicit array support.

Returns

Wrapped function that handles ImplicitArray instances transparently.

Example

>>> @use_implicit
... def matmul(x, w):
...     return x @ w  # Automatically uses custom dot_general for NF4
>>>
>>> # Or use as context:
>>> with implicit:
...     output = inputs @ nf4_weights  # Custom handler dispatched
Technical Details:
  • Creates a custom JAX trace (_CustomTrace) that intercepts primitive operations

  • Wraps ImplicitArray instances in _CustomTracer for operation interception

  • Dispatches to registered handlers via the @register decorator

  • Falls back to materialization if no handler is registered

  • Maintains compatibility with JAX transformations (jit, grad, vmap)

See also

  • register: Decorator for registering custom primitive handlers

  • ImplicitArray: Base class for implicit array implementations

  • _CustomTrace: The trace implementation that handles dispatch