eformer.escale.mesh.creation#

JAX mesh creation utilities for distributed computation.

This module provides utilities for creating and managing JAX meshes for distributed computation across multiple devices. It supports various parallelism strategies including:

  • Data Parallelism (dp): Replicates model across devices, splits data

  • Fully Sharded Data Parallelism (fsdp): Shards both model and data across devices

  • Expert Parallelism (ep): For mixture-of-experts models

  • Tensor Parallelism (tp): Splits individual tensors across devices

  • Sequence Parallelism (sp): Splits sequence dimension across devices

Key Features:
  • Automatic device mesh creation for single/multi-host setups

  • Support for TPU slices and multi-process environments

  • CPU-specific utilities for debugging and testing

  • String-based mesh configuration parsing

  • Caching for efficient mesh reuse

Typical Usage:
>>>
>>> mesh = create_mesh(
...     axis_dims=(2, 4),
...     axis_names=('data', 'model')
... )
>>> with mesh:
...
...     pass
>>>
>>> mesh = parse_mesh_from_string("dp:2,tp:4", ["dp", "tp"])
>>>
>>> with cpu_context() as mesh:
...
...     pass
eformer.escale.mesh.creation.calculate_host_mesh_shape(global_mesh_shape: Sequence[int], total_devices: int | None = None, num_processes: int | None = None)[source]#

Calculate the mesh shape for the local host in a distributed setting.

Determines how to split a global mesh shape across multiple processes, ensuring each host gets an appropriate portion of the mesh.

Parameters
  • global_mesh_shape – The desired global mesh shape across all processes.

  • total_devices – Total number of devices on this host. If None, uses jax.local_device_count().

  • num_processes – Total number of processes in the distributed setup. If None, uses jax.process_count().

Returns

Tuple representing the mesh shape for this host.

Raises

ValueError – If mesh size doesn’t match available devices or if the calculated host mesh doesn’t use the correct number of devices.

Example

>>>
>>> calculate_host_mesh_shape((2, 4), total_devices=4, num_processes=2)
(1, 4)
eformer.escale.mesh.creation.cpu_context()[source]#

Context manager that provides both CPU mesh and forces CPU execution.

Combines force_cpu() and create_cpu_mesh() to provide a complete CPU execution environment. This ensures both that operations run on CPU and that they use a CPU-configured mesh.

Yields

The CPU mesh created for the context.

Example

>>> with cpu_context() as mesh:
...
...     @jax.jit
...     def fn(x):
...         return x * 2
...     result = fn(jax.numpy.ones((4, 4)))

Note

This is particularly useful for: - Unit testing that needs deterministic CPU behavior - Debugging distributed code on a single machine - Prototyping before deploying to accelerators

eformer.escale.mesh.creation.create_cpu_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp')) Mesh[source]#

Create a mesh using CPU devices.

Useful for debugging, testing, or when you want to force operations to run on CPU regardless of available accelerators.

Parameters
  • axis_dims – Dimensions for each mesh axis. Default is (1, -1, 1, 1, 1). For CPU, this typically resolves to a shape matching the number of available CPU devices.

  • axis_names – Names for each axis. Default is (‘dp’, ‘fsdp’, ‘ep’, ‘tp’, ‘sp’).

Returns

JAX Mesh configured to use CPU device(s).

Note

This uses all available CPU devices on the host and arranges them according to axis_dims.

eformer.escale.mesh.creation.create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), dcn_mesh_dims: Optional[Sequence[int]] = None, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: str | None = None, use_jax: bool = False, axis_types: Union[Sequence[jax._src.mesh.AxisType | str], AxisType, str, None, Literal['auto', 'explicit', 'manual']] = None) Mesh[source]#

Create a JAX mesh for distributed computation.

Creates a mesh that maps logical mesh axes to physical devices, supporting various parallelism strategies including data, tensor, sequence, and pipeline parallelism.

Parameters
  • axis_dims – Dimensions for each mesh axis. Default is (1, -1, 1, 1, 1) where -1 means use all remaining devices.

  • axis_names – Names for each axis. Default is (‘dp’, ‘fsdp’, ‘ep’, ‘tp’, ‘sp’) representing data, fully-sharded data, expert, tensor, and sequence parallelism respectively.

  • axis_types – Optional axis type(s) for mesh axes. Accepts AxisType values or “auto”, “explicit”, “manual” strings. A single value is applied to all axes.

  • dcn_mesh_dims – Data center network mesh dimensions for hybrid device setups. If None, automatically calculated for multi-process environments.

  • should_sort_granules_by_key – Whether to sort device granules for consistent ordering across processes.

  • allow_split_physical_axes – Whether physical device axes can be split across logical mesh axes.

  • backend – JAX backend (‘cpu’, ‘gpu’, ‘tpu’). If None, uses default.

  • use_jax – If True, uses jax.make_mesh. If False, uses mesh_utils-based explicit device mesh creation (including multi-slice support). When True, multi-slice or multi-process topologies fall back to the mesh_utils path for better topology support.

Returns

JAX Mesh object ready for use with pjit and sharding specifications.

Example

>>>
>>> mesh = create_mesh(
...     axis_dims=(2, 4),
...     axis_names=('data', 'model')
... )
>>>
>>> with mesh:
...     sharded_fn = pjit(fn, in_shardings=..., out_shardings=...)
eformer.escale.mesh.creation.force_cpu()[source]#

Context manager that forces JAX operations to run on CPU.

Temporarily sets the default JAX device to CPU for all operations within the context. Useful for debugging or when specific operations need to run on CPU.

Yields

The CPU device being used.

Example

>>> with force_cpu() as cpu_device:
...
...     result = jax.numpy.sum(array)
...     print(f"Running on {cpu_device}")

Note

Device setting is restored when exiting the context.

eformer.escale.mesh.creation.parse_mesh_from_string(axis_dims: str, names: Sequence[str]) Mesh[source]#

Parse mesh configuration from string representation.

Supports two formats: 1. Named format: “dp:2,tp:4” - explicitly maps names to dimensions 2. Positional format: “2,4” - maps dimensions to names by position

Parameters
  • axis_dims – String representation of axis dimensions. Either: - Named: “name1:dim1,name2:dim2,…” (e.g., “dp:2,tp:4”) - Positional: “dim1,dim2,…” (e.g., “2,4”)

  • names – Sequence of axis names that should appear in the mesh.

Returns

JAX Mesh configured according to the string specification.

Raises

ValueError – If axis names don’t match, dimensions and names have different lengths, or unknown axis names are used.

Example

>>>
>>> mesh = parse_mesh_from_string("dp:2,tp:4", ["dp", "tp"])
>>>
>>>
>>> mesh = parse_mesh_from_string("2,4", ["data", "model"])