eformer.escale.mesh.creation#
JAX mesh creation utilities for distributed computation.
This module provides utilities for creating and managing JAX meshes for distributed computation across multiple devices. It supports various parallelism strategies including:
Data Parallelism (dp): Replicates model across devices, splits data
Fully Sharded Data Parallelism (fsdp): Shards both model and data across devices
Expert Parallelism (ep): For mixture-of-experts models
Tensor Parallelism (tp): Splits individual tensors across devices
Sequence Parallelism (sp): Splits sequence dimension across devices
- Key Features:
Automatic device mesh creation for single/multi-host setups
Support for TPU slices and multi-process environments
CPU-specific utilities for debugging and testing
String-based mesh configuration parsing
Caching for efficient mesh reuse
- Typical Usage:
>>> >>> mesh = create_mesh( ... axis_dims=(2, 4), ... axis_names=('data', 'model') ... ) >>> with mesh: ... ... pass
>>> >>> mesh = parse_mesh_from_string("dp:2,tp:4", ["dp", "tp"])
>>> >>> with cpu_context() as mesh: ... ... pass
- eformer.escale.mesh.creation.calculate_host_mesh_shape(global_mesh_shape: Sequence[int], total_devices: int | None = None, num_processes: int | None = None)[source]#
Calculate the mesh shape for the local host in a distributed setting.
Determines how to split a global mesh shape across multiple processes, ensuring each host gets an appropriate portion of the mesh.
- Parameters
global_mesh_shape – The desired global mesh shape across all processes.
total_devices – Total number of devices on this host. If None, uses jax.local_device_count().
num_processes – Total number of processes in the distributed setup. If None, uses jax.process_count().
- Returns
Tuple representing the mesh shape for this host.
- Raises
ValueError – If mesh size doesn’t match available devices or if the calculated host mesh doesn’t use the correct number of devices.
Example
>>> >>> calculate_host_mesh_shape((2, 4), total_devices=4, num_processes=2) (1, 4)
- eformer.escale.mesh.creation.cpu_context()[source]#
Context manager that provides both CPU mesh and forces CPU execution.
Combines force_cpu() and create_cpu_mesh() to provide a complete CPU execution environment. This ensures both that operations run on CPU and that they use a CPU-configured mesh.
- Yields
The CPU mesh created for the context.
Example
>>> with cpu_context() as mesh: ... ... @jax.jit ... def fn(x): ... return x * 2 ... result = fn(jax.numpy.ones((4, 4)))
Note
This is particularly useful for: - Unit testing that needs deterministic CPU behavior - Debugging distributed code on a single machine - Prototyping before deploying to accelerators
- eformer.escale.mesh.creation.create_cpu_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp')) Mesh[source]#
Create a mesh using CPU devices.
Useful for debugging, testing, or when you want to force operations to run on CPU regardless of available accelerators.
- Parameters
axis_dims – Dimensions for each mesh axis. Default is (1, -1, 1, 1, 1). For CPU, this typically resolves to a shape matching the number of available CPU devices.
axis_names – Names for each axis. Default is (‘dp’, ‘fsdp’, ‘ep’, ‘tp’, ‘sp’).
- Returns
JAX Mesh configured to use CPU device(s).
Note
This uses all available CPU devices on the host and arranges them according to axis_dims.
- eformer.escale.mesh.creation.create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'ep', 'tp', 'sp'), dcn_mesh_dims: Optional[Sequence[int]] = None, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: str | None = None, use_jax: bool = False, axis_types: Union[Sequence[jax._src.mesh.AxisType | str], AxisType, str, None, Literal['auto', 'explicit', 'manual']] = None) Mesh[source]#
Create a JAX mesh for distributed computation.
Creates a mesh that maps logical mesh axes to physical devices, supporting various parallelism strategies including data, tensor, sequence, and pipeline parallelism.
- Parameters
axis_dims – Dimensions for each mesh axis. Default is (1, -1, 1, 1, 1) where -1 means use all remaining devices.
axis_names – Names for each axis. Default is (‘dp’, ‘fsdp’, ‘ep’, ‘tp’, ‘sp’) representing data, fully-sharded data, expert, tensor, and sequence parallelism respectively.
axis_types – Optional axis type(s) for mesh axes. Accepts AxisType values or “auto”, “explicit”, “manual” strings. A single value is applied to all axes.
dcn_mesh_dims – Data center network mesh dimensions for hybrid device setups. If None, automatically calculated for multi-process environments.
should_sort_granules_by_key – Whether to sort device granules for consistent ordering across processes.
allow_split_physical_axes – Whether physical device axes can be split across logical mesh axes.
backend – JAX backend (‘cpu’, ‘gpu’, ‘tpu’). If None, uses default.
use_jax – If True, uses jax.make_mesh. If False, uses mesh_utils-based explicit device mesh creation (including multi-slice support). When True, multi-slice or multi-process topologies fall back to the mesh_utils path for better topology support.
- Returns
JAX Mesh object ready for use with pjit and sharding specifications.
Example
>>> >>> mesh = create_mesh( ... axis_dims=(2, 4), ... axis_names=('data', 'model') ... ) >>> >>> with mesh: ... sharded_fn = pjit(fn, in_shardings=..., out_shardings=...)
- eformer.escale.mesh.creation.force_cpu()[source]#
Context manager that forces JAX operations to run on CPU.
Temporarily sets the default JAX device to CPU for all operations within the context. Useful for debugging or when specific operations need to run on CPU.
- Yields
The CPU device being used.
Example
>>> with force_cpu() as cpu_device: ... ... result = jax.numpy.sum(array) ... print(f"Running on {cpu_device}")
Note
Device setting is restored when exiting the context.
- eformer.escale.mesh.creation.parse_mesh_from_string(axis_dims: str, names: Sequence[str]) Mesh[source]#
Parse mesh configuration from string representation.
Supports two formats: 1. Named format: “dp:2,tp:4” - explicitly maps names to dimensions 2. Positional format: “2,4” - maps dimensions to names by position
- Parameters
axis_dims – String representation of axis dimensions. Either: - Named: “name1:dim1,name2:dim2,…” (e.g., “dp:2,tp:4”) - Positional: “dim1,dim2,…” (e.g., “2,4”)
names – Sequence of axis names that should appear in the mesh.
- Returns
JAX Mesh configured according to the string specification.
- Raises
ValueError – If axis names don’t match, dimensions and names have different lengths, or unknown axis names are used.
Example
>>> >>> mesh = parse_mesh_from_string("dp:2,tp:4", ["dp", "tp"]) >>> >>> >>> mesh = parse_mesh_from_string("2,4", ["data", "model"])