eformer.escale.partition.auto_spec#

eformer.escale.partition.auto_spec.auto_namedsharding(mesh: jax._src.mesh.Mesh | None = None, names: list[str | tuple[str, ...]] | None = None, min_sharding_size: int | None = None, reverse: bool = False)[source]#

Returns a function that creates a NamedSharding for an array based on the provided parameters.

Parameters
  • mesh – The device mesh to shard across. If None, uses the current thread’s mesh.

  • names – List of mesh axis names to use for sharding. If None, derives from mesh shape.

  • min_sharding_size – Minimum size of array to shard. If None, uses the product of mesh device shape.

  • reverse – If True, reverses the sorting order of array dimensions.

Returns

A function that takes an array as input and returns a NamedSharding object.

eformer.escale.partition.auto_spec.auto_partition_spec(x: Union[Array, ndarray, bool, number], mesh: jax._src.mesh.Mesh | None = None, names: list[str | tuple[str, ...]] | None = None, min_sharding_size: int | None = None, reverse: bool = False) PartitionSpec[source]#

Create an optimized PartitionSpec to shard an array across a device mesh.

Parameters
  • x – The input array to be sharded.

  • mesh – The device mesh to shard across. If None, uses the current thread’s mesh.

  • names – List of mesh axis names to use for sharding. If None, derives from mesh shape.

  • min_sharding_size – Minimum size of array to shard. If None, uses mesh device count.

  • reverse – If True, reverses dimension sorting order for sharding assignment.

Returns

Optimized sharding specification for the input array.

Return type

PartitionSpec

Raises
  • ValueError – If mesh is unavailable or invalid names are provided.

  • TypeError – If input types are incorrect.

eformer.escale.partition.auto_spec.auto_shard_array(x: Union[Array, ndarray, bool, number], mesh: jax._src.mesh.Mesh | None = None, names: list[str | tuple[str, ...]] | None = None, min_sharding_size: int | None = None, reverse: bool = False)[source]#

Shards an array across a device mesh according to an automatically derived PartitionSpec.

This function acts as a wrapper around pjit(x, in_axis_resources=…).

Parameters
  • x – The input array to be sharded.

  • mesh – The device mesh to shard across. If None, uses the current thread’s mesh.

  • names – List of mesh axis names to use for sharding. If None, derives from mesh shape.

  • min_sharding_size – Minimum size of array to shard. If None, uses the product of mesh device shape.

  • reverse – If True, reverses the sorting order of array dimensions.

Returns

The sharded array.

eformer.escale.partition.auto_spec.convert_sharding_strategy(array: Union[Array, ndarray, bool, number], old_partition_specs: dict[str, jax.sharding.PartitionSpec], old_mesh: Mesh, new_mesh: Mesh, strategy: str = 'preserve_balance') dict[str, jax.sharding.PartitionSpec][source]#

Convert sharding strategy between different mesh configurations.

When migrating models between different mesh topologies (e.g., from 8 to 16 devices), this function adapts partition specifications to maintain similar parallelism characteristics.

Parameters
  • array – Reference array used to determine valid new partition specs.

  • old_partition_specs – Dictionary of current partition specifications.

  • old_mesh – The original mesh configuration.

  • new_mesh – The target mesh configuration to convert to.

  • strategy

    Conversion strategy. Currently supports: - “preserve_balance”: Maintains similar parallelization factor

    by using the old spec’s total split factor as the minimum sharding size for the new spec.

Returns

A dictionary of new partition specifications adapted to new_mesh.

Example

>>> # Convert from 8-device to 16-device mesh
>>> old_mesh = create_mesh((2, 4), ('dp', 'tp'))
>>> new_mesh = create_mesh((4, 4), ('dp', 'tp'))
>>> new_specs = convert_sharding_strategy(
...     array, old_specs, old_mesh, new_mesh
... )
eformer.escale.partition.auto_spec.optimize_sharding_for_memory(pytree: Any, mesh: jax._src.mesh.Mesh | None = None, max_memory_per_device: int | None = None, names: list[str] | None = None) dict[str, jax.sharding.PartitionSpec][source]#

Optimize sharding strategy to fit within per-device memory constraints.

Generates partition specifications that ensure each array’s per-device memory footprint stays within the specified limit. Arrays smaller than the limit remain unsharded for efficiency.

Parameters
  • pytree – A PyTree of arrays to generate partition specs for.

  • mesh – The JAX mesh to shard across. If None, uses the current context’s mesh.

  • max_memory_per_device – Maximum bytes per device. Arrays larger than this will be sharded to fit. If None, no memory constraint is applied (defaults to auto_partition_spec behavior).

  • names – List of mesh axis names to consider for sharding. If None, uses all axis names from the mesh.

Returns

A dictionary mapping paths to PartitionSpecs optimized for the memory constraint.

Example

>>> # Optimize for 8GB per device
>>> specs = optimize_sharding_for_memory(
...     params,
...     mesh=mesh,
...     max_memory_per_device=8 * 1024**3
... )
eformer.escale.partition.auto_spec.validate_sharding_config(pytree: Any, partition_specs: dict[str, jax.sharding.PartitionSpec], mesh: jax._src.mesh.Mesh | None = None) list[str][source]#

Validate sharding configuration and return any issues found.

Checks that partition specifications are compatible with array shapes and mesh configuration. Identifies potential problems like: - Array dimensions not divisible by mesh axis sizes - Small arrays that might not benefit from sharding

Parameters
  • pytree – A PyTree of arrays to validate.

  • partition_specs – Dictionary mapping paths to PartitionSpecs.

  • mesh – The JAX mesh to validate against. If None, uses the current context’s mesh.

Returns

A list of issue descriptions. Empty list means no issues found.

Example

>>> issues = validate_sharding_config(params, specs, mesh)
>>> if issues:
...     for issue in issues:
...         print(f"Warning: {issue}")
eformer.escale.partition.auto_spec.vrn_auto_partition_spec(x: Union[Array, ndarray, bool, number], mesh: jax._src.mesh.Mesh | None = None, names: list[str | tuple[str, ...]] | None = None, min_sharding_size: int | None = None, reverse: bool = False) PartitionSpec[source]#

Create an optimized PartitionSpec to shard an array across a device mesh.

Parameters
  • x – The input array to be sharded.

  • mesh – The device mesh to shard across. If None, uses the current thread’s mesh.

  • names – List of mesh axis names to use for sharding. If None, derives from mesh shape.

  • min_sharding_size – Minimum size of array to shard. If None, uses the product of mesh device shape.

  • reverse – If True, reverses the sorting order of array dimensions.

Returns

A PartitionSpec describing optimal array sharding.

Raises
  • ValueError – If mesh is unavailable or invalid names are provided.

  • TypeError – If input types are incorrect.