eformer.serialization.sharding_utils#
Utilities for handling sharding in checkpoint serialization.
- eformer.serialization.sharding_utils.apply_sharding_tree(arrays: dict, sharding_tree: dict | None, mesh: jax._src.mesh.Mesh | None = None) dict[source]#
Apply sharding specifications from a sharding tree to arrays.
Takes a dictionary of arrays and applies corresponding sharding specifications from a matching sharding tree structure. Handles callable shardings, direct Sharding objects, and falls back to default sharding when needed.
- Parameters
arrays – Dictionary of arrays (flattened or nested). Can contain JAX arrays or numpy arrays.
sharding_tree – PyTree of sharding specifications matching the structure of arrays. Each leaf can be a Sharding object, a callable that returns a Sharding, or None for default sharding.
mesh – JAX mesh for creating default shardings when sharding_tree leaf is None. If mesh is None, uses SingleDeviceSharding.
- Returns
Dictionary of arrays with sharding applied via jax.device_put.
Note
If sharding_tree is None, returns arrays unchanged.
Structure mismatches are logged as warnings and original arrays returned.
Non-array leaves (scalars, etc.) are returned unchanged.
Example
>>> arrays = {"layer1.weight": weight_arr, "layer1.bias": bias_arr} >>> shard_tree = { ... "layer1.weight": NamedSharding(mesh, PartitionSpec("data")), ... "layer1.bias": NamedSharding(mesh, PartitionSpec()), ... } >>> sharded = apply_sharding_tree(arrays, shard_tree, mesh)
- eformer.serialization.sharding_utils.create_sharding_tree_from_index(checkpoint_dir: str, mesh: jax._src.mesh.Mesh | None = None, prefix: str | None = None, default_sharding: jaxlib._jax.Sharding | jax.sharding.PartitionSpec | collections.abc.Callable[[Any], jaxlib._jax.Sharding] | None = None) dict[source]#
Create a sharding tree from tensorstore index file.
Creates a PyTree structure that matches the checkpoint structure, where each leaf is a sharding specification or function that can be applied to the corresponding array during deserialization.
- Parameters
checkpoint_dir – Directory containing the tensorstore checkpoint.
mesh – JAX mesh for creating shardings. If None, uses replicated sharding.
prefix – Optional prefix to create sharding tree for specific subtree.
default_sharding – Default sharding to use for all arrays. Can be: - A Sharding object - A PartitionSpec (will be wrapped with NamedSharding using mesh) - A callable that takes array info dict and returns a Sharding - None (uses fully replicated sharding)
- Returns
Dictionary with same structure as checkpoint, where leaves are sharding specifications.
Example
>>> >>> shard_tree = create_sharding_tree_from_index("checkpoint/")
>>> >>> def custom_shard_fn(info): ... if "embedding" in info["path"]: ... return NamedSharding(mesh, PartitionSpec("data", None)) ... return NamedSharding(mesh, PartitionSpec()) >>> shard_tree = create_sharding_tree_from_index( ... "checkpoint/", mesh=mesh, default_sharding=custom_shard_fn ... )
- eformer.serialization.sharding_utils.make_itsharded(xs, mesh)[source]#
Convert a PyTree of arrays to fully replicated shardings on a mesh.
Takes a PyTree and reshards all fully addressable JAX arrays to use a replicated sharding (PartitionSpec()) on the provided mesh. Non-array leaves and arrays that are not fully addressable are left unchanged.
- Parameters
xs – PyTree containing JAX arrays and potentially other values.
mesh – JAX Mesh to use for the replicated sharding.
- Returns
PyTree with same structure where fully addressable arrays have been resharded to replicated layout across all devices in the mesh.
Note
Uses JIT compilation for efficient device placement.
Only processes arrays where is_fully_addressable is True.
Useful for preparing data for collective operations or checkpointing.
Example
>>> from jax.sharding import Mesh >>> mesh = Mesh(jax.devices(), ("data",)) >>> sharded_state = make_itsharded(model_state, mesh)
- eformer.serialization.sharding_utils.validate_sharding_tree(sharding_tree: dict, expected_structure: dict) bool[source]#
Validate that a sharding tree structure matches an expected structure.
Compares the tree definitions of two PyTrees to ensure they have compatible structures before applying shardings during checkpoint loading.
- Parameters
sharding_tree – PyTree of sharding specifications to validate.
expected_structure – Expected PyTree structure (e.g., from checkpoint index or a template model state).
- Returns
True if the tree structures match (same number of leaves in same positions), False otherwise.
Note
Uses JAX tree_util for structure comparison.
Logs a warning if validation encounters an error.
Only compares structure (treedef), not the actual values.
Example
>>> shard_tree = create_sharding_tree_from_index("checkpoint/") >>> model_structure = jax.tree_util.tree_map(lambda x: None, model_state) >>> if validate_sharding_tree(shard_tree, model_structure): ... print("Sharding tree is compatible")