eformer.executor.cluster_util#

Cluster utilities for distributed execution with Ray and JAX.

This module provides utilities for managing distributed clusters, particularly focused on SLURM environments and Ray cluster initialization. It handles automatic discovery of cluster topology, coordinator selection, and proper resource allocation across distributed nodes.

Note

This implementation is adapted from the Levanter project (stanford-crfm/levanter) with modifications for the eFormer/EasyDeL framework.

class eformer.executor.cluster_util.DistributedConfig(coordinator_address: str | None = None, num_processes: int | None = None, process_id: int | None = None, local_device_ids: int | list[int] | None = None)[source]#

Bases: object

Configuration for distributed JAX execution.

Encapsulates all settings needed to initialize JAX in a distributed environment, with automatic detection of SLURM clusters.

coordinator_address#

Address of the coordinator process. Auto-detected from SLURM if None.

Type

str | None

num_processes#

Total number of processes in the cluster.

Type

int | None

process_id#

ID of this process (0-indexed).

Type

int | None

local_device_ids#

Device IDs to use on this node. Auto-detected from SLURM if None.

Type

int | list[int] | None

coordinator_address: str | None = None#
initialize()[source]#

Initialize JAX distributed execution.

Sets up JAX for distributed execution based on the configuration, with automatic detection of SLURM cluster settings if needed.

Note

If no distributed configuration is provided and no cluster is detected, JAX will run in single-process mode.

local_device_ids: int | list[int] | None = None#
num_processes: int | None = None#
process_id: int | None = None#
class eformer.executor.cluster_util.RayClusterConfig(address: str | None = None, start_workers: bool = True, auto_start_cluster: bool = True)[source]#

Bases: object

Configuration for Ray cluster initialization.

Controls how Ray clusters are started and connected to.

address#

Ray cluster address. If None, uses auto-discovery.

Type

str | None

start_workers#

Whether to start Ray workers on non-head nodes. Defaults to True.

Type

bool

auto_start_cluster#

Whether to automatically start the Ray cluster. Defaults to True.

Type

bool

address: str | None = None#
auto_start_cluster: bool = True#
initialize()[source]#

Initialize the Ray cluster based on configuration.

Calls auto_ray_cluster() with the configured settings if auto_start_cluster is True.

start_workers: bool = True#
eformer.executor.cluster_util.auto_ray_cluster(address: str | None = None, namespace: str | None = 'eray-executor', start_workers: bool = True, fail_if_cluster_already_initialized: bool = False, **kwargs)[source]#

Automatically initialize a Ray cluster.

Handles automatic discovery and initialization of Ray clusters in various environments. Can start both head and worker nodes as needed, with special support for SLURM clusters.

Parameters
  • address (str | None) – Ray cluster address. If None, attempts auto-discovery from RAY_ADDRESS environment variable or JAX coordinator.

  • namespace (str | None) – Ray namespace to use. Defaults to “eray-executor”.

  • start_workers (bool) – Whether to start Ray workers on non-head nodes. Defaults to True.

  • fail_if_cluster_already_initialized (bool) – Whether to fail if a cluster is already running. Defaults to False.

  • **kwargs – Additional arguments passed to ray.init().

Raises

RuntimeError – If Ray head or worker fails to start.

Note

This function can only be called once per process. Subsequent calls will be ignored with a warning.

class eformer.executor.cluster_util.eSlurmCluster[source]#

Bases: SlurmCluster

Extended SLURM cluster implementation for Ray executor.

This class extends JAX’s SlurmCluster to provide additional functionality for Ray-based distributed execution in SLURM environments. It handles automatic coordinator address discovery, device ID assignment, and local process counting.

Inherits all attributes from jax.clusters.SlurmCluster.
classmethod get_coordinator_address() str[source]#

Get the coordinator address for the SLURM cluster.

Automatically determines the coordinator node and port based on SLURM environment variables. The coordinator is typically the first node in the node list.

Returns

The coordinator address in the format “hostname:port”.

Return type

str

Raises

ValueError – If node list cannot be found in environment variables.

classmethod get_local_device_ids_for_process() list[int] | None[source]#

Get the device IDs assigned to the current local process.

Determines which CUDA devices should be used by this process based on the SLURM task configuration and CUDA_VISIBLE_DEVICES.

Returns

List of device IDs for this process, or None

if device assignment cannot be determined.

Return type

list[int] | None

Raises

ValueError – If the number of visible devices is not evenly divisible by the number of local tasks.

eformer.executor.cluster_util.logical_cpu_core_count()[source]#

Get the number of logical CPU cores available to the process.

First checks SLURM environment variables, then falls back to OS CPU count.

Returns

Number of logical CPU cores available.

Return type

int