eformer.escale.partition.constraints#
Sharding constraint utilities and mesh introspection functions.
This module provides the core functionality for applying sharding constraints to JAX arrays with automatic correction based on mesh configuration and array shapes. It also includes utilities for introspecting mesh properties and extracting sharding information from arrays.
- Key Features:
Automatic sharding constraint correction for compatibility
Mesh introspection (axis names, sizes, device indices)
Partition rule matching via regex patterns
Sharding extraction from distributed arrays
Pattern-based partition specification generation
- Environment Variables:
- MIN_SHARDING_SIZE: Minimum array size to apply sharding (default: 16384).
Arrays smaller than this remain unsharded for efficiency.
- LOG_SHARDING_MOVE: If “true”, logs warnings about sharding corrections
and auto-adjustments.
Example
>>> from eformer.escale.partition.constraints import (
... with_sharding_constraint,
... get_incontext_mesh,
... match_partition_rules
... )
>>> # Apply sharding with automatic correction
>>> with mesh:
... sharded = with_sharding_constraint(array, PartitionSpec('dp', 'tp'))
>>> # Match rules to parameters
>>> specs = match_partition_rules(rules, model_params)
- eformer.escale.partition.constraints.analyze_sharding_strategy(pytree: Any, partition_specs: dict[str, jax.sharding.PartitionSpec], mesh: jax._src.mesh.Mesh | None = None) dict[source]#
Analyze the effectiveness of a sharding strategy.
Computes metrics to evaluate how well a sharding strategy distributes computation and memory across devices. Useful for debugging and optimizing distributed training configurations.
- Parameters
pytree – A PyTree of arrays to analyze.
partition_specs – Dictionary mapping paths to PartitionSpecs.
mesh – The JAX mesh to analyze against. If None, uses the current context’s mesh.
- Returns
“total_parameters”: Total parameter count across all arrays
”sharded_parameters”: Count of parameters that are sharded
”memory_per_device”: Per-device memory breakdown (dict)
”balance_score”: Score indicating load balance (0.0-1.0)
”partition_stats”: Statistics about partition distribution
- Return type
A dictionary containing analysis metrics
Example
>>> analysis = analyze_sharding_strategy(params, specs, mesh) >>> print(f"Sharded: {analysis['sharded_parameters']}/{analysis['total_parameters']}") >>> print(f"Balance score: {analysis['balance_score']:.2f}")
- eformer.escale.partition.constraints.create_pattern_based_partition_spec(pattern: str, mesh: jax._src.mesh.Mesh | None = None, default_spec: jax.sharding.PartitionSpec | None = None) Callable[[str, Union[Array, ndarray, bool, number]], PartitionSpec][source]#
Creates a function that returns PartitionSpec based on parameter name patterns.
Example
- pattern_fn = create_pattern_based_partition_spec(
“attention|mlp->data,hidden->model”
)
- eformer.escale.partition.constraints.extract_sharding_structure(pytree: Any) Any[source]#
Extract NamedSharding objects from a PyTree of sharded arrays.
Creates a new PyTree with the same structure as the input, where each leaf contains the NamedSharding of the corresponding array (or None if the leaf has no sharding information).
- Parameters
pytree – A PyTree potentially containing sharded JAX arrays.
- Returns
A NamedSharding object if the original leaf was a sharded array
None if the leaf had no sharding or wasn’t a JAX array
- Return type
A PyTree matching the input structure. Each leaf is either
Example
>>> shardings = extract_sharding_structure(sharded_params) >>> # shardings has same structure as sharded_params >>> # but leaves are NamedSharding objects or None
- eformer.escale.partition.constraints.extract_shardings(tree, mesh: Mesh = None)[source]#
Extracts JAX NamedSharding objects from the leaves of a PyTree.
This function traverses the input PyTree and inspects each leaf. - If a leaf has a .sharding attribute that is already a NamedSharding,
it’s returned directly.
If a leaf has a .sharding attribute that is a PartitionSpec, it attempts to convert it into a NamedSharding using the provided mesh. If no mesh is provided, it tries to get one from the JAX context (e.g., using get_incontext_mesh). If no mesh is available in either case, a ValueError is raised.
If a leaf does not have a .sharding attribute, or if its sharding is not a NamedSharding or convertible PartitionSpec, None is returned for that leaf in the output tree.
- Parameters
tree – The input PyTree (e.g., nested dictionary, list, tuple) potentially containing JAX arrays or other objects with sharding information.
mesh – An optional jax.sharding.Mesh. If provided, it’s used to convert PartitionSpec objects to NamedSharding. If None, the function attempts to find a mesh from the current JAX context.
- Returns
A PyTree with the same structure as the input tree. Each leaf will contain either a jax.sharding.NamedSharding object corresponding to the input leaf’s sharding, or None if no valid sharding information was found or could be constructed.
- Raises
ValueError – If a leaf has a PartitionSpec sharding but no mesh is provided or found in the context.
- eformer.escale.partition.constraints.get_axes_size_in_mesh(axis_names: tuple[str, ...] | str | Any | None, mesh: jax._src.mesh.Mesh | None = None) int[source]#
Calculates the total size of the specified mesh axes.
If a single axis name (string) is provided, it returns the size of that dimension in the mesh. If a sequence (list or tuple) of axis names is provided, it returns the product of the sizes of all specified axes.
If no mesh is explicitly provided, it uses the mesh active in the current context obtained via get_current_mesh().
- Parameters
axis_names – The name of a single mesh axis (str) or a sequence (list/tuple) of axis names whose sizes should be multiplied.
mesh – The mesh object to query. If None, the current context’s mesh is used. Defaults to None.
- Returns
- The size of the single specified axis, or the product of the sizes
of the sequence of specified axes.
- Return type
int
- Raises
KeyError – If any of the specified axis_names are not found in the mesh’s dimensions.
AssertionError – If mesh is None and no mesh is found in the current context (raised by get_current_mesh()).
- eformer.escale.partition.constraints.get_corrected_named_sharding(shape: tuple[int, ...], partition_spec: PartitionSpec, raise_mesh_error: bool = True) NamedSharding[source]#
Calculates the corrected PartitionSpec based on shape and mesh, returns NamedSharding.
This function takes an array shape and a desired PartitionSpec. It determines the effective PartitionSpec by correcting the input based on:
Axis names present in the current mesh.
Divisibility of array dimensions by the product of corresponding mesh axis sizes.
It does NOT correct based on mesh axes having size 1, allowing such axes to persist in the spec if explicitly provided and divisibility holds.
- Parameters
shape – The shape of the target JAX array.
partition_spec – The desired PartitionSpec.
raise_mesh_error – If True, raises an error if no mesh is active. If False, returns a replicated NamedSharding on an empty mesh if no mesh is found.
- Returns
A NamedSharding object containing the current mesh and the corrected PartitionSpec.
- Raises
AssertionError – If no mesh is active and raise_mesh_error is True.
- eformer.escale.partition.constraints.get_incontext_mesh(raise_error: bool = True) Mesh[source]#
Retrieve the mesh object active in the current execution context.
This function accesses the physical mesh defined within the thread’s resource environment (pxla.thread_resources.env.physical_mesh). It is commonly used to get the mesh when inside a with mesh: context.
- Parameters
raise_error – If True (default), raises an AssertionError when no mesh is active. If False, returns the empty mesh without error.
- Returns
The active Mesh object for the current context, or an empty mesh if raise_error is False and no mesh is active.
- Raises
AssertionError – If no mesh is found in the current context and raise_error is True.
Example
>>> with mesh: ... current_mesh = get_incontext_mesh() ... print(current_mesh.axis_names) ('dp', 'tp')
- eformer.escale.partition.constraints.get_mesh_axis_names(mesh: jax._src.mesh.Mesh | None = None) list[str][source]#
Retrieves the names of all axes defined in the mesh.
These names typically correspond to the dimensions used for sharding or parallelism.
If no mesh is explicitly provided, it uses the mesh active in the current context obtained via get_current_mesh().
- Parameters
mesh – The mesh object to query. If None, the current context’s mesh is used. Defaults to None.
- Returns
A list containing the names of all axes in the mesh.
- Return type
List[str]
- Raises
AssertionError – If mesh is None and no mesh is found in the current context (raised by get_current_mesh()).
- eformer.escale.partition.constraints.get_mesh_axis_size(axis_names: tuple[str, ...] | str | Any | None) int[source]#
Calculates the total number of devices along the specified mesh axis or axes.
- Parameters
axis_names – The name of a single mesh axis (str) or a sequence (list/tuple) of mesh axis names. The order in the sequence does not affect the result (product is commutative).
- Returns
The total number of devices (size) in the submesh defined by the axis/axes. Returns 1 if axis_names is an empty sequence.
- Raises
TypeError – If axis_names is not a str or a sequence of str.
- eformer.escale.partition.constraints.get_names_from_partition_spec(partition_specs: dict[str, jax.sharding.PartitionSpec]) list[str][source]#
Extract axis names from a partition specification.
This function recursively iterates through the provided partition_specs dictionary and extracts all unique axis names used in the sharding specifications.
- Parameters
partition_specs – A dictionary mapping parameter names to their respective PartitionSpec.
- Returns
A list of unique axis names used in the partition specs.
- eformer.escale.partition.constraints.get_partition_spec(tree)[source]#
Retrieves the PartitionSpec for each leaf in a PyTree.
This function traverses the input PyTree and determines the jax.sharding.PartitionSpec for each leaf based on its type: - If the leaf is a jax.Array, it returns the PartitionSpec from
leaf.sharding.spec.
If the leaf is a Python scalar (int or float), it returns an empty PartitionSpec(), assuming scalars are typically replicated.
For any other leaf type, it raises a ValueError.
- Parameters
tree – The input PyTree (e.g., nested dictionary, list, tuple) containing JAX arrays, scalars, or potentially other types.
- Returns
A PyTree with the same structure as the input tree. Each leaf contains the corresponding jax.sharding.PartitionSpec.
- Raises
ValueError – If a leaf in the tree is not a jax.Array, int, or float.
AttributeError – If a jax.Array leaf doesn’t have .sharding.spec (which would be unusual for a properly sharded array).
- eformer.escale.partition.constraints.get_shardings_with_structure(pytree: Any) Any[source]#
Get shardings from a PyTree while preserving structure.
Alias for extract_sharding_structure. Returns a PyTree matching the input structure where each leaf contains the NamedSharding of the corresponding array (or None if unavailable).
- Parameters
pytree – A PyTree potentially containing sharded JAX arrays.
- Returns
A PyTree matching the input structure with NamedSharding objects or None at each leaf position.
See also
extract_sharding_structure: The underlying implementation.
- eformer.escale.partition.constraints.get_submesh_device_index(axis_names: tuple[str, ...] | str | Any | None) int[source]#
Calculates the linear index of the current device within the specified mesh axes.
This effectively flattens the multi-dimensional coordinates of the device within the submesh defined by axis_names into a single integer index.
IMPORTANT: It assumes the input axis_names sequence is ordered from most major to most minor dimension. The calculation performs a row-major-like flattening based on this order.
- Parameters
axis_names – The name of a single mesh axis (str) or a sequence (list/tuple) of mesh axis names, ordered from major to minor.
- Returns
The 0-based linear index of the current device within the submesh. Returns 0 if axis_names is an empty sequence.
- Raises
TypeError – If axis_names is not a str or a sequence of str.
- eformer.escale.partition.constraints.make_shard_and_gather_fns(partition_specs: dict[str, jax.sharding.PartitionSpec], mesh: jax._src.mesh.Mesh | None = None) tuple[dict[str, Callable], dict[str, Callable]][source]#
Create shard and gather functions based on given partition specs and mesh.
This function generates dictionaries of shard and gather functions that can be used to distribute and collect arrays across a JAX mesh. The functions are specifically designed for use with Flax’s tu.tree_map.
- Parameters
partition_specs – A dictionary mapping parameter names to their respective PartitionSpec.
mesh – The JAX mesh to use for sharding. If None, the current mesh is used.
- Returns
shard_fns: A dictionary mapping parameter names to their corresponding shard functions.
gather_fns: A dictionary mapping parameter names to their corresponding gather functions.
- Return type
A tuple containing two dictionaries
- eformer.escale.partition.constraints.match_partition_rules(rules: list[tuple[str, jax.sharding.PartitionSpec]], tree: dict, min_size: int | None = 0, strict: bool = True) dict[source]#
Match partition rules to parameters based on their names.
This function takes a list of partition rules (regular expressions and corresponding PartitionSpec) and applies them to a dictionary of parameters based on their names. It’s useful for automatically defining sharding strategies. The order of keys in the output dictionary matches the input tree’s key order.
- Parameters
rules – A list of tuples, where each tuple contains: - A regular expression to match parameter names. - A PartitionSpec to apply if the name matches.
tree – A dictionary of parameters, where keys are parameter names and values are the parameters (arrays) or indices.
min_size – Minimum size for applying sharding. Parameters smaller than this will use PartitionSpec() for efficiency. Defaults to MIN_SHARDING_SIZE.
strict – If True, validates array shapes and applies min_size checks. If False, applies rules without validation.
- Returns
A dictionary with the same keys as tree, maintaining the original key order, but with values replaced by the corresponding PartitionSpec based on matching rules.
- Raises
ValueError – If no matching rule is found for a parameter.
Example
>>> rules = [(".*weight", PartitionSpec("model", None))] >>> tree = {"layer/weight": 0, "layer/bias": 1} >>> match_partition_rules(rules, tree) {"layer/weight": PartitionSpec("model", None), "layer/bias": PartitionSpec()}
- eformer.escale.partition.constraints.names_in_current_mesh(*names: str) bool[source]#
Check if the given names are present in the current JAX mesh.
- Parameters
*names – Variable number of axis names to check.
- Returns
True if all given names are present in the current mesh, False otherwise.
- eformer.escale.partition.constraints.with_sharding_constraint(arr: jax.jaxlib._jax.Array | Any, sharding: jax.sharding.PartitionSpec | jax.sharding.NamedSharding) jax.jaxlib._jax.Array | Any[source]#
Apply sharding constraints with automatic correction based on array shape and mesh.
This function takes a JAX array (or PyTree) and a sharding specification (PartitionSpec or NamedSharding). It attempts to apply the sharding, but first checks if the specification is compatible with the array’s shape and the current mesh configuration.
- If an axis specified in the PartitionSpec:
Does not exist in the mesh,
Is incompatible with the array’s dimension size (not divisible),
then that part of the PartitionSpec is automatically corrected to None, effectively preventing sharding along that dimension.
Note: Mesh axes of size 1 are allowed if divisibility holds, enabling logical sharding even on single-device axes.
- Parameters
arr – The JAX array or PyTree to apply sharding constraints to.
sharding – The desired sharding specification (PartitionSpec or NamedSharding).
- Returns
The JAX array or PyTree with potentially corrected sharding constraints applied.