# Copyright 2026 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX mesh creation utilities for distributed computation.
This module provides utilities for creating and managing JAX meshes for distributed
computation across multiple devices. It supports various parallelism strategies including:
- Data Parallelism (dp): Replicates model across devices, splits data
- Fully Sharded Data Parallelism (fsdp): Shards both model and data across devices
- Expert Parallelism (ep): For mixture-of-experts models
- Tensor Parallelism (tp): Splits individual tensors across devices
- Sequence Parallelism (sp): Splits sequence dimension across devices
Key Features:
- Automatic device mesh creation for single/multi-host setups
- Support for TPU slices and multi-process environments
- CPU-specific utilities for debugging and testing
- String-based mesh configuration parsing
- Caching for efficient mesh reuse
Typical Usage:
>>>
>>> mesh = create_mesh(
... axis_dims=(2, 4),
... axis_names=('data', 'model')
... )
>>> with mesh:
...
... pass
>>>
>>> mesh = parse_mesh_from_string("dp:2,tp:4", ["dp", "tp"])
>>>
>>> with cpu_context() as mesh:
...
... pass
"""
import functools
import os
import typing as tp
import contextlib2
import jax
import numpy as np
from jax.experimental.mesh_utils import create_device_mesh, create_hybrid_device_mesh
from jax.sharding import AxisType, Mesh
DEFAULT_SHARDING_STG = (1, -1, 1, 1, 1)
DEFAULT_NAMED_SHARDING_STG = ("dp", "fsdp", "ep", "tp", "sp")
_AXIS_TYPE_BY_NAME = {
"auto": AxisType.Auto,
"explicit": AxisType.Explicit,
"manual": AxisType.Manual,
}
def _get_num_slices(devices: tp.Sequence[tp.Any]) -> int:
"""Determine the number of slices in a multi-slice TPU configuration.
Inspects device objects for slice_index attributes (indicating TPU pod
slices) and falls back to the MEGASCALE_NUM_SLICES environment variable.
Args:
devices: Sequence of JAX device objects to inspect.
Returns:
The number of distinct slices detected, or 1 for single-slice setups.
"""
num_slices = 1
if devices and hasattr(devices[0], "slice_index"):
try:
num_slices = len({d.slice_index for d in devices})
except Exception:
pass
if num_slices == 1:
num_slices = int(os.environ.get("MEGASCALE_NUM_SLICES", num_slices))
return num_slices
def _normalize_axis_types(
axis_names: tp.Sequence[str],
axis_types: tp.Sequence[AxisType | str] | AxisType | str | None,
) -> tuple[AxisType, ...] | None:
"""Normalize axis types to a tuple of AxisType enums.
Converts various input formats (strings, single values, sequences) into
a consistent tuple of AxisType enum values that can be passed to mesh
creation functions.
Args:
axis_names: Sequence of axis names to determine required length.
axis_types: Axis type specification in one of these formats:
- None: Returns None (use defaults)
- Single AxisType or string: Applied to all axes
- Sequence of AxisType/strings: One per axis name
Returns:
Tuple of AxisType enums with same length as axis_names,
or None if input was None.
Raises:
ValueError: If string values are not valid axis type names or
if sequence length doesn't match axis_names length.
TypeError: If axis_types contains invalid types.
"""
if axis_types is None:
return None
if isinstance(axis_types, (AxisType, str)):
axis_types_seq = (axis_types,) * len(axis_names)
else:
axis_types_seq = tuple(axis_types)
if len(axis_types_seq) == 1 and len(axis_names) > 1:
axis_types_seq = axis_types_seq * len(axis_names)
normalized = []
for axis_type in axis_types_seq:
if isinstance(axis_type, str):
key = axis_type.strip().lower()
if key not in _AXIS_TYPE_BY_NAME:
raise ValueError(
f"axis_types must be one of {{'auto', 'explicit', 'manual'}} or AxisType, got {axis_type!r}."
)
normalized.append(_AXIS_TYPE_BY_NAME[key])
elif isinstance(axis_type, AxisType):
normalized.append(axis_type)
else:
raise TypeError(f"axis_types entries must be strings or AxisType values, got {type(axis_type)}.")
if len(normalized) != len(axis_names):
raise ValueError(
"axis_types length must match axis_names length. "
f"Got {len(normalized)} types for {len(axis_names)} axis names."
)
return tuple(normalized)
[docs]def calculate_host_mesh_shape(
global_mesh_shape: tp.Sequence[int],
total_devices: int | None = None,
num_processes: int | None = None,
):
"""Calculate the mesh shape for the local host in a distributed setting.
Determines how to split a global mesh shape across multiple processes,
ensuring each host gets an appropriate portion of the mesh.
Args:
global_mesh_shape: The desired global mesh shape across all processes.
total_devices: Total number of devices on this host. If None, uses
jax.local_device_count().
num_processes: Total number of processes in the distributed setup.
If None, uses jax.process_count().
Returns:
Tuple representing the mesh shape for this host.
Raises:
ValueError: If mesh size doesn't match available devices or if
the calculated host mesh doesn't use the correct number of devices.
Example:
>>>
>>> calculate_host_mesh_shape((2, 4), total_devices=4, num_processes=2)
(1, 4)
"""
total_devices = total_devices or jax.local_device_count()
num_processes = num_processes or jax.process_count()
total_mesh_size = int(np.prod(global_mesh_shape))
if total_mesh_size != total_devices * num_processes:
raise ValueError(
f"Mesh size {total_mesh_size} doesn't match available devices "
f"{total_devices * num_processes} (local x processes)"
)
host_mesh = list(global_mesh_shape)
remaining_process_split = num_processes
idx = 0
while remaining_process_split > 1 and idx < len(host_mesh):
dim_size = host_mesh[idx]
if dim_size >= remaining_process_split:
factor = remaining_process_split
host_mesh[idx] = dim_size // factor
remaining_process_split = 1
else:
factor = dim_size
host_mesh[idx] = 1
remaining_process_split = remaining_process_split // factor
idx += 1
host_total = int(np.prod(host_mesh))
if host_total != total_devices:
raise ValueError(
f"Host mesh shape {tuple(host_mesh)} uses {host_total} devices instead of {total_devices}. "
"Ensure that num_processes factors the global mesh shape."
)
return tuple(host_mesh)
def _cached_mesh(
axis_dims: tp.Sequence[int],
axis_names: tp.Sequence[str],
axis_types: tp.Sequence[AxisType] | None = None,
dcn_mesh_dims: tp.Sequence[int] | None = None,
should_sort_granules_by_key: bool = True,
allow_split_physical_axes: bool = True,
backend: str | None = None,
):
"""Wrapper that normalizes arguments and feeds the cached implementation.
This function converts sequences to tuples for hashability and delegates
to the cached implementation. The caching ensures that identical mesh
configurations reuse the same mesh object, improving performance.
Args:
axis_dims: Dimensions for each mesh axis
axis_names: Names for each mesh axis
axis_types: Optional mesh axis types for explicit/auto/manual sharding
dcn_mesh_dims: Data center network mesh dimensions
should_sort_granules_by_key: Whether to sort device granules
allow_split_physical_axes: Whether to allow splitting physical axes
backend: JAX backend to use
Returns:
Cached JAX Mesh object
"""
axis_dims_t = tuple(axis_dims)
axis_names_t = tuple(axis_names)
axis_types_t = None if axis_types is None else tuple(axis_types)
dcn_mesh_dims_t = None if dcn_mesh_dims is None else tuple(dcn_mesh_dims)
backend_s = backend or jax.default_backend()
return _cached_mesh_impl(
axis_dims=axis_dims_t,
axis_names=axis_names_t,
axis_types=axis_types_t,
dcn_mesh_dims=dcn_mesh_dims_t,
should_sort_granules_by_key=should_sort_granules_by_key,
allow_split_physical_axes=allow_split_physical_axes,
backend=backend_s,
)
@functools.cache
def _cached_mesh_impl(
axis_dims: tuple[int, ...],
axis_names: tuple[str, ...],
axis_types: tuple[AxisType, ...] | None = None,
dcn_mesh_dims: tuple[int, ...] | None = None,
should_sort_granules_by_key: bool = True,
allow_split_physical_axes: bool = True,
backend: str = "cpu",
):
"""Cached implementation of mesh creation logic.
This function handles three main scenarios:
1. Multi-slice environments (TPU pods): Creates per-slice meshes with
appropriate DCN configuration for inter-slice communication.
2. Multi-process environments: Distributes mesh across processes,
calculating DCN dimensions to map logical to physical topology.
3. Single-process environments: Creates a simple device mesh.
The function automatically detects the environment type and applies
the appropriate mesh creation strategy.
Args:
axis_dims: Tuple of dimensions for each mesh axis
axis_names: Tuple of names for each mesh axis
axis_types: Optional mesh axis types for explicit/auto/manual sharding
dcn_mesh_dims: Data center network dimensions for hybrid setups
should_sort_granules_by_key: Sort devices for consistency
allow_split_physical_axes: Allow splitting physical device axes
backend: Backend to use ('cpu', 'gpu', 'tpu')
Returns:
JAX Mesh configured for the detected environment
Raises:
ValueError: If mesh configuration is invalid for the environment
"""
devices = jax.devices(backend)
total_devices = jax.device_count(backend)
local_devices = jax.local_device_count(backend)
process_count = jax.process_count()
global_mesh_shape = np.arange(total_devices).reshape(axis_dims).shape
num_slices = _get_num_slices(devices)
def fill_minus_one_to_target(shape: tuple[int, ...], target: int) -> tuple[int, ...]:
"""Replace -1 in shape with value to match target product.
Allows using -1 as a placeholder in dcn_mesh_dims to automatically
calculate the appropriate dimension size.
Args:
shape: Shape tuple potentially containing one -1
target: Target product all dimensions should multiply to
Returns:
Shape tuple with -1 replaced by calculated value
Raises:
ValueError: If multiple -1s exist or product doesn't match target
"""
shp = list(shape)
minus = [i for i, v in enumerate(shp) if v == -1]
if len(minus) > 1:
raise ValueError("Only one -1 is supported in dcn_mesh_dims.")
prod_known = 1
for v in shp:
if v != -1:
if v <= 0:
raise ValueError(f"dcn_mesh_dims entries must be > 0 or -1, got {v}")
prod_known *= v
if minus:
if target % prod_known != 0:
raise ValueError(f"dcn_mesh_dims product ({prod_known}) does not divide target ({target}).")
shp[minus[0]] = target // prod_known
if np.prod(shp) != target:
raise ValueError(f"dcn_mesh_dims product {int(np.prod(shp))} must equal {target}; got {tuple(shp)}")
return tuple(int(v) for v in shp)
if num_slices > 1:
dynamic_axis = next((i for i, dim in enumerate(global_mesh_shape) if dim % num_slices == 0), None)
if dynamic_axis is None:
raise ValueError(
f"Multi-slice detected (num_slices={num_slices}) but no mesh axis in "
f"{global_mesh_shape} is divisible by num_slices."
)
per_slice_mesh_shape = list(global_mesh_shape)
per_slice_mesh_shape[dynamic_axis] //= num_slices
per_slice_mesh_shape = tuple(per_slice_mesh_shape)
if dcn_mesh_dims is None:
dcn_list = [1] * len(axis_dims)
dcn_list[dynamic_axis] = num_slices
dcn = tuple(dcn_list)
else:
dcn = fill_minus_one_to_target(dcn_mesh_dims, num_slices)
ndarray = create_hybrid_device_mesh(
mesh_shape=per_slice_mesh_shape,
dcn_mesh_shape=dcn,
devices=devices,
allow_split_physical_axes=allow_split_physical_axes,
process_is_granule=False,
should_sort_granules_by_key=should_sort_granules_by_key,
)
elif process_count > 1:
local_mesh_shape = calculate_host_mesh_shape(
global_mesh_shape=global_mesh_shape,
total_devices=local_devices,
num_processes=process_count,
)
if dcn_mesh_dims is None:
ratios = [int(g // le) for g, le in zip(global_mesh_shape, local_mesh_shape, strict=False)]
if np.prod(ratios) != process_count:
ratios = [1] * len(axis_dims)
for i in range(len(axis_dims)):
ratios[i] = process_count
break
dcn = tuple(ratios)
else:
dcn = fill_minus_one_to_target(dcn_mesh_dims, process_count)
ndarray = create_hybrid_device_mesh(
mesh_shape=local_mesh_shape,
dcn_mesh_shape=dcn,
devices=devices,
allow_split_physical_axes=allow_split_physical_axes,
process_is_granule=True,
should_sort_granules_by_key=should_sort_granules_by_key,
)
else:
ndarray = create_device_mesh(
mesh_shape=global_mesh_shape,
devices=devices,
allow_split_physical_axes=allow_split_physical_axes,
)
return Mesh(ndarray, axis_names, axis_types=axis_types)
[docs]def create_mesh(
axis_dims: tp.Sequence[int] = DEFAULT_SHARDING_STG,
axis_names: tp.Sequence[str] = DEFAULT_NAMED_SHARDING_STG,
dcn_mesh_dims: tp.Sequence[int] | None = None,
should_sort_granules_by_key: bool = True,
allow_split_physical_axes: bool = True,
backend: str | None = None,
use_jax: bool = False,
axis_types: tp.Sequence[AxisType | str] | AxisType | str | None | tp.Literal["auto", "explicit", "manual"] = None,
) -> Mesh:
"""Create a JAX mesh for distributed computation.
Creates a mesh that maps logical mesh axes to physical devices, supporting
various parallelism strategies including data, tensor, sequence, and pipeline
parallelism.
Args:
axis_dims: Dimensions for each mesh axis. Default is (1, -1, 1, 1, 1)
where -1 means use all remaining devices.
axis_names: Names for each axis. Default is ('dp', 'fsdp', 'ep', 'tp', 'sp')
representing data, fully-sharded data, expert, tensor, and sequence
parallelism respectively.
axis_types: Optional axis type(s) for mesh axes. Accepts AxisType values
or "auto", "explicit", "manual" strings. A single value is applied
to all axes.
dcn_mesh_dims: Data center network mesh dimensions for hybrid device setups.
If None, automatically calculated for multi-process environments.
should_sort_granules_by_key: Whether to sort device granules for consistent
ordering across processes.
allow_split_physical_axes: Whether physical device axes can be split
across logical mesh axes.
backend: JAX backend ('cpu', 'gpu', 'tpu'). If None, uses default.
use_jax: If True, uses jax.make_mesh. If False, uses mesh_utils-based
explicit device mesh creation (including multi-slice support). When
True, multi-slice or multi-process topologies fall back to the
mesh_utils path for better topology support.
Returns:
JAX Mesh object ready for use with pjit and sharding specifications.
Example:
>>>
>>> mesh = create_mesh(
... axis_dims=(2, 4),
... axis_names=('data', 'model')
... )
>>>
>>> with mesh:
... sharded_fn = pjit(fn, in_shardings=..., out_shardings=...)
"""
axis_types = _normalize_axis_types(axis_names, axis_types)
if use_jax:
devices = jax.devices(backend)
num_slices = _get_num_slices(devices)
process_count = jax.process_count()
if num_slices == 1 and process_count == 1 and dcn_mesh_dims is None:
total_devices = len(devices)
axis_dims = np.arange(total_devices).reshape(axis_dims).shape
return jax.make_mesh(
axis_shapes=axis_dims,
axis_names=axis_names,
axis_types=axis_types,
devices=devices,
)
return _cached_mesh(
axis_dims=axis_dims,
axis_names=axis_names,
axis_types=axis_types,
dcn_mesh_dims=dcn_mesh_dims,
should_sort_granules_by_key=should_sort_granules_by_key,
allow_split_physical_axes=allow_split_physical_axes,
backend=backend,
)
[docs]def parse_mesh_from_string(
axis_dims: str,
names: tp.Sequence[str],
) -> Mesh:
"""Parse mesh configuration from string representation.
Supports two formats:
1. Named format: "dp:2,tp:4" - explicitly maps names to dimensions
2. Positional format: "2,4" - maps dimensions to names by position
Args:
axis_dims: String representation of axis dimensions. Either:
- Named: "name1:dim1,name2:dim2,..." (e.g., "dp:2,tp:4")
- Positional: "dim1,dim2,..." (e.g., "2,4")
names: Sequence of axis names that should appear in the mesh.
Returns:
JAX Mesh configured according to the string specification.
Raises:
ValueError: If axis names don't match, dimensions and names have
different lengths, or unknown axis names are used.
Example:
>>>
>>> mesh = parse_mesh_from_string("dp:2,tp:4", ["dp", "tp"])
>>>
>>>
>>> mesh = parse_mesh_from_string("2,4", ["data", "model"])
"""
if ":" in axis_dims:
dims = []
dim_names = []
for axis in axis_dims.split(","):
name, dim = axis.split(":")
if name not in names:
raise ValueError(f"Axis name '{name}' not found in provided names: {names}")
dims.append(int(dim))
dim_names.append(name)
if set(dim_names) != set(names):
raise ValueError("Not all axis names were used in 'axis_dims'")
else:
dims = [int(x) for x in axis_dims.split(",")]
dim_names = list(names)
if len(dims) != len(names):
raise ValueError("Number of dimensions and names must match")
return create_mesh(tuple(dims), tuple(dim_names))
[docs]def create_cpu_mesh(
axis_dims: tp.Sequence[int] = DEFAULT_SHARDING_STG,
axis_names: tp.Sequence[str] = DEFAULT_NAMED_SHARDING_STG,
) -> Mesh:
"""Create a mesh using CPU devices.
Useful for debugging, testing, or when you want to force operations
to run on CPU regardless of available accelerators.
Args:
axis_dims: Dimensions for each mesh axis. Default is (1, -1, 1, 1, 1).
For CPU, this typically resolves to a shape matching the number of
available CPU devices.
axis_names: Names for each axis. Default is ('dp', 'fsdp', 'ep', 'tp', 'sp').
Returns:
JAX Mesh configured to use CPU device(s).
Note:
This uses all available CPU devices on the host and arranges them
according to axis_dims.
"""
return create_mesh(axis_dims=tuple(axis_dims), axis_names=tuple(axis_names), backend="cpu")
[docs]@contextlib2.contextmanager
def force_cpu():
"""Context manager that forces JAX operations to run on CPU.
Temporarily sets the default JAX device to CPU for all operations
within the context. Useful for debugging or when specific operations
need to run on CPU.
Yields:
The CPU device being used.
Example:
>>> with force_cpu() as cpu_device:
...
... result = jax.numpy.sum(array)
... print(f"Running on {cpu_device}")
Note:
Device setting is restored when exiting the context.
"""
cpu = jax.local_devices(backend="cpu")[0]
with jax.default_device(cpu):
yield cpu
[docs]@contextlib2.contextmanager
def cpu_context():
"""Context manager that provides both CPU mesh and forces CPU execution.
Combines force_cpu() and create_cpu_mesh() to provide a complete CPU
execution environment. This ensures both that operations run on CPU
and that they use a CPU-configured mesh.
Yields:
The CPU mesh created for the context.
Example:
>>> with cpu_context() as mesh:
...
... @jax.jit
... def fn(x):
... return x * 2
... result = fn(jax.numpy.ones((4, 4)))
Note:
This is particularly useful for:
- Unit testing that needs deterministic CPU behavior
- Debugging distributed code on a single machine
- Prototyping before deploying to accelerators
"""
mesh = create_cpu_mesh()
with force_cpu(), mesh:
yield mesh