eformer.serialization.sharding_utils#

Utilities for handling sharding in checkpoint serialization.

eformer.serialization.sharding_utils.apply_sharding_tree(arrays: dict, sharding_tree: dict | None, mesh: jax._src.mesh.Mesh | None = None) dict[source]#

Apply sharding specifications from a sharding tree to arrays.

Takes a dictionary of arrays and applies corresponding sharding specifications from a matching sharding tree structure. Handles callable shardings, direct Sharding objects, and falls back to default sharding when needed.

Parameters
  • arrays – Dictionary of arrays (flattened or nested). Can contain JAX arrays or numpy arrays.

  • sharding_tree – PyTree of sharding specifications matching the structure of arrays. Each leaf can be a Sharding object, a callable that returns a Sharding, or None for default sharding.

  • mesh – JAX mesh for creating default shardings when sharding_tree leaf is None. If mesh is None, uses SingleDeviceSharding.

Returns

Dictionary of arrays with sharding applied via jax.device_put.

Note

  • If sharding_tree is None, returns arrays unchanged.

  • Structure mismatches are logged as warnings and original arrays returned.

  • Non-array leaves (scalars, etc.) are returned unchanged.

Example

>>> arrays = {"layer1.weight": weight_arr, "layer1.bias": bias_arr}
>>> shard_tree = {
...     "layer1.weight": NamedSharding(mesh, PartitionSpec("data")),
...     "layer1.bias": NamedSharding(mesh, PartitionSpec()),
... }
>>> sharded = apply_sharding_tree(arrays, shard_tree, mesh)
eformer.serialization.sharding_utils.create_sharding_tree_from_index(checkpoint_dir: str, mesh: jax._src.mesh.Mesh | None = None, prefix: str | None = None, default_sharding: jaxlib._jax.Sharding | jax.sharding.PartitionSpec | collections.abc.Callable[[Any], jaxlib._jax.Sharding] | None = None) dict[source]#

Create a sharding tree from tensorstore index file.

Creates a PyTree structure that matches the checkpoint structure, where each leaf is a sharding specification or function that can be applied to the corresponding array during deserialization.

Parameters
  • checkpoint_dir – Directory containing the tensorstore checkpoint.

  • mesh – JAX mesh for creating shardings. If None, uses replicated sharding.

  • prefix – Optional prefix to create sharding tree for specific subtree.

  • default_sharding – Default sharding to use for all arrays. Can be: - A Sharding object - A PartitionSpec (will be wrapped with NamedSharding using mesh) - A callable that takes array info dict and returns a Sharding - None (uses fully replicated sharding)

Returns

Dictionary with same structure as checkpoint, where leaves are sharding specifications.

Example

>>>
>>> shard_tree = create_sharding_tree_from_index("checkpoint/")
>>>
>>> def custom_shard_fn(info):
...     if "embedding" in info["path"]:
...         return NamedSharding(mesh, PartitionSpec("data", None))
...     return NamedSharding(mesh, PartitionSpec())
>>> shard_tree = create_sharding_tree_from_index(
...     "checkpoint/", mesh=mesh, default_sharding=custom_shard_fn
... )
eformer.serialization.sharding_utils.make_itsharded(xs, mesh)[source]#

Convert a PyTree of arrays to fully replicated shardings on a mesh.

Takes a PyTree and reshards all fully addressable JAX arrays to use a replicated sharding (PartitionSpec()) on the provided mesh. Non-array leaves and arrays that are not fully addressable are left unchanged.

Parameters
  • xs – PyTree containing JAX arrays and potentially other values.

  • mesh – JAX Mesh to use for the replicated sharding.

Returns

PyTree with same structure where fully addressable arrays have been resharded to replicated layout across all devices in the mesh.

Note

  • Uses JIT compilation for efficient device placement.

  • Only processes arrays where is_fully_addressable is True.

  • Useful for preparing data for collective operations or checkpointing.

Example

>>> from jax.sharding import Mesh
>>> mesh = Mesh(jax.devices(), ("data",))
>>> sharded_state = make_itsharded(model_state, mesh)
eformer.serialization.sharding_utils.validate_sharding_tree(sharding_tree: dict, expected_structure: dict) bool[source]#

Validate that a sharding tree structure matches an expected structure.

Compares the tree definitions of two PyTrees to ensure they have compatible structures before applying shardings during checkpoint loading.

Parameters
  • sharding_tree – PyTree of sharding specifications to validate.

  • expected_structure – Expected PyTree structure (e.g., from checkpoint index or a template model state).

Returns

True if the tree structures match (same number of leaves in same positions), False otherwise.

Note

  • Uses JAX tree_util for structure comparison.

  • Logs a warning if validation encounters an error.

  • Only compares structure (treedef), not the actual values.

Example

>>> shard_tree = create_sharding_tree_from_index("checkpoint/")
>>> model_structure = jax.tree_util.tree_map(lambda x: None, model_state)
>>> if validate_sharding_tree(shard_tree, model_structure):
...     print("Sharding tree is compatible")