eformer.pytree._xtree#

eformer.pytree._xtree.dataclass(clz: _T, **kwargs) _T[source]#
eformer.pytree._xtree.dataclass(**kwargs) Callable[[_T], _T]

A decorator that enhances standard dataclasses to be JAX PyTree compatible and adds serialization/deserialization capabilities.

It automatically registers the dataclass with jax.tree_util and defines to_state_dict and from_state_dict methods based on the field types and explicit pytree_node markings.

Parameters
  • clz – The class to decorate.

  • **kwargs – Additional keyword arguments passed to dataclasses.dataclass. Defaults to frozen=True.

Returns

The decorated class.

eformer.pytree._xtree.field(*, pytree_node: bool | None = None, metadata: dict | None = None, **kwargs)[source]#

Define a dataclass field and optionally mark it explicitly as a PyTree node.

This function is a wrapper around dataclasses.field that adds a pytree_node option to the metadata.

Parameters
  • pytree_node – Explicitly mark the field as a PyTree node (True) or leaf (False). If None, the type annotation will be used to infer behavior.

  • metadata – A dictionary of metadata for the field. The pytree_node key will be added or updated in this dictionary.

  • **kwargs – Additional keyword arguments passed to dataclasses.field.

Returns

A dataclasses.Field object.

eformer.pytree._xtree.register_serialization_state(ty: Any, ty_to_state_dict: Callable[[Any], dict[str, Any]], ty_from_state_dict: Callable[[Any, dict[str, Any]], Any], override: bool = False)[source]#

Registers serialization and deserialization functions for a given type.

Parameters
  • ty – The type to register handlers for.

  • ty_to_state_dict – A callable that converts an instance of ty to a state dictionary.

  • ty_from_state_dict – A callable that updates an instance of ty from a state dictionary.

  • override – If True, overrides an existing registration for the type. If False and a registration exists, raises a ValueError.

Raises

ValueError – If a handler for the type is already registered and override is False.

class eformer.pytree._xtree.xTree(*args, **kwargs)[source]#

Bases: object

Base class for dataclasses acting as JAX PyTree nodes with built-in serialization support.

Classes inheriting from xTree are automatically processed by the dataclass decorator upon definition, making them JAX PyTree compatible and adding to_state_dict and from_state_dict methods.

classmethod from_dict(data: dict[str, Any]) __T[source]#

Deserializes a dictionary into a PyTree object.

classmethod from_state_dict(state: Any) STree[source]#

Deserializes state into an instance of this class.

replace(**overrides) STree[source]#

Returns a new instance of the xTree subclass with specified fields updated.

This method is added dynamically by the dataclass decorator.

Parameters

**overrides – Keyword arguments where keys are field names and values are the new values for those fields.

Returns

A new instance of the xTree subclass with the updated fields.

to_dict() dict[str, Any][source]#

Serializes the PyTree object to a dictionary.

to_state_dict() Any[source]#

Serializes this instance to a JSON-compatible state.

eformer.pytree._xtree.xfrom_state_dict(target: _T, state: dict[str, Any], name: str = '.') _T[source]#

Recursively deserializes the state dictionary into the target object.

Uses the registered from_state_dict function for the target’s type if available, otherwise returns the state directly.

Parameters
  • target – The object to deserialize into.

  • state – The state dictionary.

  • name – The name of the current object in the parent structure (used for error reporting).

Returns

The deserialized object.

eformer.pytree._xtree.xto_state_dict(target: Any) dict[str, Any][source]#

Recursively converts the target object into a state dictionary.

Uses the registered to_state_dict function for the target’s type if available, otherwise returns the target directly.

Parameters

target – The object to serialize.

Returns

A dictionary representing the state of the target object, or the target itself if no serialization handler is registered.