eformer.escale.helpers.base#
- class eformer.escale.helpers.base.AutoShardingRule(mesh: jax._src.mesh.Mesh | None = None, axis_names: list[str] | None = None, min_shard_size: int | None = None, reverse: bool = False)[source]#
Bases:
ShardingRuleAutomatically determines sharding based on array shapes and mesh configuration.
This rule analyzes array shapes and assigns mesh axes to array dimensions to achieve optimal parallelism. It prioritizes larger dimensions and ensures that array dimensions are divisible by the corresponding mesh axis sizes.
The algorithm works as follows: 1. Skip arrays smaller than min_shard_size (they remain unsharded) 2. Sort array dimensions by size (largest first, unless reverse=True) 3. For each dimension, find the first available mesh axis that divides evenly 4. Assign that axis to the dimension and remove it from available axes
- mesh#
The JAX mesh to shard across.
- axis_names#
List of mesh axis names to consider for sharding.
- min_shard_size#
Minimum array size (in elements) to apply sharding.
- reverse#
If True, processes smaller dimensions first.
Example
>>> mesh = create_mesh(axis_dims=(2, 4), axis_names=('dp', 'tp')) >>> rule = AutoShardingRule(mesh=mesh, axis_names=['dp', 'tp']) >>> # Apply to model parameters >>> specs = rule.apply(model_params)
- class eformer.escale.helpers.base.CompositeShardingRule(*rules: ShardingRule)[source]#
Bases:
ShardingRuleCombines multiple sharding rules with priority ordering.
This rule applies multiple sharding rules in sequence and selects the first non-empty PartitionSpec for each array. This is useful for implementing fallback strategies where a specific rule might not produce valid sharding for all arrays.
The priority order is determined by the order of rules passed to the constructor. Earlier rules have higher priority.
- rules#
Tuple of ShardingRule instances to combine.
Example
>>> # Try shape-based first, fall back to auto sharding >>> shape_rule = ShapeBasedShardingRule({(None, 1024): PartitionSpec('tp')}) >>> auto_rule = AutoShardingRule(mesh=mesh) >>> combined = CompositeShardingRule(shape_rule, auto_rule) >>> specs = combined.apply(model_params)
- apply(pytree: Any) Any[source]#
Apply combined sharding rules to a PyTree.
Applies all rules and for each leaf selects the first non-empty PartitionSpec. If all rules return empty specs, an empty spec is used.
- Parameters
pytree – A PyTree of arrays to generate partition specs for.
- Returns
A PyTree with the same structure containing the highest-priority non-empty PartitionSpec for each array.
- class eformer.escale.helpers.base.MemoryConstrainedShardingRule(max_memory_per_device: int, mesh: jax._src.mesh.Mesh | None = None, axis_names: list[str] | None = None)[source]#
Bases:
ShardingRuleCreates sharding based on per-device memory constraints.
This rule ensures that each device’s memory usage stays within a specified limit by sharding large arrays across the mesh. It prioritizes sharding the largest dimensions first and uses the largest mesh axes for maximum memory reduction.
The algorithm: 1. If array fits in memory, return empty PartitionSpec (no sharding) 2. Sort mesh axes by size (largest first for maximum memory reduction) 3. Sort array dimensions by size (largest first) 4. Iteratively assign mesh axes to dimensions until memory fits
- max_memory_per_device#
Maximum bytes allowed per device.
- mesh#
The JAX mesh to shard across.
- axis_names#
List of mesh axis names to consider for sharding.
Example
>>> # Allow max 1GB per device >>> rule = MemoryConstrainedShardingRule( ... max_memory_per_device=1024**3, ... mesh=mesh ... ) >>> specs = rule.apply(large_model_params)
- class eformer.escale.helpers.base.ShapeBasedShardingRule(shape_patterns: dict[tuple[int | None, ...], jax.sharding.PartitionSpec])[source]#
Bases:
ShardingRuleCreates sharding based on array shape patterns.
This rule allows defining specific sharding strategies for arrays matching certain shape patterns. Patterns can include wildcards (None) to match any dimension size.
This is useful when you know certain shapes should always be sharded in a particular way, such as embedding tables or attention weights.
- shape_patterns#
Dictionary mapping shape patterns to PartitionSpecs.
Example
>>> # Shard arrays with shape (vocab_size, embed_dim) along first axis >>> patterns = { ... (None, 1024): PartitionSpec('tp', None), # Embedding tables ... (1024, 1024): PartitionSpec('tp', None), # Square weight matrices ... } >>> rule = ShapeBasedShardingRule(patterns) >>> specs = rule.apply(model_params)
- class eformer.escale.helpers.base.ShardingAnalyzer(mesh: jax._src.mesh.Mesh | None = None)[source]#
Bases:
objectAnalyzes and validates sharding strategies.
- mesh#
The mesh configuration for sharding. If not provided, it defaults to the physical mesh from the thread resources.
- Type
- validate_partition_specs(pytree
tp.Any, partition_specs: tp.Any) -> tp.List[str]: Validates the compatibility of partition specifications with the shapes of arrays in the pytree. :param pytree: A pytree of arrays to be validated. :type pytree: tp.Any :param partition_specs: A pytree of partition specifications corresponding to the arrays. :type partition_specs: tp.Any
- Returns
A list of issues found during validation. If empty, no issues were found.
- Return type
tp.List[str]
- estimate_memory_usage(pytree
tp.Any, partition_specs: tp.Any) -> tp.Dict[str, int]: Estimates the memory usage per device after applying the sharding strategy. :param pytree: A pytree of arrays for which memory usage is to be estimated. :type pytree: tp.Any :param partition_specs: A pytree of partition specifications corresponding to the arrays. :type partition_specs: tp.Any
- Returns
A dictionary containing the total memory size and the size per device.
- Return type
tp.Dict[str, int]
- estimate_memory_usage(pytree: Any, partition_specs: Any) dict[str, int][source]#
Estimate memory usage per device after applying sharding.
Calculates the total memory footprint and estimates how much memory each device will need after the sharding strategy is applied.
- Parameters
pytree – A PyTree of arrays to estimate memory for.
partition_specs – A PyTree of PartitionSpecs with the same structure as pytree.
- Returns
“total_size”: Total memory in bytes before sharding
”size_per_device”: Estimated memory per device after sharding
- Return type
A dictionary containing
Example
>>> analyzer = ShardingAnalyzer(mesh) >>> usage = analyzer.estimate_memory_usage(params, specs) >>> print(f"Memory per device: {usage['size_per_device'] / 1e9:.2f} GB")
- validate_partition_specs(pytree: Any, partition_specs: Any) list[str][source]#
Validate compatibility of partition specs with array shapes.
Checks that each array dimension is divisible by the corresponding mesh axis size specified in the partition spec.
- Parameters
pytree – A PyTree of arrays to validate.
partition_specs – A PyTree of PartitionSpecs with the same structure as pytree.
- Returns
A list of validation issue messages. Empty list means no issues.
Example
>>> analyzer = ShardingAnalyzer(mesh) >>> issues = analyzer.validate_partition_specs(params, specs) >>> if issues: ... print("Validation failed:", issues)
- class eformer.escale.helpers.base.ShardingRule[source]#
Bases:
ABCAbstract base class for defining sharding rules.
Sharding rules define how arrays in a PyTree should be partitioned across devices in a mesh. Subclasses must implement the apply method to provide specific sharding logic.
This class follows the Strategy pattern, allowing different sharding algorithms to be swapped at runtime.
Example
>>> class CustomRule(ShardingRule): ... def apply(self, pytree): ... return jax.tree_util.tree_map( ... lambda x: PartitionSpec('data'), ... pytree ... )
- abstract apply(pytree: Any) Any[source]#
Apply the sharding rule to a PyTree of arrays.
- Parameters
pytree – A PyTree (nested structure) of arrays to generate partition specifications for.
- Returns
A PyTree with the same structure as the input, where each leaf is replaced with its corresponding PartitionSpec.
- eformer.escale.helpers.base.barrier_sync(timeout: float = 200)[source]#
Synchronize all JAX processes at a barrier point.
Blocks execution until all processes in the distributed JAX runtime reach this barrier. This is essential for ensuring consistency across distributed training, especially before/after collective operations or checkpointing.
The function uses a global counter to create unique barrier names, allowing multiple barriers to be used sequentially without conflicts.
- Parameters
timeout – Maximum time to wait for all processes to reach the barrier, in seconds. Defaults to 200 seconds (3.33 minutes). If the timeout is exceeded, a RuntimeError will be raised by the underlying JAX distributed client.
- Returns
None
- Raises
RuntimeError – If the JAX distributed client is not initialized. This typically means JAX was not started in distributed mode or the distributed runtime failed to initialize.
Note
This function is a no-op when running with a single process (jax.process_count() == 1), allowing code to work seamlessly in both single and multi-process environments.
Each call increments a global counter to ensure unique barrier names, preventing conflicts when multiple barriers are used in sequence.
The timeout is converted to milliseconds for the underlying JAX API.
Example
>>> >>> model = train_step(model, batch) >>> barrier_sync() >>> if jax.process_index() == 0: ... save_checkpoint(model) >>> barrier_sync()
>>> >>> barrier_sync(timeout=600)
Warning
Ensure all processes call barrier_sync() the same number of times and in the same order, or deadlocks may occur. Conditional barriers based on process rank should be avoided.
- eformer.escale.helpers.base.create_monitored_function(fn: Callable, partition_specs: Any, analyzer: ShardingAnalyzer) Callable[source]#
Create a monitored version of a function with sharding analysis.
Wraps a function to automatically validate sharding, measure execution time, and track memory usage. Useful for debugging and optimizing distributed training loops.
- Parameters
fn – The function to wrap with monitoring.
partition_specs – The partition specifications to validate inputs against.
analyzer – A ShardingAnalyzer instance for validation and memory estimation.
- Returns
A wrapped function that returns a tuple of (result, metrics) where metrics contains execution_time, memory_usage, and validation_issues.
Example
>>> analyzer = ShardingAnalyzer(mesh) >>> monitored_train_step = create_monitored_function( ... train_step, partition_specs, analyzer ... ) >>> result, metrics = monitored_train_step(params, batch) >>> print(f"Execution time: {metrics['execution_time']:.2f}s")
Warning
If validation issues are detected, a warning is raised but execution continues. This allows debugging without interrupting training.