eformer.escale.helpers.base#

class eformer.escale.helpers.base.AutoShardingRule(mesh: jax._src.mesh.Mesh | None = None, axis_names: list[str] | None = None, min_shard_size: int | None = None, reverse: bool = False)[source]#

Bases: ShardingRule

Automatically determines sharding based on array shapes and mesh configuration.

This rule analyzes array shapes and assigns mesh axes to array dimensions to achieve optimal parallelism. It prioritizes larger dimensions and ensures that array dimensions are divisible by the corresponding mesh axis sizes.

The algorithm works as follows: 1. Skip arrays smaller than min_shard_size (they remain unsharded) 2. Sort array dimensions by size (largest first, unless reverse=True) 3. For each dimension, find the first available mesh axis that divides evenly 4. Assign that axis to the dimension and remove it from available axes

mesh#

The JAX mesh to shard across.

axis_names#

List of mesh axis names to consider for sharding.

min_shard_size#

Minimum array size (in elements) to apply sharding.

reverse#

If True, processes smaller dimensions first.

Example

>>> mesh = create_mesh(axis_dims=(2, 4), axis_names=('dp', 'tp'))
>>> rule = AutoShardingRule(mesh=mesh, axis_names=['dp', 'tp'])
>>> # Apply to model parameters
>>> specs = rule.apply(model_params)
apply(pytree: Any) Any[source]#

Apply auto-sharding to all arrays in a PyTree.

Parameters

pytree – A PyTree of arrays to generate partition specs for.

Returns

A PyTree with the same structure containing PartitionSpecs optimized for each array’s shape.

class eformer.escale.helpers.base.CompositeShardingRule(*rules: ShardingRule)[source]#

Bases: ShardingRule

Combines multiple sharding rules with priority ordering.

This rule applies multiple sharding rules in sequence and selects the first non-empty PartitionSpec for each array. This is useful for implementing fallback strategies where a specific rule might not produce valid sharding for all arrays.

The priority order is determined by the order of rules passed to the constructor. Earlier rules have higher priority.

rules#

Tuple of ShardingRule instances to combine.

Example

>>> # Try shape-based first, fall back to auto sharding
>>> shape_rule = ShapeBasedShardingRule({(None, 1024): PartitionSpec('tp')})
>>> auto_rule = AutoShardingRule(mesh=mesh)
>>> combined = CompositeShardingRule(shape_rule, auto_rule)
>>> specs = combined.apply(model_params)
apply(pytree: Any) Any[source]#

Apply combined sharding rules to a PyTree.

Applies all rules and for each leaf selects the first non-empty PartitionSpec. If all rules return empty specs, an empty spec is used.

Parameters

pytree – A PyTree of arrays to generate partition specs for.

Returns

A PyTree with the same structure containing the highest-priority non-empty PartitionSpec for each array.

class eformer.escale.helpers.base.MemoryConstrainedShardingRule(max_memory_per_device: int, mesh: jax._src.mesh.Mesh | None = None, axis_names: list[str] | None = None)[source]#

Bases: ShardingRule

Creates sharding based on per-device memory constraints.

This rule ensures that each device’s memory usage stays within a specified limit by sharding large arrays across the mesh. It prioritizes sharding the largest dimensions first and uses the largest mesh axes for maximum memory reduction.

The algorithm: 1. If array fits in memory, return empty PartitionSpec (no sharding) 2. Sort mesh axes by size (largest first for maximum memory reduction) 3. Sort array dimensions by size (largest first) 4. Iteratively assign mesh axes to dimensions until memory fits

max_memory_per_device#

Maximum bytes allowed per device.

mesh#

The JAX mesh to shard across.

axis_names#

List of mesh axis names to consider for sharding.

Example

>>> # Allow max 1GB per device
>>> rule = MemoryConstrainedShardingRule(
...     max_memory_per_device=1024**3,
...     mesh=mesh
... )
>>> specs = rule.apply(large_model_params)
apply(pytree: Any) Any[source]#

Apply memory-constrained sharding to all arrays in a PyTree.

Parameters

pytree – A PyTree of arrays to generate partition specs for.

Returns

A PyTree with the same structure containing PartitionSpecs that ensure each array fits within the memory constraint.

class eformer.escale.helpers.base.ShapeBasedShardingRule(shape_patterns: dict[tuple[int | None, ...], jax.sharding.PartitionSpec])[source]#

Bases: ShardingRule

Creates sharding based on array shape patterns.

This rule allows defining specific sharding strategies for arrays matching certain shape patterns. Patterns can include wildcards (None) to match any dimension size.

This is useful when you know certain shapes should always be sharded in a particular way, such as embedding tables or attention weights.

shape_patterns#

Dictionary mapping shape patterns to PartitionSpecs.

Example

>>> # Shard arrays with shape (vocab_size, embed_dim) along first axis
>>> patterns = {
...     (None, 1024): PartitionSpec('tp', None),  # Embedding tables
...     (1024, 1024): PartitionSpec('tp', None),  # Square weight matrices
... }
>>> rule = ShapeBasedShardingRule(patterns)
>>> specs = rule.apply(model_params)
apply(pytree: Any) Any[source]#

Apply shape-based sharding to all arrays in a PyTree.

Parameters

pytree – A PyTree of arrays to generate partition specs for.

Returns

A PyTree with the same structure containing PartitionSpecs based on matching shape patterns.

class eformer.escale.helpers.base.ShardingAnalyzer(mesh: jax._src.mesh.Mesh | None = None)[source]#

Bases: object

Analyzes and validates sharding strategies.

mesh#

The mesh configuration for sharding. If not provided, it defaults to the physical mesh from the thread resources.

Type

Mesh

validate_partition_specs(pytree

tp.Any, partition_specs: tp.Any) -> tp.List[str]: Validates the compatibility of partition specifications with the shapes of arrays in the pytree. :param pytree: A pytree of arrays to be validated. :type pytree: tp.Any :param partition_specs: A pytree of partition specifications corresponding to the arrays. :type partition_specs: tp.Any

Returns

A list of issues found during validation. If empty, no issues were found.

Return type

tp.List[str]

estimate_memory_usage(pytree

tp.Any, partition_specs: tp.Any) -> tp.Dict[str, int]: Estimates the memory usage per device after applying the sharding strategy. :param pytree: A pytree of arrays for which memory usage is to be estimated. :type pytree: tp.Any :param partition_specs: A pytree of partition specifications corresponding to the arrays. :type partition_specs: tp.Any

Returns

A dictionary containing the total memory size and the size per device.

Return type

tp.Dict[str, int]

estimate_memory_usage(pytree: Any, partition_specs: Any) dict[str, int][source]#

Estimate memory usage per device after applying sharding.

Calculates the total memory footprint and estimates how much memory each device will need after the sharding strategy is applied.

Parameters
  • pytree – A PyTree of arrays to estimate memory for.

  • partition_specs – A PyTree of PartitionSpecs with the same structure as pytree.

Returns

  • “total_size”: Total memory in bytes before sharding

  • ”size_per_device”: Estimated memory per device after sharding

Return type

A dictionary containing

Example

>>> analyzer = ShardingAnalyzer(mesh)
>>> usage = analyzer.estimate_memory_usage(params, specs)
>>> print(f"Memory per device: {usage['size_per_device'] / 1e9:.2f} GB")
validate_partition_specs(pytree: Any, partition_specs: Any) list[str][source]#

Validate compatibility of partition specs with array shapes.

Checks that each array dimension is divisible by the corresponding mesh axis size specified in the partition spec.

Parameters
  • pytree – A PyTree of arrays to validate.

  • partition_specs – A PyTree of PartitionSpecs with the same structure as pytree.

Returns

A list of validation issue messages. Empty list means no issues.

Example

>>> analyzer = ShardingAnalyzer(mesh)
>>> issues = analyzer.validate_partition_specs(params, specs)
>>> if issues:
...     print("Validation failed:", issues)
class eformer.escale.helpers.base.ShardingRule[source]#

Bases: ABC

Abstract base class for defining sharding rules.

Sharding rules define how arrays in a PyTree should be partitioned across devices in a mesh. Subclasses must implement the apply method to provide specific sharding logic.

This class follows the Strategy pattern, allowing different sharding algorithms to be swapped at runtime.

Example

>>> class CustomRule(ShardingRule):
...     def apply(self, pytree):
...         return jax.tree_util.tree_map(
...             lambda x: PartitionSpec('data'),
...             pytree
...         )
abstract apply(pytree: Any) Any[source]#

Apply the sharding rule to a PyTree of arrays.

Parameters

pytree – A PyTree (nested structure) of arrays to generate partition specifications for.

Returns

A PyTree with the same structure as the input, where each leaf is replaced with its corresponding PartitionSpec.

eformer.escale.helpers.base.barrier_sync(timeout: float = 200)[source]#

Synchronize all JAX processes at a barrier point.

Blocks execution until all processes in the distributed JAX runtime reach this barrier. This is essential for ensuring consistency across distributed training, especially before/after collective operations or checkpointing.

The function uses a global counter to create unique barrier names, allowing multiple barriers to be used sequentially without conflicts.

Parameters

timeout – Maximum time to wait for all processes to reach the barrier, in seconds. Defaults to 200 seconds (3.33 minutes). If the timeout is exceeded, a RuntimeError will be raised by the underlying JAX distributed client.

Returns

None

Raises

RuntimeError – If the JAX distributed client is not initialized. This typically means JAX was not started in distributed mode or the distributed runtime failed to initialize.

Note

  • This function is a no-op when running with a single process (jax.process_count() == 1), allowing code to work seamlessly in both single and multi-process environments.

  • Each call increments a global counter to ensure unique barrier names, preventing conflicts when multiple barriers are used in sequence.

  • The timeout is converted to milliseconds for the underlying JAX API.

Example

>>>
>>> model = train_step(model, batch)
>>> barrier_sync()
>>> if jax.process_index() == 0:
...     save_checkpoint(model)
>>> barrier_sync()
>>>
>>> barrier_sync(timeout=600)

Warning

Ensure all processes call barrier_sync() the same number of times and in the same order, or deadlocks may occur. Conditional barriers based on process rank should be avoided.

eformer.escale.helpers.base.create_monitored_function(fn: Callable, partition_specs: Any, analyzer: ShardingAnalyzer) Callable[source]#

Create a monitored version of a function with sharding analysis.

Wraps a function to automatically validate sharding, measure execution time, and track memory usage. Useful for debugging and optimizing distributed training loops.

Parameters
  • fn – The function to wrap with monitoring.

  • partition_specs – The partition specifications to validate inputs against.

  • analyzer – A ShardingAnalyzer instance for validation and memory estimation.

Returns

A wrapped function that returns a tuple of (result, metrics) where metrics contains execution_time, memory_usage, and validation_issues.

Example

>>> analyzer = ShardingAnalyzer(mesh)
>>> monitored_train_step = create_monitored_function(
...     train_step, partition_specs, analyzer
... )
>>> result, metrics = monitored_train_step(params, batch)
>>> print(f"Execution time: {metrics['execution_time']:.2f}s")

Warning

If validation issues are detected, a warning is raised but execution continues. This allows debugging without interrupting training.