eformer.pytree._pytree#

class eformer.pytree._pytree.FrozenPyTree[source]#

Bases: _PyTreeNodeBase

Base class for immutable (frozen) PyTree dataclasses.

Inheriting from this class automatically applies the auto_pytree decorator with frozen=True to the subclass, registering it as a frozen JAX PyTree.

class eformer.pytree._pytree.PyTree[source]#

Bases: _PyTreeNodeBase

Base class for mutable PyTree dataclasses.

Inheriting from this class automatically applies the auto_pytree decorator to the subclass, registering it as a JAX PyTree.

class eformer.pytree._pytree.PyTreeClassInfo(data_fields: tuple[str, ...], meta_fields: tuple[str, ...], frozen: bool, type_hints: dict[str, type])[source]#

Bases: object

Stores metadata about a class registered as a PyTree.

data_fields#
frozen#
meta_fields#
type_hints#
eformer.pytree._pytree.auto_pytree(cls: type[T] | None = None, meta_fields: tuple[str, ...] | None = None, json_serializable: bool = True, frozen: bool = False, max_print_length: int = 500)[source]#

A class decorator that automatically registers a dataclass as a JAX PyTree.

It uses dataclasses.dataclass to make the class a dataclass if it isn’t already, determines which fields are data (PyTree children) and which are metadata, and registers the class with jax.tree_util.register_dataclass.

Fields are considered metadata if: - They are explicitly listed in meta_fields. - They are marked with field(pytree_node=False). - Their type hint suggests they are non-JAX types (checked by _is_non_jax_type).

Parameters
  • cls – The class to be decorated.

  • meta_fields – A tuple of field names to always treat as metadata.

  • json_serializable – If True (default), adds to_dict, from_dict, to_json, and from_json methods to the class.

  • frozen – If True, makes the dataclass frozen (immutable). Defaults to False.

Returns

The decorated class, registered as a PyTree.

eformer.pytree._pytree.field(pytree_node: bool = True, *, metadata: dict | None = None, **kwargs) Field[source]#

A dataclass field replacement that allows specifying whether a field should be treated as a PyTree node.

Parameters
  • pytree_node – If True (default), the field is treated as a PyTree leaf/node. If False, the field is treated as metadata.

  • metadata – Optional dictionary of metadata to pass to dataclasses.field.

  • **kwargs – Additional keyword arguments passed to dataclasses.field.

Returns

A dataclasses.Field object.