eformer.serialization.serialization#

eformer.serialization.serialization.is_array_like(x: Any) bool[source]#

Check if an object is array-like.

Minimal check similar to equinox.is_array_like, checking for shape and dtype attributes.

Parameters

x – Object to check.

Returns

True if object has both shape and dtype attributes, False otherwise.

eformer.serialization.serialization.leaf_key_paths(pytree: Any, prefix: str | None = '', *, is_leaf: collections.abc.Callable[[Any], bool] | None = None)[source]#

Create dotted key paths for each leaf in a pytree.

Returns a pytree of the same structure where each leaf is replaced by its key path (prefixed by prefix if provided). Uses jax.tree_util.tree_flatten_with_path for robust handling of dicts, sequences, dataclasses, namedtuples, and custom PyTree nodes.

Parameters
  • pytree – The pytree to create key paths for.

  • prefix – Optional prefix to add to all key paths.

  • is_leaf – Optional function to determine if a node is a leaf.

Returns

PyTree with same structure where leaves are replaced by their dotted key paths.

eformer.serialization.serialization.tree_deserialize_leaves(checkpoint_dir, mesh: Mesh, manager: jax.experimental.array_serialization.serialization.GlobalAsyncCheckpointManager | None = None, *, prefix: str | None = None, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, shardings: jaxtyping.PyTree | dict[collections.abc.Callable] | None = None)[source]#

Deserialize a PyTree of arrays from a TensorStore checkpoint.

If pytree is provided, returns a pytree with the same structure as the template. If pytree is None, discovers the structure from the checkpoint directory.

Parameters
  • checkpoint_dir – Directory containing the TensorStore checkpoint.

  • mesh – Optional JAX mesh for distributed arrays.

  • manager – Optional GlobalAsyncCheckpointManager. If None, creates a new one.

  • prefix – Optional prefix to filter/load specific tree (e.g., ‘model’, ‘optimizer’).

  • shardings – sharding specifications matching checkpoint structure.

Returns

Deserialized pytree structure with loaded arrays.

Raises

ValueError – If checkpoint format is unsupported or prefix not found.

Note

Supports both v1.0 (single prefix) and v2.0 (multi-prefix) index formats. When using v2.0 format with multiple prefixes, you must specify which prefix to load or an error will be raised listing available prefixes.

eformer.serialization.serialization.tree_serialize_leaves(checkpoint_dir, pytree, manager: jax.experimental.array_serialization.serialization.GlobalAsyncCheckpointManager | None = None, *, prefix: str | None = None, commit_callback: collections.abc.Callable | None = None, write_index: bool = True)[source]#

Serialize a pytree’s leaves to TensorStore format.

Serializes arrays in a pytree to TensorStore format with optional prefixing for organizing multiple trees in the same checkpoint directory.

Parameters
  • checkpoint_dir – Directory to save the checkpoint.

  • pytree – PyTree structure containing arrays to serialize.

  • manager – Optional GlobalAsyncCheckpointManager. If None, creates a new one.

  • prefix – Optional prefix for organizing arrays (e.g., ‘model’, ‘optimizer’).

  • commit_callback – Optional callback to run after committing the checkpoint.

  • write_index – Whether to write an index file for the checkpoint.

Returns

None

Note

Uses a unified index file (tensorstore_index.json) that supports multiple prefixes in version 2.0 format.