eformer.jaximus._imus#
Implicit Array System and JAX Primitive Interception.
This module provides a framework for creating and manipulating “implicit arrays” - custom array-like objects that defer expensive operations (like materialization from quantized formats) until absolutely necessary. It enables transparent integration with JAX by intercepting primitive operations and routing them to custom handlers.
- Core Components:
ImplicitArray: Abstract base class for lazy/deferred array representations
ste (Straight-Through Estimator): Decorator for gradient pass-through in quantization
use_implicit/implicit: Context manager for enabling implicit array dispatch
register: Decorator for registering custom primitive handlers
_CustomTrace: JAX trace implementation for intercepting operations
The system uses JAX’s tracing infrastructure to intercept operations on ImplicitArray instances and dispatch them to registered handlers, allowing custom behavior while maintaining compatibility with JAX transformations (jit, grad, vmap, etc.).
Example
>>> from eformer.jaximus import ImplicitArray, register, implicit
>>> from eformer.ops.quantization import ArrayNF4
>>>
>>> # Create a quantized weight array
>>> weight = jnp.ones((128, 64), dtype=jnp.float32)
>>> nf4_weight = ArrayNF4.quantize(weight, block_size=64)
>>>
>>> # Use implicit dispatch to avoid premature materialization
>>> @implicit
... def linear(x, w):
... return x @ w # Uses custom dot_general handler for NF4
>>>
>>> output = linear(inputs, nf4_weight) # NF4 kernel used, no materialization
- class eformer.jaximus._imus.ImplicitArray(*, shape: Optional[Sequence[int]] = None, dtype: dtype = None)[source]#
Bases:
_ArrayBaseAbstract base class for implicit (lazy/deferred) array representations.
ImplicitArray enables custom array-like types that defer expensive operations until materialization is required. This is particularly useful for quantized arrays (NF4, INT4, INT8) where you want to: 1. Store data in compressed format to save memory 2. Perform operations directly on compressed data when possible (via custom kernels) 3. Only materialize to full precision when necessary
The ImplicitArray system integrates with JAX’s tracing infrastructure to intercept operations and dispatch to custom handlers registered via @register decorator.
- shape#
Tuple representing the logical shape of the array. Uses _AvalDescriptor for lazy initialization.
- Type
Optional[Sequence[int]]
- dtype#
JAX dtype of the materialized array. Uses _AvalDescriptor for lazy initialization.
- Type
- Subclass Requirements:
Must be a dataclass
Must implement materialize() method
Should register custom handlers for primitives via @register
Non-array fields should use aux_field() to mark them as auxiliary
- Example Subclass:
>>> from dataclasses import dataclass >>> from eformer.jaximus import ImplicitArray, aux_field, register >>> >>> @dataclass >>> class ArrayNF4(ImplicitArray): ... packed: jax.Array # 4-bit packed data ... absmax: jax.Array # Scale factors ... block_size: int = aux_field() # Static metadata ... ... def materialize(self): ... # Dequantize back to float32 ... return dequantize_nf4(self.packed, self.absmax, self.block_size) >>> >>> # Register custom handler for matrix multiplication >>> @register("dot_general") >>> def nf4_matmul(lhs, rhs: ArrayNF4, **kwargs): ... # Use optimized NF4 kernel instead of materializing ... return nf4_kernel(lhs, rhs.packed, rhs.absmax)
- Usage:
>>> # Create quantized weight >>> weight = jnp.randn(128, 64) >>> nf4_weight = ArrayNF4.quantize(weight, block_size=64) >>> >>> # Use with implicit dispatch >>> @implicit ... def linear(x, w): ... return x @ w # Automatically uses nf4_matmul handler >>> >>> output = linear(inputs, nf4_weight) # No materialization!
- Technical Details:
Registered as a JAX pytree for compatibility with transformations
Integrates with JAX’s abstract value (aval) system for shape inference
Uses _AvalDescriptor for lazy shape/dtype discovery
Supports nested implicit arrays via tree_flatten_with_keys
Automatic materialization fallback when no custom handler exists
See also
register: Decorator for registering custom primitive handlers
use_implicit/implicit: Enable implicit array dispatch
ste: Straight-through estimator for quantization-aware training
- property aval#
- shape: Optional[Sequence[int]] = None#
- class eformer.jaximus._imus.OrginArray[source]#
Bases:
ABCAbstract base class for array-like types in the implicit array system.
This class serves as a registry for types that should be treated as “original” array types. jax.Array is automatically registered.
The class is used in type checking and dispatch logic to determine whether a value is an array-like type that can participate in implicit array operations.
- exception eformer.jaximus._imus.UninitializedAval[source]#
Bases:
ExceptionException raised when accessing uninitialized abstract value attributes.
This is raised when trying to access shape or dtype on an ImplicitArray before these values have been computed or set.
- eformer.jaximus._imus.aux_field(metadata=None, **kwargs)[source]#
Create a dataclass field marked as auxiliary (non-pytree) data.
In ImplicitArray subclasses, fields can be either: 1. Pytree children: Arrays and nested structures that should be traced by JAX 2. Auxiliary data: Static metadata (ints, strings, dtypes) that don’t participate in tracing
This function creates fields marked as auxiliary, which means: - Not included in pytree flattening/unflattening - Not traced by JAX transformations (jit, grad, vmap) - Typically used for static configuration (block_size, dtype, mesh info)
- Parameters
metadata – Optional existing metadata dict to extend
**kwargs – Additional arguments passed to dataclasses.field()
- Returns
A dataclass field with auxiliary metadata set
Example
>>> from dataclasses import dataclass >>> from eformer.jaximus import ImplicitArray, aux_field >>> >>> @dataclass >>> class ArrayNF4(ImplicitArray): ... # These are pytree children (traced by JAX) ... packed: jax.Array ... absmax: jax.Array ... ... # These are auxiliary (static metadata) ... block_size: int = aux_field() ... dtype: jnp.dtype = aux_field(default=jnp.float32) ... mesh_config: tuple | None = aux_field(default=None)
- Technical Details:
Sets metadata[“implicit_array_aux”] = True
Used by tree_flatten_with_keys and tree_unflatten
Auxiliary fields are passed as aux_data in pytree registration
Must use this for non-array fields to avoid JAX tracing errors
See also
ImplicitArray.tree_flatten_with_keys: Uses aux_field metadata
_get_names_and_aux: Helper that reads this metadata
- eformer.jaximus._imus.combine_leaf_predicate(base_fn, is_leaf)[source]#
Wrap a tree function to include ImplicitArray as a leaf type.
Creates a new function that combines the given is_leaf predicate with an additional predicate, allowing custom leaf detection.
- Parameters
base_fn – Original tree function (e.g., tu.tree_map).
is_leaf – Predicate function to treat values as leaves.
- Returns
Wrapped function that uses the combined leaf predicate.
- eformer.jaximus._imus.default_handler(primitive, *args, **params)[source]#
Default handler that executes a JAX primitive normally.
- Parameters
primitive – JAX primitive to execute.
*args – Arguments to the primitive.
**params – Parameters for the primitive.
- Returns
Result of executing the primitive with the given arguments.
- eformer.jaximus._imus.flatten_one_implicit_layer(tree)[source]#
Flatten one layer of nested ImplicitArrays in a tree.
For nested ImplicitArray structures, this function flattens just one level, treating nested ImplicitArrays as leaves.
- Parameters
tree – Pytree potentially containing nested ImplicitArrays.
- Returns
Tuple of (leaves, structure) where leaves are flattened one level and structure is the pytree structure.
- eformer.jaximus._imus.implicit(fn)#
Enable implicit array dispatch for a function.
This decorator/wrapper sets up a custom JAX trace that intercepts operations on ImplicitArray instances and routes them to registered handlers. This allows transparent use of quantized or lazy arrays without manual materialization.
- Parameters
fn – Function to wrap with implicit array support.
- Returns
Wrapped function that handles ImplicitArray instances transparently.
Example
>>> @use_implicit ... def matmul(x, w): ... return x @ w # Automatically uses custom dot_general for NF4 >>> >>> # Or use as context: >>> with implicit: ... output = inputs @ nf4_weights # Custom handler dispatched
- Technical Details:
Creates a custom JAX trace (_CustomTrace) that intercepts primitive operations
Wraps ImplicitArray instances in _CustomTracer for operation interception
Dispatches to registered handlers via the @register decorator
Falls back to materialization if no handler is registered
Maintains compatibility with JAX transformations (jit, grad, vmap)
See also
register: Decorator for registering custom primitive handlers
ImplicitArray: Base class for implicit array implementations
_CustomTrace: The trace implementation that handles dispatch
- eformer.jaximus._imus.implicit_depth(tree)[source]#
Calculate the maximum nesting depth of ImplicitArrays in a tree.
- Parameters
tree – Pytree potentially containing nested ImplicitArrays.
- Returns
Integer depth of nesting. 0 means no ImplicitArrays, 1 means ImplicitArrays with no nesting, etc.
- eformer.jaximus._imus.is_array(element: Any) bool[source]#
Check if an element is an array type (NumPy or JAX).
- Parameters
element – Value to check.
- Returns
True if element is a NumPy array, NumPy scalar, or JAX array.
- eformer.jaximus._imus.leaf_predicate(x)[source]#
Predicate for identifying ImplicitArray leaves in tree operations.
- Parameters
x – Value to check.
- Returns
True if x is an ImplicitArray instance.
- eformer.jaximus._imus.materialize_handler(primitive, *vals, params)[source]#
Handler that materializes all ImplicitArrays before executing primitive.
This is the fallback handler used when no custom handler is registered for a primitive operation involving ImplicitArrays.
- Parameters
primitive – JAX primitive to execute.
*vals – Values that may include ImplicitArrays.
params – Parameters for the primitive.
- Returns
Result of executing the primitive after materializing all values.
- eformer.jaximus._imus.materialize_nested(implicit_arr, full=False)[source]#
Recursively materialize nested ImplicitArray structures.
- Parameters
implicit_arr – ImplicitArray or nested ImplicitArray to materialize.
full – If True, recursively materialize until reaching a regular array. If False, only materialize one level.
- Returns
Materialized array. If materialization fails, returns an array of ones with the expected shape and dtype.
- eformer.jaximus._imus.register(primitive: jax._src.core.Primitive | str, *, precedence: int = 0) Callable[[CT], CT][source]#
Register a custom handler for a JAX primitive operation.
This decorator allows you to define custom behavior for JAX primitives (like dot_general, add, mul, etc.) when operating on ImplicitArray instances. Handlers are dispatched via multiple dispatch based on argument types.
- Parameters
primitive – JAX primitive to register handler for. Can be: - A core.Primitive object (e.g., jax.lax.dot_general_p) - A string name (e.g., “dot_general”) - automatically resolved to primitive
precedence – Handler precedence for multiple dispatch (higher = higher priority). Default is 0.
- Returns
Decorator that registers the handler function.
- Example - Basic Usage:
>>> from eformer.jaximus import register >>> from jax.extend.core import Primitive >>> >>> @register("dot_general") >>> def nf4_matmul(lhs: jax.Array, rhs: ArrayNF4, **kwargs): ... # Custom matmul for dense @ NF4 ... return nf4_kernel(lhs, rhs.packed, rhs.absmax)
- Example - Multiple Dispatch:
>>> @register("dot_general") >>> def nf4_lhs_matmul(lhs: ArrayNF4, rhs: jax.Array, **kwargs): ... # Handle NF4 @ dense (different from above) ... return nf4_transpose_kernel(rhs, lhs.packed, lhs.absmax) >>> >>> # Both handlers registered; dispatched based on argument types >>> dense @ nf4_weight # Uses first handler >>> nf4_weight @ dense # Uses second handler
- Example - Precedence:
>>> @register("add", precedence=10) >>> def high_priority_add(x: ArrayNF4, y: ArrayNF4, **kwargs): ... # This handler takes precedence over lower-priority ones ... return optimized_nf4_add(x, y)
- Technical Details:
Uses plum for multiple dispatch based on type signatures
Handlers are called within _CustomTrace.process_primitive
If no handler matches, falls back to materialization
Supports string primitive names (auto-resolved from jax.lax)
Multiple handlers per primitive allowed (dispatched by type)
- Handler Signature:
Handlers receive: - primitive: The JAX primitive being executed (sometimes) - *args: Operands (ImplicitArray or regular arrays) - **kwargs: Primitive parameters (dimension_numbers, precision, etc.)
And should return: - Result array (can be ImplicitArray or regular array)
- Common Primitives to Register:
“dot_general”: Matrix multiplication (x @ y)
“add”, “sub”, “mul”, “div”: Arithmetic operations
“reshape”, “transpose”: Shape operations
“reduce”: Reductions (sum, max, etc.)
“convert_element_type”: Type conversions
See also
ImplicitArray: Base class for custom array types
use_implicit/implicit: Enable handler dispatch
_CustomTrace.process_primitive: Dispatch implementation
- eformer.jaximus._imus.ste(func)[source]#
Straight-Through Estimator (STE) decorator for quantization-aware training.
This decorator enables gradient flow through non-differentiable quantization operations by using a custom VJP (vector-Jacobian product) that passes gradients straight through to the input, ignoring the quantization step in the backward pass.
The STE is critical for training with quantized weights: - Forward pass: Uses the quantized representation (e.g., NF4, INT4) - Backward pass: Gradients flow to the full-precision master weights
This allows training with quantization awareness while maintaining gradient-based optimization of the underlying float32 parameters.
- Parameters
func – A function that performs quantization or other non-differentiable operations. Signature: func(x: Array, *args, **kwargs) -> Array | ImplicitArray
- Returns
Wrapped function with straight-through gradient behavior.
Example
>>> @ste ... def quantize_nf4(weights, block_size=64): ... return ArrayNF4.quantize(weights, block_size) >>> >>> # Forward: uses quantized weights >>> # Backward: gradients flow to original float32 weights >>> quantized = quantize_nf4(fp32_weights)
- Technical Details:
Uses jax.custom_vjp to define custom backward pass behavior
Automatically materializes ImplicitArray cotangents to ensure valid gradients
Supports arbitrary positional and keyword arguments
Returns None for cotangents of non-differentiable arguments
Note
The STE is a biased gradient estimator - it assumes the gradient of the quantization function is the identity. This works well in practice for quantization-aware training despite the theoretical bias.
- eformer.jaximus._imus.tree_flatten_with_implicit(tree: Any, is_leaf: collections.abc.Callable[[Any], bool] | None = None) tuple[list[Any], jaxlib._jax.pytree.PyTreeDef]#
Like jax.tree_util.tree_flatten but treats ImplicitArray as leaves.
- eformer.jaximus._imus.tree_flatten_with_path_with_implicit(tree: Any, is_leaf: collections.abc.Callable[[...], bool] | None = None, is_leaf_takes_path: bool = False) tuple[list[tuple[tuple[KeyEntry, ...], Any]], jaxlib._jax.pytree.PyTreeDef]#
Like jax.tree_util.tree_flatten_with_path but treats ImplicitArray as leaves.
- eformer.jaximus._imus.tree_leaves_with_implicit(tree: Any, is_leaf: collections.abc.Callable[[Any], bool] | None = None) list[Any]#
Like jax.tree_util.tree_leaves but treats ImplicitArray as leaves.
- eformer.jaximus._imus.tree_map_with_implicit(f: Callable[[...], Any], tree: Any, *rest: Any, is_leaf: collections.abc.Callable[[Any], bool] | None = None) Any#
Like jax.tree_util.tree_map but treats ImplicitArray as leaves.
- eformer.jaximus._imus.tree_map_with_path_with_implicit(f: Callable[[...], Any], tree: Any, *rest: Any, is_leaf: collections.abc.Callable[[...], bool] | None = None, is_leaf_takes_path: bool = False) Any#
Like jax.tree_util.tree_map_with_path but treats ImplicitArray as leaves.
- eformer.jaximus._imus.tree_structure_with_implicit(tree: Any, is_leaf: None | collections.abc.Callable[[Any], bool] = None) PyTreeDef#
Like jax.tree_util.tree_structure but treats ImplicitArray as leaves.
- eformer.jaximus._imus.use_implicit(fn)[source]#
Enable implicit array dispatch for a function.
This decorator/wrapper sets up a custom JAX trace that intercepts operations on ImplicitArray instances and routes them to registered handlers. This allows transparent use of quantized or lazy arrays without manual materialization.
- Parameters
fn – Function to wrap with implicit array support.
- Returns
Wrapped function that handles ImplicitArray instances transparently.
Example
>>> @use_implicit ... def matmul(x, w): ... return x @ w # Automatically uses custom dot_general for NF4 >>> >>> # Or use as context: >>> with implicit: ... output = inputs @ nf4_weights # Custom handler dispatched
- Technical Details:
Creates a custom JAX trace (_CustomTrace) that intercepts primitive operations
Wraps ImplicitArray instances in _CustomTracer for operation interception
Dispatches to registered handlers via the @register decorator
Falls back to materialization if no handler is registered
Maintains compatibility with JAX transformations (jit, grad, vmap)
See also
register: Decorator for registering custom primitive handlers
ImplicitArray: Base class for implicit array implementations
_CustomTrace: The trace implementation that handles dispatch