eformer.serialization.serialization#

eformer.serialization.serialization.is_array_like(x: Any) bool[source]#

Check if an object is array-like.

Minimal check similar to equinox.is_array_like, checking for shape and dtype attributes.

Parameters

x – Object to check.

Returns

True if object has both shape and dtype attributes, False otherwise.

eformer.serialization.serialization.leaf_key_paths(pytree: Any, prefix: str | None = '', *, is_leaf: collections.abc.Callable[[Any], bool] | None = None)[source]#

Create dotted key paths for each leaf in a pytree.

Returns a pytree of the same structure where each leaf is replaced by its key path (prefixed by prefix if provided). Uses jax.tree_util.tree_flatten_with_path for robust handling of dicts, sequences, dataclasses, namedtuples, and custom PyTree nodes.

Parameters
  • pytree – The pytree to create key paths for.

  • prefix – Optional prefix to add to all key paths.

  • is_leaf – Optional function to determine if a node is a leaf.

Returns

PyTree with same structure where leaves are replaced by their dotted key paths.

eformer.serialization.serialization.tree_deserialize_leaves(checkpoint_dir, mesh: Mesh, manager: jax.experimental.array_serialization.serialization.GlobalAsyncCheckpointManager | None = None, *, prefix: str | None = None, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, shardings: jaxtyping.PyTree | dict[collections.abc.Callable] | None = None, callback: collections.abc.Callable[[jax.jaxlib._jax.Array, str], jax.jaxlib._jax.Array] | None = None, chunk_size: int | None = None)[source]#

Deserialize a PyTree of arrays from a TensorStore checkpoint.

If pytree is provided, returns a pytree with the same structure as the template. If pytree is None, discovers the structure from the checkpoint directory.

Parameters
  • checkpoint_dir – Directory containing the TensorStore checkpoint.

  • mesh – Optional JAX mesh for distributed arrays.

  • manager – Optional GlobalAsyncCheckpointManager. If None, creates a new one.

  • prefix – Optional prefix to filter/load specific tree (e.g., ‘model’, ‘optimizer’).

  • shardings – sharding specifications matching checkpoint structure.

  • callback – Optional callback to process each loaded array by key.

  • chunk_size – Optional number of arrays to load per batch. If set, loads in chunks and waits between batches to reduce peak memory usage.

Returns

Deserialized pytree structure with loaded arrays.

Raises

ValueError – If checkpoint format is unsupported or prefix not found.

Note

Supports both v1.0 (single prefix) and v2.0 (multi-prefix) index formats. When using v2.0 format with multiple prefixes, you must specify which prefix to load or an error will be raised listing available prefixes.

eformer.serialization.serialization.tree_serialize_leaves(checkpoint_dir, pytree, manager: jax.experimental.array_serialization.serialization.GlobalAsyncCheckpointManager | None = None, *, prefix: str | None = None, commit_callback: collections.abc.Callable | None = None, write_index: bool = True)[source]#

Serialize a pytree’s leaves to TensorStore format.

Serializes arrays in a pytree to TensorStore format with optional prefixing for organizing multiple trees in the same checkpoint directory.

Parameters
  • checkpoint_dir – Directory to save the checkpoint.

  • pytree – PyTree structure containing arrays to serialize.

  • manager – Optional GlobalAsyncCheckpointManager. If None, creates a new one.

  • prefix – Optional prefix for organizing arrays (e.g., ‘model’, ‘optimizer’).

  • commit_callback – Optional callback to run after committing the checkpoint.

  • write_index – Whether to write an index file for the checkpoint.

Returns

None

Note

Uses a unified index file (tensorstore_index.json) that supports multiple prefixes in version 2.0 format.