eformer.executor.cluster_util#
Cluster utilities for distributed execution with Ray and JAX.
This module provides utilities for managing distributed clusters, particularly focused on SLURM environments and Ray cluster initialization. It handles automatic discovery of cluster topology, coordinator selection, and proper resource allocation across distributed nodes.
Note
This implementation is adapted from the Levanter project (stanford-crfm/levanter) with modifications for the eFormer/EasyDeL framework.
- class eformer.executor.cluster_util.DistributedConfig(coordinator_address: str | None = None, num_processes: int | None = None, process_id: int | None = None, local_device_ids: int | list[int] | None = None)[source]#
Bases:
objectConfiguration for distributed JAX execution.
Encapsulates all settings needed to initialize JAX in a distributed environment, with automatic detection of SLURM clusters.
- coordinator_address#
Address of the coordinator process. Auto-detected from SLURM if None.
- Type
str | None
- num_processes#
Total number of processes in the cluster.
- Type
int | None
- process_id#
ID of this process (0-indexed).
- Type
int | None
- local_device_ids#
Device IDs to use on this node. Auto-detected from SLURM if None.
- Type
int | list[int] | None
- coordinator_address: str | None = None#
- initialize()[source]#
Initialize JAX distributed execution.
Sets up JAX for distributed execution based on the configuration, with automatic detection of SLURM cluster settings if needed.
Note
If no distributed configuration is provided and no cluster is detected, JAX will run in single-process mode.
- local_device_ids: int | list[int] | None = None#
- num_processes: int | None = None#
- process_id: int | None = None#
- class eformer.executor.cluster_util.RayClusterConfig(address: str | None = None, start_workers: bool = True, auto_start_cluster: bool = True)[source]#
Bases:
objectConfiguration for Ray cluster initialization.
Controls how Ray clusters are started and connected to.
- address#
Ray cluster address. If None, uses auto-discovery.
- Type
str | None
- start_workers#
Whether to start Ray workers on non-head nodes. Defaults to True.
- Type
bool
- auto_start_cluster#
Whether to automatically start the Ray cluster. Defaults to True.
- Type
bool
- address: str | None = None#
- auto_start_cluster: bool = True#
- initialize()[source]#
Initialize the Ray cluster based on configuration.
Calls auto_ray_cluster() with the configured settings if auto_start_cluster is True.
- start_workers: bool = True#
- eformer.executor.cluster_util.auto_ray_cluster(address: str | None = None, namespace: str | None = 'eray-executor', start_workers: bool = True, fail_if_cluster_already_initialized: bool = False, **kwargs)[source]#
Automatically initialize a Ray cluster.
Handles automatic discovery and initialization of Ray clusters in various environments. Can start both head and worker nodes as needed, with special support for SLURM clusters.
- Parameters
address (str | None) – Ray cluster address. If None, attempts auto-discovery from RAY_ADDRESS environment variable or JAX coordinator.
namespace (str | None) – Ray namespace to use. Defaults to “eray-executor”.
start_workers (bool) – Whether to start Ray workers on non-head nodes. Defaults to True.
fail_if_cluster_already_initialized (bool) – Whether to fail if a cluster is already running. Defaults to False.
**kwargs – Additional arguments passed to ray.init().
- Raises
RuntimeError – If Ray head or worker fails to start.
Note
This function can only be called once per process. Subsequent calls will be ignored with a warning.
- class eformer.executor.cluster_util.eSlurmCluster[source]#
Bases:
SlurmClusterExtended SLURM cluster implementation for Ray executor.
This class extends JAX’s SlurmCluster to provide additional functionality for Ray-based distributed execution in SLURM environments. It handles automatic coordinator address discovery, device ID assignment, and local process counting.
- Inherits all attributes from jax.clusters.SlurmCluster.
- classmethod get_coordinator_address() str[source]#
Get the coordinator address for the SLURM cluster.
Automatically determines the coordinator node and port based on SLURM environment variables. The coordinator is typically the first node in the node list.
- Returns
The coordinator address in the format “hostname:port”.
- Return type
str
- Raises
ValueError – If node list cannot be found in environment variables.
- classmethod get_local_device_ids_for_process() list[int] | None[source]#
Get the device IDs assigned to the current local process.
Determines which CUDA devices should be used by this process based on the SLURM task configuration and CUDA_VISIBLE_DEVICES.
- Returns
- List of device IDs for this process, or None
if device assignment cannot be determined.
- Return type
list[int] | None
- Raises
ValueError – If the number of visible devices is not evenly divisible by the number of local tasks.