eformer.escale.partition.auto_spec#
- eformer.escale.partition.auto_spec.auto_namedsharding(mesh: jax._src.mesh.Mesh | None = None, names: list[str | tuple[str, ...]] | None = None, min_sharding_size: int | None = None, reverse: bool = False)[source]#
Returns a function that creates a NamedSharding for an array based on the provided parameters.
- Parameters
mesh – The device mesh to shard across. If None, uses the current thread’s mesh.
names – List of mesh axis names to use for sharding. If None, derives from mesh shape.
min_sharding_size – Minimum size of array to shard. If None, uses the product of mesh device shape.
reverse – If True, reverses the sorting order of array dimensions.
- Returns
A function that takes an array as input and returns a NamedSharding object.
- eformer.escale.partition.auto_spec.auto_partition_spec(x: Union[Array, ndarray, bool, number], mesh: jax._src.mesh.Mesh | None = None, names: list[str | tuple[str, ...]] | None = None, min_sharding_size: int | None = None, reverse: bool = False) PartitionSpec[source]#
Create an optimized PartitionSpec to shard an array across a device mesh.
- Parameters
x – The input array to be sharded.
mesh – The device mesh to shard across. If None, uses the current thread’s mesh.
names – List of mesh axis names to use for sharding. If None, derives from mesh shape.
min_sharding_size – Minimum size of array to shard. If None, uses mesh device count.
reverse – If True, reverses dimension sorting order for sharding assignment.
- Returns
Optimized sharding specification for the input array.
- Return type
PartitionSpec
- Raises
ValueError – If mesh is unavailable or invalid names are provided.
TypeError – If input types are incorrect.
- eformer.escale.partition.auto_spec.auto_shard_array(x: Union[Array, ndarray, bool, number], mesh: jax._src.mesh.Mesh | None = None, names: list[str | tuple[str, ...]] | None = None, min_sharding_size: int | None = None, reverse: bool = False)[source]#
Shards an array across a device mesh according to an automatically derived PartitionSpec.
This function acts as a wrapper around pjit(x, in_axis_resources=…).
- Parameters
x – The input array to be sharded.
mesh – The device mesh to shard across. If None, uses the current thread’s mesh.
names – List of mesh axis names to use for sharding. If None, derives from mesh shape.
min_sharding_size – Minimum size of array to shard. If None, uses the product of mesh device shape.
reverse – If True, reverses the sorting order of array dimensions.
- Returns
The sharded array.
- eformer.escale.partition.auto_spec.convert_sharding_strategy(array: Union[Array, ndarray, bool, number], old_partition_specs: dict[str, jax.sharding.PartitionSpec], old_mesh: Mesh, new_mesh: Mesh, strategy: str = 'preserve_balance') dict[str, jax.sharding.PartitionSpec][source]#
Convert sharding strategy between different mesh configurations.
When migrating models between different mesh topologies (e.g., from 8 to 16 devices), this function adapts partition specifications to maintain similar parallelism characteristics.
- Parameters
array – Reference array used to determine valid new partition specs.
old_partition_specs – Dictionary of current partition specifications.
old_mesh – The original mesh configuration.
new_mesh – The target mesh configuration to convert to.
strategy –
Conversion strategy. Currently supports: - “preserve_balance”: Maintains similar parallelization factor
by using the old spec’s total split factor as the minimum sharding size for the new spec.
- Returns
A dictionary of new partition specifications adapted to new_mesh.
Example
>>> # Convert from 8-device to 16-device mesh >>> old_mesh = create_mesh((2, 4), ('dp', 'tp')) >>> new_mesh = create_mesh((4, 4), ('dp', 'tp')) >>> new_specs = convert_sharding_strategy( ... array, old_specs, old_mesh, new_mesh ... )
- eformer.escale.partition.auto_spec.optimize_sharding_for_memory(pytree: Any, mesh: jax._src.mesh.Mesh | None = None, max_memory_per_device: int | None = None, names: list[str] | None = None) dict[str, jax.sharding.PartitionSpec][source]#
Optimize sharding strategy to fit within per-device memory constraints.
Generates partition specifications that ensure each array’s per-device memory footprint stays within the specified limit. Arrays smaller than the limit remain unsharded for efficiency.
- Parameters
pytree – A PyTree of arrays to generate partition specs for.
mesh – The JAX mesh to shard across. If None, uses the current context’s mesh.
max_memory_per_device – Maximum bytes per device. Arrays larger than this will be sharded to fit. If None, no memory constraint is applied (defaults to auto_partition_spec behavior).
names – List of mesh axis names to consider for sharding. If None, uses all axis names from the mesh.
- Returns
A dictionary mapping paths to PartitionSpecs optimized for the memory constraint.
Example
>>> # Optimize for 8GB per device >>> specs = optimize_sharding_for_memory( ... params, ... mesh=mesh, ... max_memory_per_device=8 * 1024**3 ... )
- eformer.escale.partition.auto_spec.validate_sharding_config(pytree: Any, partition_specs: dict[str, jax.sharding.PartitionSpec], mesh: jax._src.mesh.Mesh | None = None) list[str][source]#
Validate sharding configuration and return any issues found.
Checks that partition specifications are compatible with array shapes and mesh configuration. Identifies potential problems like: - Array dimensions not divisible by mesh axis sizes - Small arrays that might not benefit from sharding
- Parameters
pytree – A PyTree of arrays to validate.
partition_specs – Dictionary mapping paths to PartitionSpecs.
mesh – The JAX mesh to validate against. If None, uses the current context’s mesh.
- Returns
A list of issue descriptions. Empty list means no issues found.
Example
>>> issues = validate_sharding_config(params, specs, mesh) >>> if issues: ... for issue in issues: ... print(f"Warning: {issue}")
- eformer.escale.partition.auto_spec.vrn_auto_partition_spec(x: Union[Array, ndarray, bool, number], mesh: jax._src.mesh.Mesh | None = None, names: list[str | tuple[str, ...]] | None = None, min_sharding_size: int | None = None, reverse: bool = False) PartitionSpec[source]#
Create an optimized PartitionSpec to shard an array across a device mesh.
- Parameters
x – The input array to be sharded.
mesh – The device mesh to shard across. If None, uses the current thread’s mesh.
names – List of mesh axis names to use for sharding. If None, derives from mesh shape.
min_sharding_size – Minimum size of array to shard. If None, uses the product of mesh device shape.
reverse – If True, reverses the sorting order of array dimensions.
- Returns
A PartitionSpec describing optimal array sharding.
- Raises
ValueError – If mesh is unavailable or invalid names are provided.
TypeError – If input types are incorrect.