eformer.serialization.serialization#
- eformer.serialization.serialization.is_array_like(x: Any) bool[source]#
Check if an object is array-like.
Minimal check similar to equinox.is_array_like, checking for shape and dtype attributes.
- Parameters
x – Object to check.
- Returns
True if object has both shape and dtype attributes, False otherwise.
- eformer.serialization.serialization.leaf_key_paths(pytree: Any, prefix: str | None = '', *, is_leaf: collections.abc.Callable[[Any], bool] | None = None)[source]#
Create dotted key paths for each leaf in a pytree.
Returns a pytree of the same structure where each leaf is replaced by its key path (prefixed by prefix if provided). Uses jax.tree_util.tree_flatten_with_path for robust handling of dicts, sequences, dataclasses, namedtuples, and custom PyTree nodes.
- Parameters
pytree – The pytree to create key paths for.
prefix – Optional prefix to add to all key paths.
is_leaf – Optional function to determine if a node is a leaf.
- Returns
PyTree with same structure where leaves are replaced by their dotted key paths.
- eformer.serialization.serialization.tree_deserialize_leaves(checkpoint_dir, mesh: Mesh, manager: jax.experimental.array_serialization.serialization.GlobalAsyncCheckpointManager | None = None, *, prefix: str | None = None, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, shardings: jaxtyping.PyTree | dict[collections.abc.Callable] | None = None, callback: collections.abc.Callable[[jax.jaxlib._jax.Array, str], jax.jaxlib._jax.Array] | None = None, chunk_size: int | None = None)[source]#
Deserialize a PyTree of arrays from a TensorStore checkpoint.
If pytree is provided, returns a pytree with the same structure as the template. If pytree is None, discovers the structure from the checkpoint directory.
- Parameters
checkpoint_dir – Directory containing the TensorStore checkpoint.
mesh – Optional JAX mesh for distributed arrays.
manager – Optional GlobalAsyncCheckpointManager. If None, creates a new one.
prefix – Optional prefix to filter/load specific tree (e.g., ‘model’, ‘optimizer’).
shardings – sharding specifications matching checkpoint structure.
callback – Optional callback to process each loaded array by key.
chunk_size – Optional number of arrays to load per batch. If set, loads in chunks and waits between batches to reduce peak memory usage.
- Returns
Deserialized pytree structure with loaded arrays.
- Raises
ValueError – If checkpoint format is unsupported or prefix not found.
Note
Supports both v1.0 (single prefix) and v2.0 (multi-prefix) index formats. When using v2.0 format with multiple prefixes, you must specify which prefix to load or an error will be raised listing available prefixes.
- eformer.serialization.serialization.tree_serialize_leaves(checkpoint_dir, pytree, manager: jax.experimental.array_serialization.serialization.GlobalAsyncCheckpointManager | None = None, *, prefix: str | None = None, commit_callback: collections.abc.Callable | None = None, write_index: bool = True)[source]#
Serialize a pytree’s leaves to TensorStore format.
Serializes arrays in a pytree to TensorStore format with optional prefixing for organizing multiple trees in the same checkpoint directory.
- Parameters
checkpoint_dir – Directory to save the checkpoint.
pytree – PyTree structure containing arrays to serialize.
manager – Optional GlobalAsyncCheckpointManager. If None, creates a new one.
prefix – Optional prefix for organizing arrays (e.g., ‘model’, ‘optimizer’).
commit_callback – Optional callback to run after committing the checkpoint.
write_index – Whether to write an index file for the checkpoint.
- Returns
None
Note
Uses a unified index file (tensorstore_index.json) that supports multiple prefixes in version 2.0 format.