eformer.pytree._pytree#
- class eformer.pytree._pytree.FrozenPyTree[source]#
Bases:
_PyTreeNodeBaseBase class for immutable (frozen) PyTree dataclasses.
Inheriting from this class automatically applies the auto_pytree decorator with frozen=True to the subclass, registering it as a frozen JAX PyTree.
- class eformer.pytree._pytree.PyTree[source]#
Bases:
_PyTreeNodeBaseBase class for mutable PyTree dataclasses.
Inheriting from this class automatically applies the auto_pytree decorator to the subclass, registering it as a JAX PyTree.
- class eformer.pytree._pytree.PyTreeClassInfo(data_fields: tuple[str, ...], meta_fields: tuple[str, ...], frozen: bool, type_hints: dict[str, type])[source]#
Bases:
objectStores metadata about a class registered as a PyTree.
- data_fields#
- frozen#
- meta_fields#
- type_hints#
- eformer.pytree._pytree.auto_pytree(cls: type[T] | None = None, meta_fields: tuple[str, ...] | None = None, json_serializable: bool = True, frozen: bool = False, max_print_length: int = 500)[source]#
A class decorator that automatically registers a dataclass as a JAX PyTree.
It uses dataclasses.dataclass to make the class a dataclass if it isn’t already, determines which fields are data (PyTree children) and which are metadata, and registers the class with jax.tree_util.register_dataclass.
Fields are considered metadata if: - They are explicitly listed in meta_fields. - They are marked with field(pytree_node=False). - Their type hint suggests they are non-JAX types (checked by _is_non_jax_type).
- Parameters
cls – The class to be decorated.
meta_fields – A tuple of field names to always treat as metadata.
json_serializable – If True (default), adds to_dict, from_dict, to_json, and from_json methods to the class.
frozen – If True, makes the dataclass frozen (immutable). Defaults to False.
- Returns
The decorated class, registered as a PyTree.
- eformer.pytree._pytree.field(pytree_node: bool = True, *, metadata: dict | None = None, **kwargs) Field[source]#
A dataclass field replacement that allows specifying whether a field should be treated as a PyTree node.
- Parameters
pytree_node – If True (default), the field is treated as a PyTree leaf/node. If False, the field is treated as metadata.
metadata – Optional dictionary of metadata to pass to dataclasses.field.
**kwargs – Additional keyword arguments passed to dataclasses.field.
- Returns
A dataclasses.Field object.