eformer.executor.ray.executor#
Ray-based executor for distributed machine learning workloads.
This module provides the core execution framework for running distributed workloads on accelerators (TPUs, GPUs) using Ray. It supports single-pod, multi-slice, and fault-tolerant execution patterns with automatic retry mechanisms.
- Key Features:
Single-pod and multi-slice execution on TPUs/GPUs
Automatic retry mechanisms for preemption and failures
Resource management and allocation via Ray
Support for both synchronous and asynchronous execution
Decorator-based API for easy integration
MegaScale coordination for multi-slice TPU workloads
Flexible result flattening for multi-host scenarios
- Environment Variables:
EXECUTOR_CALL_INDEX: Set to worker index within a pod (0-based)
EXECUTOR_CALL_SLICE: Set to slice ID for multi-slice execution
COORD_PORT: Coordinator port for MegaScale (default: 8192)
TPU_NAME, TPU_VERSION, TPU_ZONE: TPU configuration passed to workers
MEGASCALE_* variables: Auto-configured for multi-slice coordination
Example
Basic single-pod execution:
>>> import ray
>>> from eformer.executor.ray import RayExecutor, TpuAcceleratorConfig
>>>
>>> @ray.remote
>>> def train_model(data):
...
... return trained_model
>>>
>>> tpu_config = TpuAcceleratorConfig(type="v4-8")
>>> result = RayExecutor.execute_resumable(
... train_model,
... tpu_config,
... max_retries_preemption=10,
... max_retries_failure=3
... )
Multi-slice execution with decorator:
>>> from eformer.executor.ray import autoscale_execute_resumable
>>>
>>> @autoscale_execute_resumable(tpu_config)
>>> @ray.remote
>>> def distributed_train(slice_data):
...
... return slice_results
>>>
>>> results = distributed_train(training_data)
- Classes:
RayExecutor: Core executor with static methods for various execution patterns
- Functions:
execute: Decorator for single-pod execution without retry execute_resumable: Decorator for single-pod execution with retry autoscale_execute: Decorator for multi-slice execution without retry autoscale_execute_resumable: Decorator for multi-slice execution with retry
- class eformer.executor.ray.executor.RayExecutor[source]#
Bases:
objectCore executor for Ray-based distributed workloads.
Provides static methods to execute Ray remote functions on various accelerators (TPUs, GPUs) with support for single-pod, multi-slice, and fault-tolerant execution patterns.
This class serves as the main interface for running distributed ML workloads with automatic resource allocation, retry mechanisms, and failure handling.
- All methods return JobStatus objects that encapsulate:
JobSucceeded: Successful completion with results
JobFailed: Failure due to exceptions
JobPreempted: Preemption on preemptible resources
JobError: Unexpected errors
Note
All methods are static and can be called directly on the class. The class does not maintain state between executions.
- static autoscale_execute(remote_fn: RemoteFunction, accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig, flatten: bool = True, **kwargs) JobStatus[source]#
Execute a Ray remote function across multiple TPU slices.
Distributes execution of a remote function across multiple TPU slices for large-scale parallel processing. This method sets up the necessary infrastructure including slice actors, placement groups, and MegaScale coordination environment variables.
- Parameters
remote_fn (RemoteFunction) – The Ray remote function to execute on each slice. Must be decorated with @ray.remote.
accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources, must include multi-slice details (pod_count > 1).
flatten (bool) – If True (default), returns a flat list of results from all hosts across all slices. If False, returns nested lists where outer list represents slices and inner lists contain results from hosts within each slice.
**kwargs – Additional keyword arguments passed to the remote function on each slice.
- Returns
- A single JobStatus object containing results from all slices.
JobSucceeded: Contains results list (flat or nested based on flatten)
JobFailed: Contains the exception that caused the failure
JobPreempted: Contains preemption error details
JobError: Contains unexpected error information
- Return type
- Raises
InsufficientSlicesError – If requested number of slices cannot be allocated.
RayError – If slice actor creation fails, coordinator IP cannot be determined, or remote function calls fail.
RuntimeError – If no SliceActors available after scaling or coordinator IP cannot be determined.
Note
The method automatically sets up MegaScale environment variables for multi-slice coordination including coordinator address, slice IDs, and port configuration.
Each slice gets its own SliceActor which manages multiple DeviceHostActors.
The pool manager is automatically drained after execution completes or if an error occurs.
Environment variables set include: MEGASCALE_COORDINATOR_ADDRESS, MEGASCALE_NUM_SLICES, MEGASCALE_PORT, MEGASCALE_SLICE_ID, TPU_SLICE_NAME, and more.
Example
>>> @ray.remote >>> def train_on_slice(data, slice_id): ... ... return model_weights >>> >>> tpu_config = TpuAcceleratorConfig(type="v4-32", pod_count=4) >>> >>> >>> job_status = RayExecutor.autoscale_execute( ... train_on_slice, ... tpu_config, ... data=training_data ... ) >>> if isinstance(job_status, JobSucceeded): ... flat_results = job_status.result >>> >>> >>> job_status = RayExecutor.autoscale_execute( ... train_on_slice, ... tpu_config, ... flatten=False, ... data=training_data ... ) >>> if isinstance(job_status, JobSucceeded): ... results_by_slice = job_status.result
- classmethod autoscale_execute_resumable(remote_fn: RemoteFunction, accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig, max_retries_preemption: int = 1000000, max_retries_failure: int = 10, **kwargs)[source]#
Execute a multi-slice function with automatic retry on failures.
Provides fault-tolerant execution of Ray remote functions across multiple TPU slices with coordinated retry mechanisms. All slices must succeed for the execution to be considered successful.
- Parameters
remote_fn (RemoteFunction) – The Ray remote function to execute on each slice. Must be decorated with @ray.remote.
accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources with multi-slice support (pod_count > 1).
max_retries_preemption (int) – Maximum number of retries when any slice is preempted. Defaults to 1,000,000.
max_retries_failure (int) – Maximum number of retries when any slice fails. Defaults to 10.
**kwargs – Additional keyword arguments passed to the remote function on each slice. The ‘flatten’ parameter can be used to control result structure.
- Returns
- List of results from successful execution on all slices.
The structure depends on the flatten parameter passed in kwargs: - If flatten=True (default): Flat list of all results - If flatten=False: List of lists, one per slice
- Return type
list[Any]
- Raises
RuntimeError – If any slice is preempted more than max_retries_preemption times, fails more than max_retries_failure times, or if autoscale_execute returns None or unexpected result type.
RayError – If autoscale_execute fails during setup or coordination (slice actor creation, placement group setup, etc.).
ray.exceptions.RayTaskError – Re-raised if it occurs and indicates preemption or failure after max retries.
Exception – The last encountered exception if retries are exhausted.
Note
Implements an all-or-nothing retry policy: if any slice fails or is preempted, the entire multi-slice execution is retried
Different error types are handled with appropriate retry logic: * RayTaskError/RayError with “preempted” -> preemption counter * Other errors -> failure counter
Each retry attempt creates new slice actors and placement groups
Detailed logging tracks retry attempts and error types
Example
>>> @ray.remote >>> def distributed_training(data_shard): ... ... return trained_weights >>> >>> tpu_config = TpuAcceleratorConfig(type="v4-32", pod_count=4) >>> >>> >>> results = RayExecutor.autoscale_execute_resumable( ... distributed_training, ... tpu_config, ... max_retries_preemption=50, ... max_retries_failure=3, ... data_shard=sharded_data ... ) >>> >>> >>> results_by_slice = RayExecutor.autoscale_execute_resumable( ... distributed_training, ... tpu_config, ... max_retries_preemption=50, ... max_retries_failure=3, ... data_shard=sharded_data, ... flatten=False ... )
- static execute(remote_fn: RemoteFunction, accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig, **kwargs)[source]#
Execute a Ray remote function on a single pod or slice.
Runs a Ray remote function on a single accelerator pod (TPU/GPU) with the specified resource configuration. For multi-slice TPU workloads, use autoscale_execute instead.
- Parameters
remote_fn (RemoteFunction) – The Ray remote function to execute. Must be decorated with @ray.remote.
accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources (TPU, GPU, or CPU).
**kwargs – Additional keyword arguments passed to the remote function.
- Returns
actual result.
- Return type
ray.JobStatus
- Raises
ValueError – If pod_count in accelerator_config is not 1, indicating that autoscale_execute should be used instead.
Example
>>> @ray.remote >>> def compute(x): ... return x * 2 >>> >>> config = GpuAcceleratorConfig(count=1, type="v100") >>> result = RayExecutor.execute(compute, config, x=10)
- classmethod execute_resumable(remote_fn: RemoteFunction, accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig, max_retries_preemption: int = 1000000, max_retries_failure: int = 10, **kwargs)[source]#
Execute a remote function with automatic retry on failures.
Provides fault-tolerant execution of Ray remote functions with configurable retry policies for both preemptions and failures. Particularly useful for long-running jobs on preemptible resources.
- Parameters
remote_fn (RemoteFunction) – The Ray remote function to execute. Must be decorated with @ray.remote.
accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources.
max_retries_preemption (int) – Maximum number of retries on preemption. Defaults to 1,000,000 (effectively unlimited).
max_retries_failure (int) – Maximum number of retries on failure. Defaults to 10.
**kwargs – Additional keyword arguments passed to the remote function.
- Returns
- The result from successful execution of the remote function.
The actual return type depends on what the remote function returns.
- Return type
Any
- Raises
RuntimeError – If the job is preempted more than max_retries_preemption times or fails more than max_retries_failure times. The error message indicates whether it was due to preemptions or failures.
ray.exceptions.RayTaskError – Re-raised if it occurs and is not preemption-related after max retries.
Exception – The last encountered exception if all retries are exhausted.
Note
Preemptions and failures are tracked separately
Each attempt logs status information for debugging
The method distinguishes between preemption (often recoverable) and failures (may indicate code issues)
RayTaskErrors containing “preempted” are treated as preemptions
Example
>>> @ray.remote >>> def long_running_task(data): ... ... return process(data) >>> >>> config = TpuAcceleratorConfig(type="v4-8", preemptible=True) >>> result = RayExecutor.execute_resumable( ... long_running_task, ... config, ... max_retries_preemption=100, ... max_retries_failure=5, ... data=my_data ... )
- class eformer.executor.ray.executor.TpuRemoteManager(tpu_version: str, pod_count: int, *, base_env: dict[str, str] | None = None, runtime_env: dict | None = None, coord_port: int = 8081)[source]#
Bases:
objectSession-based manager for TPU remote execution across multiple slices.
Provides a high-level interface for managing TPU resources and executing functions or class methods across multiple TPU slices. Each execution runs in a short-lived worker process to avoid TPU handle leaks.
Key features: - Automatic slice scaling and warming - Persistent DeviceHostActors for efficient execution - Broadcasting of function/method calls across all hosts - Ephemeral worker processes for each execution - MegaScale coordination for multi-slice workloads
- tpu_version#
TPU version string (e.g., ‘v4’, ‘v5p’)
- pod_count#
Number of TPU pods/slices to use
- base_env#
Base environment variables for all workers
- runtime_env#
Ray runtime environment configuration
- coord_port#
Port for MegaScale coordinator (default 8081)
Example
>>> manager = TpuRemoteManager( ... tpu_version='v4', ... pod_count=2, ... base_env={'MY_VAR': 'value'} ... ) >>> manager.ensure() >>> results = manager.run_function(my_func, arg1, arg2) >>> manager.close()
- close()[source]#
Clean up all TPU resources and reset state.
Performs cleanup in the following order: 1. Cancels any current work on all host actors 2. Drains the actor pool (terminates all actors) 3. Resets internal state
Note
This method is safe to call multiple times and will suppress any errors during cleanup to ensure resources are freed.
- ensure()[source]#
Initialize and prepare TPU resources if not already done.
This method: 1. Scales the slice pool to the requested pod_count 2. Prepares all slices (creates placement groups, etc.) 3. Ensures DeviceHostActors are created on each host 4. Builds per-slice environment variables including MegaScale config 5. Caches host actor handles for efficient execution
- Raises
RuntimeError – If no SliceActors are available after scaling.
Note
This method is idempotent - calling it multiple times has no additional effect after the first successful call.
- run_class_method(cls_obj: type, init_args: tuple, init_kwargs: dict, method_name: str, *call_args, flatten: bool = True, **call_kwargs)[source]#
Execute a class method on every TPU host across all slices.
Instantiates the class and calls the specified method in ephemeral worker processes on each host. The class is reconstructed inside each worker to ensure clean TPU state.
- Parameters
cls_obj – The class object (not an instance). Must be pickleable.
init_args – Positional arguments for class initialization.
init_kwargs – Keyword arguments for class initialization.
method_name – Name of the method to call on the instantiated object.
*call_args – Positional arguments for the method call.
flatten – If True, returns a flat list of results from all hosts. If False, returns nested lists by slice.
**call_kwargs – Keyword arguments for the method call.
- Returns
List of results from all hosts (flattened). If flatten=False: List of lists, where each inner list contains
results from hosts within a slice.
- Return type
If flatten=True
Example
>>> results = manager.run_class_method( ... MyModel, ... init_args=(config,), ... init_kwargs={'checkpoint': 'path/to/ckpt'}, ... method_name='train', ... epochs=10, ... batch_size=32 ... )
Note
The class is instantiated fresh in each worker process, avoiding TPU handle leaks from persistent objects.
- run_function(fn: Callable, *args, flatten: bool = True, **kwargs)[source]#
Execute a Python function on every TPU host across all slices.
Runs the provided function in ephemeral worker processes on each host to avoid TPU handle leaks. The function is executed with the same arguments on every host.
- Parameters
fn – Python callable to execute on each host. Should be pickleable.
*args – Positional arguments to pass to the function.
flatten – If True, returns a flat list of results from all hosts. If False, returns nested lists where outer list represents slices and inner lists contain results from hosts within each slice.
**kwargs – Keyword arguments to pass to the function.
- Returns
List of results from all hosts (flattened). If flatten=False: List of lists, where each inner list contains
results from hosts within a slice.
- Return type
If flatten=True
Note
Each execution runs in a new process to avoid TPU handle leaks. The function must be pickleable for Ray serialization.
- eformer.executor.ray.executor.autoscale_execute(accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig)[source]#
Decorator for multi-slice execution without retry.
Wraps a Ray remote function to automatically use RayExecutor.autoscale_execute with the specified accelerator configuration. Results from all slices are automatically retrieved with ray.get(). The function will be executed across multiple TPU slices in parallel, with MegaScale coordination automatically configured.
- Parameters
accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources with multi-slice support. Must have pod_count > 1.
- Returns
- Decorator function that wraps the remote function and returns
a list of results, one from each slice.
- Return type
Callable
Note
The decorator handles slice actor creation, placement group setup, and MegaScale environment configuration automatically.
Example
>>> tpu_config = TpuAcceleratorConfig(type="v4-32", pod_count=4) >>> >>> @autoscale_execute(tpu_config) >>> @ray.remote >>> def parallel_compute(data_shard): ... return compute_result(data_shard) >>> >>> results = parallel_compute(sharded_data)
- eformer.executor.ray.executor.autoscale_execute_resumable(accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig)[source]#
Decorator for fault-tolerant multi-slice execution.
Wraps a Ray remote function to automatically use RayExecutor.autoscale_execute_resumable with the specified accelerator configuration. Provides automatic retry on preemption or failure of any slice. Uses an all-or-nothing retry policy: if any slice fails, the entire multi-slice execution is retried.
- Parameters
accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources with multi-slice support. Must have pod_count > 1.
- Returns
- Decorator function that wraps the remote function and adds
automatic retry logic for all slices.
- Return type
Callable
Note
Default retry limits are 1,000,000 for preemptions and 10 for failures. To customize these limits, use RayExecutor.autoscale_execute_resumable directly with max_retries_preemption and max_retries_failure parameters.
Example
>>> tpu_config = TpuAcceleratorConfig(type="v4-32", pod_count=4, preemptible=True) >>> >>> @autoscale_execute_resumable(tpu_config) >>> @ray.remote >>> def resilient_training(data_batch): ... ... return train_model(data_batch) >>> >>> results = resilient_training(training_data)
- eformer.executor.ray.executor.device_remote(*, accelerator_config: TpuAcceleratorConfig, flatten: bool = True)[source]#
Decorator for TPU-remote execution of functions or classes.
Transforms a regular Python function or class into a TPU-remote version that automatically broadcasts execution across all TPU hosts in the specified configuration. Each execution runs in an ephemeral worker process to avoid TPU handle leaks.
- Parameters
accelerator_config – TPU configuration specifying version, pod count, and other execution parameters.
flatten – If True (default), returns flat list of results from all hosts. If False, returns nested lists where outer list represents slices and inner lists contain results from hosts within each slice.
- Returns
Decorator function that wraps the target function or class.
- Example for functions:
>>> @device_remote(accelerator_config=tpu_config) >>> def compute(x, y): ... return jax.numpy.dot(x, y) >>> >>> >>> results = compute(array1, array2) >>> >>> >>> refs = compute.remote(array1, array2) >>> results = ray.get(refs)
- Example for classes:
>>> @tpu_remote(accelerator_config=tpu_config) >>> class Model: ... def __init__(self, config): ... self.config = config ... ... def train(self, data): ... ... return metrics >>> >>> model = Model(config) >>> metrics = model.train(data) >>> refs = model.train.remote(data)
Note
Functions/classes must be pickleable for Ray serialization
Each method call creates new worker processes on TPU hosts
The decorator manages slice scaling and resource allocation
Call model.close() to clean up resources when done
- eformer.executor.ray.executor.execute(accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig)[source]#
Decorator for single-pod execution without retry.
Wraps a Ray remote function to automatically use RayExecutor.execute with the specified accelerator configuration. Results are automatically retrieved with ray.get(). This decorator is suitable for tasks that don’t require fault tolerance or where failures should be handled by the caller.
- Parameters
accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources to use for execution. Should have pod_count=1 for single-pod execution.
- Returns
- Decorator function that wraps the remote function and
automatically retrieves results.
- Return type
Callable
Note
Unlike execute_resumable, this decorator does not retry on failure. Use this for quick tasks or when you want to handle failures yourself.
Example
>>> gpu_config = GpuAcceleratorConfig(count=2, type="a100") >>> >>> @execute(gpu_config) >>> @ray.remote >>> def gpu_task(tensor): ... return tensor.cuda() * 2 >>> >>> result = gpu_task(my_tensor)
- eformer.executor.ray.executor.execute_resumable(accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig)[source]#
Decorator for fault-tolerant single-pod execution.
Wraps a Ray remote function to automatically use RayExecutor.execute_resumable with the specified accelerator configuration. The decorated function will automatically retry on preemption or failure according to the default retry policies (1,000,000 retries for preemption, 10 for failures).
- Parameters
accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources to use for execution. Should have pod_count=1 for single-pod execution.
- Returns
- Decorator function that wraps the remote function and adds
automatic retry logic.
- Return type
Callable
Note
To customize retry behavior, use RayExecutor.execute_resumable directly with max_retries_preemption and max_retries_failure parameters.
Example
>>> tpu_config = TpuAcceleratorConfig(type="v4-8") >>> >>> @execute_resumable(tpu_config) >>> @ray.remote >>> def my_task(data): ... return process(data) >>> >>> result = my_task(input_data)
- eformer.executor.ray.executor.resolve_maybe_refs(items)[source]#
Resolve Ray ObjectRefs to their values if present.
Checks if all items in the provided list are Ray ObjectRefs and resolves them using ray.get(). If the items are not ObjectRefs or the list is empty, returns them unchanged.
- Parameters
items – List of items that may be Ray ObjectRefs or regular values.
- Returns
List of resolved values if input contained ObjectRefs, otherwise the original items unchanged.
Example
>>> refs = [task.remote() for task in tasks] >>> results = resolve_maybe_refs(refs) >>> >>> values = [1, 2, 3] >>> results = resolve_maybe_refs(values)