eformer.escale.mesh.mesh_helpers#
- class eformer.escale.mesh.mesh_helpers.MeshPartitionHelper(mesh: Mesh)[source]#
Bases:
objectHelper class for analyzing and applying partition strategies to PyTrees.
This class provides utilities for automatically determining optimal sharding strategies based on array shapes and mesh configuration. It supports various parallelism patterns including FSDP, data parallelism, tensor parallelism, and sequence parallelism.
The helper analyzes array shapes and suggests appropriate sharding methods based on the available mesh axes and their sizes.
- mesh#
The JAX mesh to use for sharding.
- axis_sizes#
Dictionary mapping axis names to their sizes.
Example
>>> mesh = create_mesh(axis_dims=(2, 4, 1, 2, 1), ... axis_names=('dp', 'fsdp', 'ep', 'tp', 'sp')) >>> helper = MeshPartitionHelper(mesh) >>> # Analyze and auto-shard a model >>> sharded_params = helper.auto_shard_pytree(model_params)
- analyze_pytree(pytree: Any) dict[tuple[int, ...], jax.sharding.PartitionSpec][source]#
Analyze a PyTree and suggest partitioning methods for each unique shape.
Collects all unique array shapes in the PyTree and determines appropriate sharding methods for each based on the mesh configuration.
- Parameters
pytree – A PyTree of arrays to analyze.
- Returns
A dictionary mapping array shapes to lists of suggested sharding method tuples. Each method tuple contains axis names to use for sharding (e.g., (‘fsdp’, ‘sp’) for combined sharding).
Example
>>> helper = MeshPartitionHelper(mesh) >>> shape_methods = helper.analyze_pytree(params) >>> # {(1024, 4096): [('fsdp', 'sp'), ('tp',)], ...}
- auto_shard_pytree(pytree: Any, min_shard_size: int = 1024)[source]#
Automatically shard an entire PyTree based on shape analysis.
Analyzes all arrays in the PyTree, determines optimal sharding strategies, and applies them. This is a convenience method that combines analyze_pytree, create_partition_spec, and shard_array.
- Parameters
pytree – A PyTree of arrays to shard.
min_shard_size – Minimum number of elements to consider sharding. Arrays smaller than this remain unsharded. Defaults to 1024.
- Returns
A PyTree with the same structure where each array has been sharded according to its optimal partition specification.
Example
>>> helper = MeshPartitionHelper(mesh) >>> sharded_params = helper.auto_shard_pytree(model_params)
- create_partition_spec(array_shape: tuple[int, ...], methods: list[tuple], min_shard_size: int = 1024) PartitionSpec[source]#
Create a PartitionSpec for an array using suggested sharding methods.
Takes a list of sharding method tuples and determines how to apply them to the array dimensions. Handles both single-axis and multi-axis (combined) sharding methods.
- Parameters
array_shape – The shape of the array to create a spec for.
methods – List of sharding method tuples from _suggest_methods.
min_shard_size – Minimum number of elements per shard to consider sharding worthwhile. Prevents over-sharding small arrays. Defaults to 1024.
- Returns
A PartitionSpec assigning mesh axes to array dimensions. Returns an empty PartitionSpec for scalar arrays or if no suitable sharding is found.
Example
>>> helper = MeshPartitionHelper(mesh) >>> methods = helper._suggest_methods((1024, 4096)) >>> spec = helper.create_partition_spec((1024, 4096), methods) >>> # PartitionSpec('fsdp', 'tp')
- shard_array(array, partition_spec)[source]#
Shard an array according to a partition specification.
Places the array on devices according to the given partition spec, creating a distributed array with NamedSharding.
- Parameters
array – The array to shard.
partition_spec – The PartitionSpec defining how to distribute the array across devices.
- Returns
A JAX array distributed across devices according to the partition specification.