eformer.executor.ray.executor#

Ray-based executor for distributed machine learning workloads.

This module provides the core execution framework for running distributed workloads on accelerators (TPUs, GPUs) using Ray. It supports single-pod, multi-slice, and fault-tolerant execution patterns with automatic retry mechanisms.

Key Features:
  • Single-pod and multi-slice execution on TPUs/GPUs

  • Automatic retry mechanisms for preemption and failures

  • Resource management and allocation via Ray

  • Support for both synchronous and asynchronous execution

  • Decorator-based API for easy integration

  • MegaScale coordination for multi-slice TPU workloads

  • Flexible result flattening for multi-host scenarios

Environment Variables:
  • EXECUTOR_CALL_INDEX: Set to worker index within a pod (0-based)

  • EXECUTOR_CALL_SLICE: Set to slice ID for multi-slice execution

  • COORD_PORT: Coordinator port for MegaScale (default: 8192)

  • TPU_NAME, TPU_VERSION, TPU_ZONE: TPU configuration passed to workers

  • MEGASCALE_* variables: Auto-configured for multi-slice coordination

Example

Basic single-pod execution:

>>> import ray
>>> from eformer.executor.ray import RayExecutor, TpuAcceleratorConfig
>>>
>>> @ray.remote
>>> def train_model(data):
...
...     return trained_model
>>>
>>> tpu_config = TpuAcceleratorConfig(type="v4-8")
>>> result = RayExecutor.execute_resumable(
...     train_model,
...     tpu_config,
...     max_retries_preemption=10,
...     max_retries_failure=3
... )

Multi-slice execution with decorator:

>>> from eformer.executor.ray import autoscale_execute_resumable
>>>
>>> @autoscale_execute_resumable(tpu_config)
>>> @ray.remote
>>> def distributed_train(slice_data):
...
...     return slice_results
>>>
>>> results = distributed_train(training_data)
Classes:

RayExecutor: Core executor with static methods for various execution patterns

Functions:

execute: Decorator for single-pod execution without retry execute_resumable: Decorator for single-pod execution with retry autoscale_execute: Decorator for multi-slice execution without retry autoscale_execute_resumable: Decorator for multi-slice execution with retry

class eformer.executor.ray.executor.RayExecutor[source]#

Bases: object

Core executor for Ray-based distributed workloads.

Provides static methods to execute Ray remote functions on various accelerators (TPUs, GPUs) with support for single-pod, multi-slice, and fault-tolerant execution patterns.

This class serves as the main interface for running distributed ML workloads with automatic resource allocation, retry mechanisms, and failure handling.

execute()[source]#

Single-pod execution without retry

autoscale_execute()[source]#

Multi-slice execution without retry

execute_resumable()[source]#

Single-pod execution with automatic retry

autoscale_execute_resumable()[source]#

Multi-slice execution with automatic retry

All methods return JobStatus objects that encapsulate:
  • JobSucceeded: Successful completion with results

  • JobFailed: Failure due to exceptions

  • JobPreempted: Preemption on preemptible resources

  • JobError: Unexpected errors

Note

All methods are static and can be called directly on the class. The class does not maintain state between executions.

static autoscale_execute(remote_fn: RemoteFunction, accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig, flatten: bool = True, **kwargs) JobStatus[source]#

Execute a Ray remote function across multiple TPU slices.

Distributes execution of a remote function across multiple TPU slices for large-scale parallel processing. This method sets up the necessary infrastructure including slice actors, placement groups, and MegaScale coordination environment variables.

Parameters
  • remote_fn (RemoteFunction) – The Ray remote function to execute on each slice. Must be decorated with @ray.remote.

  • accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources, must include multi-slice details (pod_count > 1).

  • flatten (bool) – If True (default), returns a flat list of results from all hosts across all slices. If False, returns nested lists where outer list represents slices and inner lists contain results from hosts within each slice.

  • **kwargs – Additional keyword arguments passed to the remote function on each slice.

Returns

A single JobStatus object containing results from all slices.
  • JobSucceeded: Contains results list (flat or nested based on flatten)

  • JobFailed: Contains the exception that caused the failure

  • JobPreempted: Contains preemption error details

  • JobError: Contains unexpected error information

Return type

JobStatus

Raises
  • InsufficientSlicesError – If requested number of slices cannot be allocated.

  • RayError – If slice actor creation fails, coordinator IP cannot be determined, or remote function calls fail.

  • RuntimeError – If no SliceActors available after scaling or coordinator IP cannot be determined.

Note

  • The method automatically sets up MegaScale environment variables for multi-slice coordination including coordinator address, slice IDs, and port configuration.

  • Each slice gets its own SliceActor which manages multiple DeviceHostActors.

  • The pool manager is automatically drained after execution completes or if an error occurs.

  • Environment variables set include: MEGASCALE_COORDINATOR_ADDRESS, MEGASCALE_NUM_SLICES, MEGASCALE_PORT, MEGASCALE_SLICE_ID, TPU_SLICE_NAME, and more.

Example

>>> @ray.remote
>>> def train_on_slice(data, slice_id):
...
...     return model_weights
>>>
>>> tpu_config = TpuAcceleratorConfig(type="v4-32", pod_count=4)
>>>
>>>
>>> job_status = RayExecutor.autoscale_execute(
...     train_on_slice,
...     tpu_config,
...     data=training_data
... )
>>> if isinstance(job_status, JobSucceeded):
...     flat_results = job_status.result
>>>
>>>
>>> job_status = RayExecutor.autoscale_execute(
...     train_on_slice,
...     tpu_config,
...     flatten=False,
...     data=training_data
... )
>>> if isinstance(job_status, JobSucceeded):
...     results_by_slice = job_status.result
classmethod autoscale_execute_resumable(remote_fn: RemoteFunction, accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig, max_retries_preemption: int = 1000000, max_retries_failure: int = 10, **kwargs)[source]#

Execute a multi-slice function with automatic retry on failures.

Provides fault-tolerant execution of Ray remote functions across multiple TPU slices with coordinated retry mechanisms. All slices must succeed for the execution to be considered successful.

Parameters
  • remote_fn (RemoteFunction) – The Ray remote function to execute on each slice. Must be decorated with @ray.remote.

  • accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources with multi-slice support (pod_count > 1).

  • max_retries_preemption (int) – Maximum number of retries when any slice is preempted. Defaults to 1,000,000.

  • max_retries_failure (int) – Maximum number of retries when any slice fails. Defaults to 10.

  • **kwargs – Additional keyword arguments passed to the remote function on each slice. The ‘flatten’ parameter can be used to control result structure.

Returns

List of results from successful execution on all slices.

The structure depends on the flatten parameter passed in kwargs: - If flatten=True (default): Flat list of all results - If flatten=False: List of lists, one per slice

Return type

list[Any]

Raises
  • RuntimeError – If any slice is preempted more than max_retries_preemption times, fails more than max_retries_failure times, or if autoscale_execute returns None or unexpected result type.

  • RayError – If autoscale_execute fails during setup or coordination (slice actor creation, placement group setup, etc.).

  • ray.exceptions.RayTaskError – Re-raised if it occurs and indicates preemption or failure after max retries.

  • Exception – The last encountered exception if retries are exhausted.

Note

  • Implements an all-or-nothing retry policy: if any slice fails or is preempted, the entire multi-slice execution is retried

  • Different error types are handled with appropriate retry logic: * RayTaskError/RayError with “preempted” -> preemption counter * Other errors -> failure counter

  • Each retry attempt creates new slice actors and placement groups

  • Detailed logging tracks retry attempts and error types

Example

>>> @ray.remote
>>> def distributed_training(data_shard):
...
...     return trained_weights
>>>
>>> tpu_config = TpuAcceleratorConfig(type="v4-32", pod_count=4)
>>>
>>>
>>> results = RayExecutor.autoscale_execute_resumable(
...     distributed_training,
...     tpu_config,
...     max_retries_preemption=50,
...     max_retries_failure=3,
...     data_shard=sharded_data
... )
>>>
>>>
>>> results_by_slice = RayExecutor.autoscale_execute_resumable(
...     distributed_training,
...     tpu_config,
...     max_retries_preemption=50,
...     max_retries_failure=3,
...     data_shard=sharded_data,
...     flatten=False
... )
static execute(remote_fn: RemoteFunction, accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig, **kwargs)[source]#

Execute a Ray remote function on a single pod or slice.

Runs a Ray remote function on a single accelerator pod (TPU/GPU) with the specified resource configuration. For multi-slice TPU workloads, use autoscale_execute instead.

Parameters
  • remote_fn (RemoteFunction) – The Ray remote function to execute. Must be decorated with @ray.remote.

  • accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources (TPU, GPU, or CPU).

  • **kwargs – Additional keyword arguments passed to the remote function.

Returns

actual result.

Return type

ray.JobStatus

Raises

ValueError – If pod_count in accelerator_config is not 1, indicating that autoscale_execute should be used instead.

Example

>>> @ray.remote
>>> def compute(x):
...     return x * 2
>>>
>>> config = GpuAcceleratorConfig(count=1, type="v100")
>>> result = RayExecutor.execute(compute, config, x=10)
classmethod execute_resumable(remote_fn: RemoteFunction, accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig, max_retries_preemption: int = 1000000, max_retries_failure: int = 10, **kwargs)[source]#

Execute a remote function with automatic retry on failures.

Provides fault-tolerant execution of Ray remote functions with configurable retry policies for both preemptions and failures. Particularly useful for long-running jobs on preemptible resources.

Parameters
  • remote_fn (RemoteFunction) – The Ray remote function to execute. Must be decorated with @ray.remote.

  • accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources.

  • max_retries_preemption (int) – Maximum number of retries on preemption. Defaults to 1,000,000 (effectively unlimited).

  • max_retries_failure (int) – Maximum number of retries on failure. Defaults to 10.

  • **kwargs – Additional keyword arguments passed to the remote function.

Returns

The result from successful execution of the remote function.

The actual return type depends on what the remote function returns.

Return type

Any

Raises
  • RuntimeError – If the job is preempted more than max_retries_preemption times or fails more than max_retries_failure times. The error message indicates whether it was due to preemptions or failures.

  • ray.exceptions.RayTaskError – Re-raised if it occurs and is not preemption-related after max retries.

  • Exception – The last encountered exception if all retries are exhausted.

Note

  • Preemptions and failures are tracked separately

  • Each attempt logs status information for debugging

  • The method distinguishes between preemption (often recoverable) and failures (may indicate code issues)

  • RayTaskErrors containing “preempted” are treated as preemptions

Example

>>> @ray.remote
>>> def long_running_task(data):
...
...     return process(data)
>>>
>>> config = TpuAcceleratorConfig(type="v4-8", preemptible=True)
>>> result = RayExecutor.execute_resumable(
...     long_running_task,
...     config,
...     max_retries_preemption=100,
...     max_retries_failure=5,
...     data=my_data
... )
class eformer.executor.ray.executor.TpuRemoteManager(tpu_version: str, pod_count: int, *, base_env: dict[str, str] | None = None, runtime_env: dict | None = None, coord_port: int = 8081)[source]#

Bases: object

Session-based manager for TPU remote execution across multiple slices.

Provides a high-level interface for managing TPU resources and executing functions or class methods across multiple TPU slices. Each execution runs in a short-lived worker process to avoid TPU handle leaks.

Key features: - Automatic slice scaling and warming - Persistent DeviceHostActors for efficient execution - Broadcasting of function/method calls across all hosts - Ephemeral worker processes for each execution - MegaScale coordination for multi-slice workloads

tpu_version#

TPU version string (e.g., ‘v4’, ‘v5p’)

pod_count#

Number of TPU pods/slices to use

base_env#

Base environment variables for all workers

runtime_env#

Ray runtime environment configuration

coord_port#

Port for MegaScale coordinator (default 8081)

Example

>>> manager = TpuRemoteManager(
...     tpu_version='v4',
...     pod_count=2,
...     base_env={'MY_VAR': 'value'}
... )
>>> manager.ensure()
>>> results = manager.run_function(my_func, arg1, arg2)
>>> manager.close()
close()[source]#

Clean up all TPU resources and reset state.

Performs cleanup in the following order: 1. Cancels any current work on all host actors 2. Drains the actor pool (terminates all actors) 3. Resets internal state

Note

This method is safe to call multiple times and will suppress any errors during cleanup to ensure resources are freed.

ensure()[source]#

Initialize and prepare TPU resources if not already done.

This method: 1. Scales the slice pool to the requested pod_count 2. Prepares all slices (creates placement groups, etc.) 3. Ensures DeviceHostActors are created on each host 4. Builds per-slice environment variables including MegaScale config 5. Caches host actor handles for efficient execution

Raises

RuntimeError – If no SliceActors are available after scaling.

Note

This method is idempotent - calling it multiple times has no additional effect after the first successful call.

run_class_method(cls_obj: type, init_args: tuple, init_kwargs: dict, method_name: str, *call_args, flatten: bool = True, **call_kwargs)[source]#

Execute a class method on every TPU host across all slices.

Instantiates the class and calls the specified method in ephemeral worker processes on each host. The class is reconstructed inside each worker to ensure clean TPU state.

Parameters
  • cls_obj – The class object (not an instance). Must be pickleable.

  • init_args – Positional arguments for class initialization.

  • init_kwargs – Keyword arguments for class initialization.

  • method_name – Name of the method to call on the instantiated object.

  • *call_args – Positional arguments for the method call.

  • flatten – If True, returns a flat list of results from all hosts. If False, returns nested lists by slice.

  • **call_kwargs – Keyword arguments for the method call.

Returns

List of results from all hosts (flattened). If flatten=False: List of lists, where each inner list contains

results from hosts within a slice.

Return type

If flatten=True

Example

>>> results = manager.run_class_method(
...     MyModel,
...     init_args=(config,),
...     init_kwargs={'checkpoint': 'path/to/ckpt'},
...     method_name='train',
...     epochs=10,
...     batch_size=32
... )

Note

The class is instantiated fresh in each worker process, avoiding TPU handle leaks from persistent objects.

run_function(fn: Callable, *args, flatten: bool = True, **kwargs)[source]#

Execute a Python function on every TPU host across all slices.

Runs the provided function in ephemeral worker processes on each host to avoid TPU handle leaks. The function is executed with the same arguments on every host.

Parameters
  • fn – Python callable to execute on each host. Should be pickleable.

  • *args – Positional arguments to pass to the function.

  • flatten – If True, returns a flat list of results from all hosts. If False, returns nested lists where outer list represents slices and inner lists contain results from hosts within each slice.

  • **kwargs – Keyword arguments to pass to the function.

Returns

List of results from all hosts (flattened). If flatten=False: List of lists, where each inner list contains

results from hosts within a slice.

Return type

If flatten=True

Note

Each execution runs in a new process to avoid TPU handle leaks. The function must be pickleable for Ray serialization.

eformer.executor.ray.executor.autoscale_execute(accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig)[source]#

Decorator for multi-slice execution without retry.

Wraps a Ray remote function to automatically use RayExecutor.autoscale_execute with the specified accelerator configuration. Results from all slices are automatically retrieved with ray.get(). The function will be executed across multiple TPU slices in parallel, with MegaScale coordination automatically configured.

Parameters

accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources with multi-slice support. Must have pod_count > 1.

Returns

Decorator function that wraps the remote function and returns

a list of results, one from each slice.

Return type

Callable

Note

The decorator handles slice actor creation, placement group setup, and MegaScale environment configuration automatically.

Example

>>> tpu_config = TpuAcceleratorConfig(type="v4-32", pod_count=4)
>>>
>>> @autoscale_execute(tpu_config)
>>> @ray.remote
>>> def parallel_compute(data_shard):
...     return compute_result(data_shard)
>>>
>>> results = parallel_compute(sharded_data)
eformer.executor.ray.executor.autoscale_execute_resumable(accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig)[source]#

Decorator for fault-tolerant multi-slice execution.

Wraps a Ray remote function to automatically use RayExecutor.autoscale_execute_resumable with the specified accelerator configuration. Provides automatic retry on preemption or failure of any slice. Uses an all-or-nothing retry policy: if any slice fails, the entire multi-slice execution is retried.

Parameters

accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources with multi-slice support. Must have pod_count > 1.

Returns

Decorator function that wraps the remote function and adds

automatic retry logic for all slices.

Return type

Callable

Note

Default retry limits are 1,000,000 for preemptions and 10 for failures. To customize these limits, use RayExecutor.autoscale_execute_resumable directly with max_retries_preemption and max_retries_failure parameters.

Example

>>> tpu_config = TpuAcceleratorConfig(type="v4-32", pod_count=4, preemptible=True)
>>>
>>> @autoscale_execute_resumable(tpu_config)
>>> @ray.remote
>>> def resilient_training(data_batch):
...
...     return train_model(data_batch)
>>>
>>> results = resilient_training(training_data)
eformer.executor.ray.executor.device_remote(*, accelerator_config: TpuAcceleratorConfig, flatten: bool = True)[source]#

Decorator for TPU-remote execution of functions or classes.

Transforms a regular Python function or class into a TPU-remote version that automatically broadcasts execution across all TPU hosts in the specified configuration. Each execution runs in an ephemeral worker process to avoid TPU handle leaks.

Parameters
  • accelerator_config – TPU configuration specifying version, pod count, and other execution parameters.

  • flatten – If True (default), returns flat list of results from all hosts. If False, returns nested lists where outer list represents slices and inner lists contain results from hosts within each slice.

Returns

Decorator function that wraps the target function or class.

Example for functions:
>>> @device_remote(accelerator_config=tpu_config)
>>> def compute(x, y):
...     return jax.numpy.dot(x, y)
>>>
>>>
>>> results = compute(array1, array2)
>>>
>>>
>>> refs = compute.remote(array1, array2)
>>> results = ray.get(refs)
Example for classes:
>>> @tpu_remote(accelerator_config=tpu_config)
>>> class Model:
...     def __init__(self, config):
...         self.config = config
...
...     def train(self, data):
...
...         return metrics
>>>
>>> model = Model(config)
>>> metrics = model.train(data)
>>> refs = model.train.remote(data)

Note

  • Functions/classes must be pickleable for Ray serialization

  • Each method call creates new worker processes on TPU hosts

  • The decorator manages slice scaling and resource allocation

  • Call model.close() to clean up resources when done

eformer.executor.ray.executor.execute(accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig)[source]#

Decorator for single-pod execution without retry.

Wraps a Ray remote function to automatically use RayExecutor.execute with the specified accelerator configuration. Results are automatically retrieved with ray.get(). This decorator is suitable for tasks that don’t require fault tolerance or where failures should be handled by the caller.

Parameters

accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources to use for execution. Should have pod_count=1 for single-pod execution.

Returns

Decorator function that wraps the remote function and

automatically retrieves results.

Return type

Callable

Note

Unlike execute_resumable, this decorator does not retry on failure. Use this for quick tasks or when you want to handle failures yourself.

Example

>>> gpu_config = GpuAcceleratorConfig(count=2, type="a100")
>>>
>>> @execute(gpu_config)
>>> @ray.remote
>>> def gpu_task(tensor):
...     return tensor.cuda() * 2
>>>
>>> result = gpu_task(my_tensor)
eformer.executor.ray.executor.execute_resumable(accelerator_config: eformer.executor.ray.resource_manager.TpuAcceleratorConfig | eformer.executor.ray.resource_manager.GpuAcceleratorConfig | eformer.executor.ray.resource_manager.CpuAcceleratorConfig)[source]#

Decorator for fault-tolerant single-pod execution.

Wraps a Ray remote function to automatically use RayExecutor.execute_resumable with the specified accelerator configuration. The decorated function will automatically retry on preemption or failure according to the default retry policies (1,000,000 retries for preemption, 10 for failures).

Parameters

accelerator_config (AcceleratorConfigType) – Configuration for accelerator resources to use for execution. Should have pod_count=1 for single-pod execution.

Returns

Decorator function that wraps the remote function and adds

automatic retry logic.

Return type

Callable

Note

To customize retry behavior, use RayExecutor.execute_resumable directly with max_retries_preemption and max_retries_failure parameters.

Example

>>> tpu_config = TpuAcceleratorConfig(type="v4-8")
>>>
>>> @execute_resumable(tpu_config)
>>> @ray.remote
>>> def my_task(data):
...     return process(data)
>>>
>>> result = my_task(input_data)
eformer.executor.ray.executor.resolve_maybe_refs(items)[source]#

Resolve Ray ObjectRefs to their values if present.

Checks if all items in the provided list are Ray ObjectRefs and resolves them using ray.get(). If the items are not ObjectRefs or the list is empty, returns them unchanged.

Parameters

items – List of items that may be Ray ObjectRefs or regular values.

Returns

List of resolved values if input contained ObjectRefs, otherwise the original items unchanged.

Example

>>> refs = [task.remote() for task in tasks]
>>> results = resolve_maybe_refs(refs)
>>>
>>> values = [1, 2, 3]
>>> results = resolve_maybe_refs(values)