eformer.pytree._xtree#
- eformer.pytree._xtree.dataclass(clz: _T, **kwargs) _T[source]#
- eformer.pytree._xtree.dataclass(**kwargs) Callable[[_T], _T]
A decorator that enhances standard dataclasses to be JAX PyTree compatible and adds serialization/deserialization capabilities.
It automatically registers the dataclass with jax.tree_util and defines to_state_dict and from_state_dict methods based on the field types and explicit pytree_node markings.
- Parameters
clz – The class to decorate.
**kwargs – Additional keyword arguments passed to dataclasses.dataclass. Defaults to frozen=True.
- Returns
The decorated class.
- eformer.pytree._xtree.field(*, pytree_node: bool | None = None, metadata: dict | None = None, **kwargs)[source]#
Define a dataclass field and optionally mark it explicitly as a PyTree node.
This function is a wrapper around dataclasses.field that adds a pytree_node option to the metadata.
- Parameters
pytree_node – Explicitly mark the field as a PyTree node (True) or leaf (False). If None, the type annotation will be used to infer behavior.
metadata – A dictionary of metadata for the field. The pytree_node key will be added or updated in this dictionary.
**kwargs – Additional keyword arguments passed to dataclasses.field.
- Returns
A dataclasses.Field object.
- eformer.pytree._xtree.register_serialization_state(ty: Any, ty_to_state_dict: Callable[[Any], dict[str, Any]], ty_from_state_dict: Callable[[Any, dict[str, Any]], Any], override: bool = False)[source]#
Registers serialization and deserialization functions for a given type.
- Parameters
ty – The type to register handlers for.
ty_to_state_dict – A callable that converts an instance of ty to a state dictionary.
ty_from_state_dict – A callable that updates an instance of ty from a state dictionary.
override – If True, overrides an existing registration for the type. If False and a registration exists, raises a ValueError.
- Raises
ValueError – If a handler for the type is already registered and override is False.
- class eformer.pytree._xtree.xTree(*args, **kwargs)[source]#
Bases:
objectBase class for dataclasses acting as JAX PyTree nodes with built-in serialization support.
Classes inheriting from xTree are automatically processed by the dataclass decorator upon definition, making them JAX PyTree compatible and adding to_state_dict and from_state_dict methods.
- classmethod from_dict(data: dict[str, Any]) __T[source]#
Deserializes a dictionary into a PyTree object.
- classmethod from_state_dict(state: Any) STree[source]#
Deserializes state into an instance of this class.
- replace(**overrides) STree[source]#
Returns a new instance of the xTree subclass with specified fields updated.
This method is added dynamically by the dataclass decorator.
- Parameters
**overrides – Keyword arguments where keys are field names and values are the new values for those fields.
- Returns
A new instance of the xTree subclass with the updated fields.
- eformer.pytree._xtree.xfrom_state_dict(target: _T, state: dict[str, Any], name: str = '.') _T[source]#
Recursively deserializes the state dictionary into the target object.
Uses the registered from_state_dict function for the target’s type if available, otherwise returns the state directly.
- Parameters
target – The object to deserialize into.
state – The state dictionary.
name – The name of the current object in the parent structure (used for error reporting).
- Returns
The deserialized object.
- eformer.pytree._xtree.xto_state_dict(target: Any) dict[str, Any][source]#
Recursively converts the target object into a state dictionary.
Uses the registered to_state_dict function for the target’s type if available, otherwise returns the target directly.
- Parameters
target – The object to serialize.
- Returns
A dictionary representing the state of the target object, or the target itself if no serialization handler is registered.