Source code for eformer.executor.ray.pool_manager

# Copyright 2026 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Resource pool management for distributed Ray actors.

This module provides comprehensive abstractions for managing pools of Ray actors,
with specialized focus on TPU/GPU slice management for distributed computing.
It includes health monitoring, automatic scaling, resource lifecycle management,
and placement group coordination for optimal resource allocation.

Key Components:
    - **ActorPoolMember**: Wrapper for actor handles with metadata
    - **ResourcePoolManager**: Abstract base for managing actor pools
    - **SlicePoolManager**: Specialized manager for TPU/GPU slices with placement groups
    - **SliceActor**: Ray actor for managing individual compute slices
    - **DeviceHostActor**: Ray actor for managing individual TPU hosts within slices

Resource Management Features:
    - Placement group coordination with STRICT_SPREAD strategy
    - Automatic resource request handling through Ray autoscaler
    - Health monitoring with graceful shutdown sequences
    - Robust error handling with actor restart capabilities
    - Slot-based actor allocation for deterministic placement

Environment Variables:
    - **EFORMER_SCALE_POLL_S**: Scaling operation polling interval (default: "30")
    - **EFORMER_SCALE_ADD_TIMEOUT_S**: Timeout for adding new actors (default: "604800")

Example:
    Managing a multi-slice TPU configuration with placement groups:

    >>> from eformer.executor.ray import SlicePoolManager
    >>>
    >>>
    >>> manager = SlicePoolManager(tpu_type="v4-8")
    >>> manager.scale_multislice(num_slices=4)
    >>> actors = manager.get_all_actors_in_pool()
    >>>
    >>>
    >>> manager.prepare_all_slices()
    >>> manager.drain_actor_pool()
"""

from __future__ import annotations

import logging
import os
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar
from uuid import uuid4

import ray
import requests
from ray.actor import ActorHandle
from ray.autoscaler.sdk import request_resources
from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy

from .types import HostInfo, SliceInfo

logger = logging.getLogger("ray")

HEALTH_CHECK_TIMEOUT_S = 60
SLICE_ACTOR_START_TIMEOUT_S = 4 * 60 * 60
SCALE_POLL_S = int(os.getenv("EFORMER_SCALE_POLL_S", "30"))
SCALE_ADD_TIMEOUT_S = int(os.getenv("EFORMER_SCALE_ADD_TIMEOUT_S", "604800"))
ActorInfoT = TypeVar("ActorInfoT")


[docs]class InsufficientSlicesError(RuntimeError): """Raised when the requested number of TPU slices cannot be allocated. This exception is raised by SlicePoolManager.scale_multislice when none of the requested slice counts can be satisfied, typically due to: - Insufficient TPU resources in the cluster - Preemption of TPU nodes during scaling - Ray autoscaler unable to provision required nodes The exception message includes details about requested vs available slices. Example: >>> manager = SlicePoolManager(tpu_type="v4-32") >>> try: ... manager.scale_multislice([4, 8]) ... except InsufficientSlicesError as e: ... print(f"Could not allocate TPU slices: {e}") ... """ pass
[docs]@dataclass(frozen=True) class ActorPoolMember(Generic[ActorInfoT]): """Container for an actor handle and its associated metadata. Attributes: actor: Ray actor handle for remote execution. actor_info: Metadata about the actor (type depends on ActorInfoT). """ actor: ActorHandle actor_info: ActorInfoT
[docs]class ResourcePoolManager(Generic[ActorInfoT]): """Abstract base class for managing pools of Ray actors. Provides common functionality for scaling, health monitoring, and lifecycle management of actor pools. Subclasses should implement create_actor() to define how actors are created. Attributes: _actor_pool: List of active actor pool members. """ def __init__(self) -> None: """Initialize an empty actor pool.""" self._actor_pool: list[ActorPoolMember[ActorInfoT]] = []
[docs] def get_all_actors_in_pool(self) -> list[ActorHandle]: """Get all actor handles in the pool. Returns: List of Ray actor handles. """ return [m.actor for m in self._actor_pool]
[docs] def get_all_pool_members(self) -> list[ActorPoolMember[ActorInfoT]]: """Get a copy of all pool members with their metadata. Returns: List of ActorPoolMember objects containing actors and their info. """ return self._actor_pool.copy()
[docs] def get_actor_pool_name(self) -> str: """Get a human-readable name for this actor pool. Returns: String identifier for the pool, defaults to class name. """ return self.__class__.__name__
[docs] def get_actor_name_from_actor_info(self, actor_info: ActorInfoT) -> str: """Generate a human-readable name from actor info. Args: actor_info: Metadata about the actor. Returns: String representation of the actor for logging. """ return str(actor_info)
[docs] def create_actor(self) -> ActorHandle: """Create a new actor instance. Must be implemented by subclasses to define actor creation logic. Returns: Ray actor handle for the newly created actor. Raises: NotImplementedError: If not overridden by subclass. """ raise NotImplementedError
def _remove_unhealthy_members_from_actor_pool(self) -> None: """Remove unhealthy actors from the pool. Performs health checks on all actors and removes those that are unresponsive, dead, or unhealthy. Attempts to kill removed actors. """ if not self._actor_pool: return ref_map = {m: m.actor.healthy.remote() for m in self._actor_pool} refs = list(ref_map.values()) done, _ = ray.wait(refs, num_returns=len(refs), timeout=HEALTH_CHECK_TIMEOUT_S) done_set = set(done) healthy: list[ActorPoolMember[HostInfo]] = [] for member, ref in ref_map.items(): name = self.get_actor_name_from_actor_info(member.actor_info) if ref in done_set: try: if ray.get(ref, timeout=0): healthy.append(member) else: logger.warning(f"Actor {name} reported unhealthy; killing") try: ray.kill(member.actor, no_restart=True) except Exception: pass except Exception as e: logger.warning(f"Actor {name} health check exception ({e}); killing") try: ray.kill(member.actor, no_restart=True) except Exception: pass else: logger.warning(f"Actor {name} health timeout; killing") try: ray.kill(member.actor, no_restart=True) except Exception: pass self._actor_pool = healthy def _add_members_to_actor_pool(self, desired_num_actors: int) -> None: """Add new actors to the pool to reach desired size. Creates new actors asynchronously and waits for them to start. Actors that fail to start within the timeout are killed. Args: desired_num_actors: Target number of actors in the pool. """ current = len(self._actor_pool) if current >= desired_num_actors: return num_to_add = desired_num_actors - current logger.info(f"Scaling up pool {self.get_actor_pool_name()} from {current} to {desired_num_actors}") actors = [self.create_actor() for _ in range(num_to_add)] awaitables = [(actor, actor.get_info.remote()) for actor in actors] logger.info(f"Waiting up to {SLICE_ACTOR_START_TIMEOUT_S}s for {num_to_add} slice actors to start...") ray.wait([a for _, a in awaitables], num_returns=len(awaitables), timeout=SLICE_ACTOR_START_TIMEOUT_S) started = 0 for actor, info_ref in awaitables: try: info = ray.get(info_ref, timeout=0) self._actor_pool.append(ActorPoolMember(actor, info)) started += 1 logger.info(f"Added actor {self.get_actor_name_from_actor_info(info)}") except Exception as e: logger.warning(f"SliceActor failed to start in time: {e}; killing actor") try: ray.kill(actor, no_restart=True) except Exception: pass logger.info(f"Started {started}/{num_to_add} slice actors") def _remove_members_from_actor_pool(self, desired_num_actors: int) -> None: """Remove actors to reach the desired pool size. Args: desired_num_actors: Target number of actors in the pool. """ while len(self._actor_pool) > desired_num_actors: member = self._actor_pool.pop() name = self.get_actor_name_from_actor_info(member.actor_info) try: try: ray.get(member.actor.shutdown.remote(), timeout=5) except Exception: pass ray.kill(member.actor, no_restart=True) logger.info(f"Removed actor {name}") except Exception as e: logger.error(f"Failed to kill actor {name}: {e}") def _scale_actor_pool(self, desired_num_actors: int) -> None: """Scale the actor pool to the desired size. First removes unhealthy actors, then adds or removes actors as needed to reach the target size. Args: desired_num_actors: Target number of actors in the pool. """ self._remove_unhealthy_members_from_actor_pool() current = len(self._actor_pool) if current < desired_num_actors: self._add_members_to_actor_pool(desired_num_actors) elif current > desired_num_actors: self._remove_members_from_actor_pool(desired_num_actors)
[docs] def drain_actor_pool(self) -> None: """Shut down and remove all actors from the pool. Attempts graceful shutdown first, then forcefully kills actors. Clears the actor pool after draining. """ if not self._actor_pool: return shutdown_refs = [] for member in self._actor_pool: try: shutdown_refs.append(member.actor.shutdown.remote()) except Exception: pass try: ray.wait(shutdown_refs, num_returns=len(shutdown_refs), timeout=5.0) except Exception: pass for member in self._actor_pool: name = self.get_actor_name_from_actor_info(member.actor_info) try: ray.kill(member.actor, no_restart=True) logger.info(f"Killed actor {name}") except Exception as e: logger.error(f"Failed to kill actor {name}: {e}") self._actor_pool = []
@ray.remote class DeviceHostActor: """Ray actor for managing a single TPU host within a slice. Handles task execution on a specific TPU host, managing TPU resources, environment variables, and task lifecycle. Supports cancellation and health monitoring. Each DeviceHostActor runs on a specific Ray node and manages TPU devices on that node. Attributes: host_id: Unique identifier for this host within its slice (0-based). slice_name: Name of the TPU slice this host belongs to. num_devices: Number of TPU devices on this host. _failed: Whether this host has encountered a failure. _awaitable: Current running task's ObjectRef. _node_id: Ray node ID where this actor is running. Environment Variables Set: TPU_HOST_ID: Host index within the slice. TPU_SLICE_NAME: Name of the parent slice. TPU_NUM_DEVICES: Number of devices on this host (if available). """ def __init__(self, host_id: int, slice_name: str, num_devices: int | None = None): """Initialize a DeviceHostActor. Args: host_id: Unique identifier for this host within its slice. slice_name: Name of the TPU slice this host belongs to. num_devices: Optional number of TPU devices available on this host. """ self.host_id = host_id self.slice_name = slice_name self.num_devices = num_devices or 0 self._failed = False self._awaitable: ray.ObjectRef | None = None self._node_id = ray.get_runtime_context().get_node_id() logger.info(f"DeviceHostActor[{slice_name}#{host_id}] init; num_devices={num_devices}; node_id={self._node_id}") def healthy(self) -> bool: """Check if this host is healthy and operational. Returns: True if host is not failed and not being preempted. """ return not self._failed and not self.is_being_preempted() def is_being_preempted(self) -> bool: """Check if this GCP instance is being preempted. Queries the GCP metadata server to determine if the instance is scheduled for preemption. Returns: True if instance is being preempted, False otherwise. """ try: r = requests.get( "http://metadata.google.internal/computeMetadata/v1/instance/preempted", headers={"Metadata-Flavor": "Google"}, timeout=1.0, ) return r.status_code == 200 and r.text.strip().upper() == "TRUE" except requests.RequestException: return False def get_info(self) -> HostInfo: """Get current information about this host. Returns: HostInfo object with host metadata and status. """ return HostInfo( host_id=self.host_id, slice_name=self.slice_name, num_devices=self.num_devices, healthy=self.healthy(), failed=self._failed, node_id=self._node_id, ) def _kill_vfio_holders(self): """Quietly kill processes holding /dev/vfio/*. Controlled by: - EFORMER_KILL_VFIO=1 to enable (default 0 = disabled) - EFORMER_INSTALL_LSOF=1 to attempt quiet, noninteractive lsof install (optional) All command outputs are suppressed; never prompts for sudo. """ import os if os.getenv("EFORMER_KILL_VFIO", "1") != "1": return try: import shutil import signal import subprocess def run_quiet(cmd: str, capture: bool = False) -> subprocess.CompletedProcess: return subprocess.run( ["bash", "-lc", cmd], check=False, stdout=(subprocess.PIPE if capture else subprocess.DEVNULL), stderr=subprocess.DEVNULL, text=True, env=dict(os.environ, DEBIAN_FRONTEND="noninteractive"), ) if shutil.which("lsof") is None and os.getenv("EFORMER_INSTALL_LSOF", "0") == "1": run_quiet("sudo -n apt-get -qq update || true") run_quiet("sudo -n apt-get -qq -y install lsof || true") if shutil.which("lsof") is None: return p = run_quiet("lsof -t /dev/vfio/* 2>/dev/null | sort -u", capture=True) pids = [] if p and p.stdout: pids = [int(pid) for pid in p.stdout.split() if pid.isdigit() and int(pid) != os.getpid()] for pid in pids: try: os.kill(pid, signal.SIGKILL) except Exception: pass except Exception: pass def _merge_runtime_env(self, runtime_env: dict | None, env_vars: dict | None) -> dict: """Merge environment variables into a runtime environment dict. Args: runtime_env: Base runtime environment configuration. env_vars: Environment variables to merge in. Returns: Merged runtime environment dictionary. """ re = dict(runtime_env or {}) if env_vars: ev = dict(re.get("env_vars", {})) ev.update({str(k): str(v) for k, v in env_vars.items() if v is not None}) re["env_vars"] = ev return re def _hacky_remove_tpu_lockfile(self): """Remove TPU lockfile that may prevent TPU initialization. Attempts to remove /tmp/libtpu_lockfile which can cause issues when reusing TPU resources. Falls back to sudo if needed. """ try: if os.path.exists("/tmp/libtpu_lockfile"): os.unlink("/tmp/libtpu_lockfile") except FileNotFoundError: pass except PermissionError: try: os.system("sudo rm /tmp/libtpu_lockfile") except Exception: pass def _cancel_tasks_and_wait(self, tasks: list[ray.ObjectRef], timeout_s: float = 240.0) -> None: """Cancel Ray tasks and wait for them to complete. Forcefully cancels all provided tasks and waits for completion or timeout. Args: tasks: List of Ray ObjectRefs to cancel. timeout_s: Maximum time to wait for cancellation. """ if not tasks: return try: for t in tasks: ray.cancel(t, force=True, recursive=True) except Exception as e: logger.warning(f"Failed to cancel some tasks: {e}") done, pending = ray.wait(tasks, num_returns=len(tasks), timeout=timeout_s) if pending: logger.warning(f"Cancelled {len(done)} tasks; {len(pending)} still pending after {timeout_s}s.") def cancel_current(self): """Cancel the currently running task if any. Cancels and waits for the current task to complete, then clears the awaitable reference. """ if self._awaitable: self._cancel_tasks_and_wait([self._awaitable]) self._awaitable = None def run_remote_fn( self, remote_fn, *, f_args: tuple = (), f_kwargs: dict | None = None, runtime_env: dict | None = None, env: dict | None = None, num_cpus: float = 0.0, memory_bytes: float = 20e9, extra_resources: dict | None = None, ) -> ray.ObjectRef: """Launch a cancelable task on this host's node, reserving TPU resources. Executes a Ray remote function on this specific TPU host with proper resource allocation and node affinity. Automatically cancels any previously running task and manages TPU lockfiles. Args: remote_fn: Ray remote function or callable to execute. If not already a remote function, will be wrapped with @ray.remote(max_calls=1). runtime_env: Optional Ray runtime environment configuration for dependency management and environment setup. env: Additional environment variables to merge with host environment. num_cpus: Number of CPUs to reserve for the task (default: 8.0). memory_bytes: Memory to reserve in bytes (default: 20GB). extra_resources: Additional custom resources to request. Returns: ray.ObjectRef: Reference to the running task that can be used to retrieve results with ray.get() or cancel with ray.cancel(). Raises: RuntimeError: If host is unhealthy or being preempted. ValueError: If remote_fn doesn't have max_calls=1 set. Note: - Task runs with strict node affinity to this host's node. - TPU resources are automatically reserved based on num_devices. - Previous tasks are cancelled before starting new ones. - TPU lockfile is cleaned up before execution. """ if not self.healthy(): raise RuntimeError(f"Host {self.host_id} unhealthy or preempted") if self._awaitable: self._cancel_tasks_and_wait([self._awaitable]) self._kill_vfio_holders() self._hacky_remove_tpu_lockfile() host_env = {"TPU_HOST_ID": str(self.host_id), "TPU_SLICE_NAME": self.slice_name} if self.num_devices: host_env["TPU_NUM_DEVICES"] = str(self.num_devices) merged_runtime_env = self._merge_runtime_env(runtime_env, {**host_env, **(env or {})}) resources = dict(extra_resources or {}) if self.num_devices and "TPU" not in resources: resources["TPU"] = self.num_devices try: from ray.remote_function import RemoteFunction as _RF py_fn = remote_fn._function if isinstance(remote_fn, _RF) else remote_fn except Exception: py_fn = remote_fn f_kwargs = f_kwargs or {} @ray.remote(max_calls=1) def _runner(fn, args, kwargs): try: return fn(*args, **kwargs) finally: try: import jax.distributed as jdist jdist.shutdown() except Exception: pass self._awaitable = _runner.options( scheduling_strategy=NodeAffinitySchedulingStrategy(self._node_id, soft=False), resources=resources or None, num_cpus=num_cpus, num_gpus=0, memory=int(memory_bytes), runtime_env=merged_runtime_env, max_retries=0, ).remote(py_fn, f_args, f_kwargs) return self._awaitable def shutdown(self) -> None: """Gracefully shut down this host actor. Cancels any running task and marks the host as failed. """ try: self.cancel_current() finally: self._failed = True logger.info(f"Shut down DeviceHostActor[{self.slice_name}#{self.host_id}]") @ray.remote class SliceActor: """Ray actor for managing a TPU slice with multiple hosts. Coordinates multiple TPU hosts within a single slice, handling placement groups, resource allocation, and distributed task execution. Each SliceActor manages a complete TPU pod/slice and ensures hosts are properly distributed across nodes using placement groups. Attributes: _actor_pool: List of DeviceHostActor pool members for this slice. _failed: Whether this slice has failed or been preempted. _slice_info: Detailed information about the TPU slice configuration. _host_placement_group: Ray placement group for STRICT_SPREAD host distribution. _host_infos: Node and device information for each host in the slice. Lifecycle: 1. Created by SlicePoolManager with TPU head resource requirement. 2. Discovers slice configuration from TPU environment. 3. Creates placement group for host distribution. 4. Spawns DeviceHostActors on each host node. 5. Manages task execution across all hosts. 6. Cleans up resources on shutdown. """ def __init__(self): """Initialize a slice actor. Creates an empty actor pool and prepares to manage TPU hosts within a single slice. Discovers slice information from the TPU environment during initialization. """ self._actor_pool: list[ActorPoolMember[HostInfo]] = [] self._failed = False self._slice_info: SliceInfo | None = None self._host_placement_group = None self._host_infos: list[dict] | None = None self._initialize_slice_info() @staticmethod @ray.remote(num_cpus=0) def discover_node_info(): """Discover information about the current Ray node. Static remote function to gather node metadata including IP, node ID, pod name, and TPU count. Returns: Dictionary with node information. """ import ray pod_name = None ray_tpu = None try: from ray.util.accelerators import tpu as ray_tpu pod_name = ray_tpu.get_current_pod_name() except Exception: ray_tpu = None num_devices = None try: from ray._private.accelerators import TPUAcceleratorManager num_devices = TPUAcceleratorManager.get_current_node_num_accelerators() except Exception: pass num_hosts = 1 if ray_tpu is not None: try: num_hosts = int(ray_tpu.get_current_pod_worker_count()) except Exception: num_hosts = 1 if os.getenv("EFORMER_MODERATE", "1") == "1" and pod_name: available_hosts = ray.cluster_resources().get(pod_name, None) if available_hosts is not None and num_hosts > available_hosts: available_hosts = int(available_hosts) num_devices = int(available_hosts) real_num_devices = 4 print( f"auto-discovered to set num_hosts from {num_hosts} to {available_hosts} and " f"num_devices from {num_devices} to {real_num_devices}" ) num_hosts = available_hosts num_devices = real_num_devices return { "ip": ray.util.get_node_ip_address(), "node_id": ray.get_runtime_context().get_node_id(), "pod_name": pod_name, "num_devices": num_devices, "num_hosts": num_hosts, } def _create_actor_for_host_id(self, host_id: int) -> ActorHandle: """Create a DeviceHostActor for a specific host ID. Creates an actor with node affinity to ensure it runs on the correct TPU host. The actor is created as detached to persist beyond the parent's lifetime. Args: host_id: Zero-based index of the host within the slice. Returns: Ray actor handle for the DeviceHostActor. Raises: RuntimeError: If slice/host info not initialized or host_id invalid. """ if not self._slice_info or not self._host_infos: raise RuntimeError("Slice or host info not initialized") if host_id >= len(self._host_infos): raise RuntimeError(f"Missing host info for host_id={host_id}") info = self._host_infos[host_id] node_id = info["node_id"] num_devices_for_host = info.get("num_devices") return DeviceHostActor.options( num_cpus=0, scheduling_strategy=NodeAffinitySchedulingStrategy(node_id=node_id, soft=False), name=f"{self._slice_info.slice_name}-host-{host_id}-{uuid4().hex[:8]}", ).remote(host_id, self._slice_info.slice_name, num_devices_for_host) def _initialize_slice_info(self) -> None: """Initialize slice information from TPU environment. Discovers slice name, host count, and TPU configuration. Sets _failed flag if initialization fails. """ try: from ray.util.accelerators import tpu as ray_tpu num_accelerators_per_host = None try: from ray._private.accelerators import TPUAcceleratorManager num_accelerators_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() except Exception: pass slice_name = ray_tpu.get_current_pod_name() num_hosts = int(ray_tpu.get_current_pod_worker_count()) if os.getenv("EFORMER_MODERATE", "1") == "1": available_hosts = ray.cluster_resources().get(slice_name, None) if available_hosts is not None and num_hosts > available_hosts: available_hosts = int(available_hosts) real_accelerators_per_host = 4 print( f"setting {num_hosts=} to {available_hosts=} and " f"{num_accelerators_per_host=} to {real_accelerators_per_host=}" ) num_hosts = available_hosts num_accelerators_per_host = real_accelerators_per_host ip_address = ray.util.get_node_ip_address() self._slice_info = SliceInfo( slice_name=slice_name, num_hosts=num_hosts, ip_address=ip_address, num_accelerators_per_host=num_accelerators_per_host or 0, ) logger.info(f"Initialized SliceActor: {self._slice_info}") except Exception as e: logger.error(f"Failed to initialize slice info: {e}") self._failed = True def healthy(self) -> bool: """Check if the slice is healthy and operational. Verifies that the slice has not failed and is not being preempted by the cloud provider. Returns: True if slice is healthy and available for task execution, False if failed or being preempted. """ if self._failed: return False return not self.is_being_preempted() def is_being_preempted(self) -> bool: """Check if this GCP instance is being preempted. Queries GCP metadata server to determine preemption status. Returns: True if instance is being preempted, False otherwise. """ try: r = requests.get( "http://metadata.google.internal/computeMetadata/v1/instance/preempted", headers={"Metadata-Flavor": "Google"}, timeout=1.0, ) return r.status_code == 200 and r.text.strip().upper() == "TRUE" except requests.RequestException: return False def get_info(self) -> SliceInfo: """Get current information about this slice. Returns: SliceInfo object containing slice configuration, host count, IP addresses, and TPU device information. Raises: RuntimeError: If slice information has not been initialized. """ if not self._slice_info: raise RuntimeError("Slice info not initialized") return self._slice_info def get_all_actors_in_pool(self) -> list[ActorHandle]: """Get all actor handles in the pool. Returns: List of Ray actor handles. """ return [m.actor for m in self._actor_pool] def get_all_pool_members(self) -> list[ActorPoolMember[HostInfo]]: """Get a copy of all pool members with their metadata. Returns: List of ActorPoolMember objects containing actors and their info. """ return self._actor_pool.copy() def get_actor_pool_name(self) -> str: """Get a human-readable name for this actor pool. Returns: String identifier for the pool, defaults to class name. """ if self._slice_info: return f"SliceActor({self._slice_info.slice_name})" return "SliceActor(uninitialized)" def get_actor_name_from_actor_info(self, actor_info: HostInfo) -> str: """Generate a human-readable name from actor info. Args: actor_info: Metadata about the actor. Returns: String representation of the actor for logging. """ return f"{actor_info.slice_name}-host-{actor_info.host_id}" def _ensure_host_placement_group(self) -> None: """Ensure placement group exists for distributing hosts. Creates a placement group with STRICT_SPREAD strategy to ensure hosts are distributed across different nodes. Also discovers host information for each placement group bundle. Raises: RuntimeError: If slice info is not initialized. """ if not self._slice_info: raise RuntimeError("Slice info not initialized") if self._host_placement_group is not None: return slice_label = self._slice_info.slice_name bundles = [{"CPU": 0, slice_label: 1} for _ in range(self._slice_info.num_hosts)] request_resources(bundles=bundles) self._host_placement_group = placement_group(bundles, strategy="STRICT_SPREAD") ray.get(self._host_placement_group.ready()) futures = [ SliceActor.discover_node_info.options( scheduling_strategy=PlacementGroupSchedulingStrategy( self._host_placement_group, placement_group_bundle_index=i, placement_group_capture_child_tasks=False, ) ).remote() for i in range(self._slice_info.num_hosts) ] self._host_infos = ray.get(futures) self._slice_info = SliceInfo( slice_name=self._slice_info.slice_name, num_hosts=self._slice_info.num_hosts, ip_address=self._slice_info.ip_address, num_accelerators_per_host=self._slice_info.num_accelerators_per_host, node_ids=[h.get("node_id") for h in self._host_infos], host_infos=self._host_infos, ) logger.info(f"Prepared host placement group for slice {self._slice_info.slice_name};") def prepare_hosts(self) -> None: """Prepare host placement group for this slice. Ensures the placement group is created and ready for host actors. """ self._ensure_host_placement_group() def create_actor(self) -> ActorHandle: """Create a new TPU host actor within this slice. Creates a DeviceHostActor with proper node affinity to ensure it runs on the correct host within the slice. Assigns TPU resources if available on the target node. Returns: Ray actor handle for the newly created DeviceHostActor. Raises: RuntimeError: If slice is not initialized or host info is missing. """ if not self._slice_info: raise RuntimeError("Cannot create host actor: slice not initialized") self._ensure_host_placement_group() host_id = len(self._actor_pool) if not self._host_infos or host_id >= len(self._host_infos): raise RuntimeError(f"Missing host info for host_id={host_id}") info = self._host_infos[host_id] node_id = info["node_id"] num_devices_for_host = info.get("num_devices") return DeviceHostActor.options( num_cpus=0, scheduling_strategy=NodeAffinitySchedulingStrategy(node_id=node_id, soft=False), ).remote(host_id, self._slice_info.slice_name, num_devices_for_host) def ensure_host_pool(self, desired_hosts: int | None = None) -> None: """Ensure the host actor pool has the desired number of hosts. Args: desired_hosts: Target number of hosts, defaults to slice's host count. Raises: RuntimeError: If slice info is not initialized. """ if not self._slice_info: raise RuntimeError("Slice info not initialized") self._ensure_host_placement_group() target = desired_hosts if desired_hosts is not None else self._slice_info.num_hosts self._scale_actor_pool(target) def _remove_unhealthy_members_from_actor_pool(self) -> None: """Remove unhealthy actors from the pool. Performs health checks on all actors and removes those that are unresponsive, dead, or unhealthy. Attempts to kill removed actors. """ if not self._actor_pool: return ref_map = {m: m.actor.healthy.remote() for m in self._actor_pool} refs = list(ref_map.values()) done, _ = ray.wait(refs, num_returns=len(refs), timeout=HEALTH_CHECK_TIMEOUT_S) done_set = set(done) healthy: list[ActorPoolMember[ActorInfoT]] = [] for member, ref in ref_map.items(): name = self.get_actor_name_from_actor_info(member.actor_info) if ref in done_set: try: if ray.get(ref, timeout=0): healthy.append(member) else: logger.warning(f"Actor {name} reported unhealthy; killing") try: ray.kill(member.actor, no_restart=True) except Exception: pass except Exception as e: logger.warning(f"Actor {name} health check exception ({e}); killing") try: ray.kill(member.actor, no_restart=True) except Exception: pass else: logger.warning(f"Actor {name} health timeout; killing") try: ray.kill(member.actor, no_restart=True) except Exception: pass self._actor_pool = healthy def _add_members_to_actor_pool(self, desired_num_actors: int) -> None: """Add new host actors to reach the desired pool size. Creates host actors sequentially to ensure correct host_id assignment. Each actor is given a timeout to start, after which it's killed. Args: desired_num_actors: Target number of actors in the pool. """ current = len(self._actor_pool) if current >= desired_num_actors: return to_add = desired_num_actors - current logger.info(f"Scaling up pool {self.get_actor_pool_name()} from {current} to {desired_num_actors}") info_ref_to_actor: dict[ray.ObjectRef, ActorHandle] = {} for host_id in range(current, current + to_add): actor = self._create_actor_for_host_id(host_id) info_ref = actor.get_info.remote() info_ref_to_actor[info_ref] = actor pending = list(info_ref_to_actor.keys()) started = 0 poll_s = 2.0 deadline = time.time() + HEALTH_CHECK_TIMEOUT_S while pending and time.time() < deadline: done, pending = ray.wait(pending, num_returns=len(pending), timeout=poll_s) if not done: continue for info_ref in done: actor = info_ref_to_actor.pop(info_ref, None) if not actor: continue try: info = ray.get(info_ref, timeout=0) self._actor_pool.append(ActorPoolMember(actor, info)) started += 1 logger.info(f"Added actor {self.get_actor_name_from_actor_info(info)}") except Exception as e: logger.error(f"Failed to start host actor: {e}") try: ray.kill(actor, no_restart=True) except Exception: pass def _remove_members_from_actor_pool(self, desired_num_actors: int) -> None: """Remove actors to reach the desired pool size. Args: desired_num_actors: Target number of actors in the pool. """ while len(self._actor_pool) > desired_num_actors: member = self._actor_pool.pop() name = self.get_actor_name_from_actor_info(member.actor_info) try: try: ray.get(member.actor.shutdown.remote(), timeout=5) except Exception: pass ray.kill(member.actor, no_restart=True) logger.info(f"Removed actor {name}") except Exception as e: logger.error(f"Failed to kill actor {name}: {e}") def _scale_actor_pool(self, desired_num_actors: int) -> None: """Scale the actor pool to the desired size. First removes unhealthy actors, then adds or removes actors as needed to reach the target size. Args: desired_num_actors: Target number of actors in the pool. """ self._remove_unhealthy_members_from_actor_pool() current = len(self._actor_pool) if current < desired_num_actors: self._add_members_to_actor_pool(desired_num_actors) elif current > desired_num_actors: self._remove_members_from_actor_pool(desired_num_actors) def drain_actor_pool(self) -> None: """Shut down and remove all actors from the pool. Attempts graceful shutdown first, then forcefully kills actors. Clears the actor pool after draining. """ if not self._actor_pool: return shutdown_refs = [] for member in self._actor_pool: try: shutdown_refs.append(member.actor.shutdown.remote()) except Exception: pass try: ray.wait(shutdown_refs, num_returns=len(shutdown_refs), timeout=5.0) except Exception: pass for member in self._actor_pool: name = self.get_actor_name_from_actor_info(member.actor_info) try: ray.kill(member.actor, no_restart=True) logger.info(f"Killed actor {name}") except Exception as e: logger.error(f"Failed to kill actor {name}: {e}") self._actor_pool = [] def _await_all_hosts_healthy(self, timeout_s: int = 60, poll_s: float = 2.0) -> bool: """Wait for all hosts in the pool to become healthy. Polls the health status of all host actors until they all report healthy or the timeout is reached. Args: timeout_s: Maximum time to wait for hosts to become healthy (default: 60). poll_s: Interval between health checks in seconds (default: 2.0). Returns: True if all hosts became healthy within the timeout, False otherwise. """ deadline = time.time() + timeout_s while time.time() < deadline: statuses = ray.get([m.actor.healthy.remote() for m in self._actor_pool]) if all(statuses): return True time.sleep(poll_s) return False def run_remote_fn( self, remote_fn, runtime_env: dict | None = None, env: dict | None = None, f_args: tuple = (), f_kwargs: dict | None = None, ): """Execute a remote function on all hosts in this slice. Ensures all hosts are ready, then launches the function on each host in parallel. The function runs with TPU resources reserved. This is the primary method used by RayExecutor.autoscale_execute to run workloads across the slice. Args: remote_fn: Ray remote function or callable to execute. Will be executed once per host in the slice. runtime_env: Optional Ray runtime environment configuration for dependencies and environment setup. env: Optional environment variables to set on all hosts. Returns: List[ray.ObjectRef]: One ObjectRef per host in the slice, ordered by host_id. Results can be retrieved with ray.get(). Raises: RuntimeError: If slice info is not initialized. Note: - Automatically ensures host pool is at full capacity. - Each host runs the function with proper TPU resource allocation. - Functions run in parallel across all hosts. - Environment variables include TPU_HOST_ID and TPU_SLICE_NAME. """ if not self._slice_info: raise RuntimeError("Slice info not initialized") self.ensure_host_pool(self._slice_info.num_hosts) try: self._await_all_hosts_healthy(timeout_s=int(os.getenv("EFORMER_HOST_HEALTH_WAIT_S", "60"))) except Exception: pass futures = [ member.actor.run_remote_fn.remote( remote_fn, f_args=f_args, f_kwargs=f_kwargs, runtime_env=runtime_env, env=env, ) for member in self._actor_pool ] return futures def shutdown(self): """Gracefully shut down this slice actor. Removes the placement group, marks the slice as failed, and prevents any new task execution on this slice. """ try: self.drain_actor_pool() except Exception: pass if self._host_placement_group: try: remove_placement_group(self._host_placement_group) except Exception: pass self._host_placement_group = None self._failed = True
[docs]class SlicePoolManager(ResourcePoolManager[SliceInfo]): """Manager for multiple TPU slices in multi-slice configurations. Coordinates multiple SliceActors to manage multi-slice TPU configurations. Handles scaling, health monitoring, and distributed task execution across multiple TPU slices. This is the top-level manager used by RayExecutor for multi-slice workloads. Attributes: _tpu_type: Type of TPU (e.g., "v4-8", "v5e-16"). _last_scale_ts: Timestamp of last scaling operation for rate limiting. _last_scale_check_ts: Timestamp of last scale check. _actor_pool: List of SliceActor pool members. Hierarchy: SlicePoolManager -> SliceActors -> DeviceHostActors -> Tasks Resource Requirements: - Each SliceActor requires a TPU-{type}-head resource. - Each slice requires placement group bundles for host distribution. - Automatically requests resources from Ray autoscaler. """ def __init__(self, tpu_type: str | None): """Initialize a slice pool manager. Args: tpu_type: Type of TPU to manage (e.g., "v4-8", "v5e-16"). Used for resource labeling and identification. """ super().__init__() self._tpu_type = tpu_type self._last_scale_ts: float | None = None self._last_scale_check_ts: float | None = None self._head_pg = None self._head_pg_target = 0
[docs] def get_actor_pool_name(self) -> str: """Get a human-readable name for this actor pool. Returns: String identifier for the pool, defaults to class name. """ return f"SlicePool({self._tpu_type})"
[docs] def get_actor_name_from_actor_info(self, actor_info: SliceInfo) -> str: """Generate a human-readable name from actor info. Args: actor_info: Metadata about the actor. Returns: String representation of the actor for logging. """ return actor_info.slice_name
[docs] def create_actor(self) -> ActorHandle: """Create a new SliceActor to manage a TPU slice. Creates a SliceActor with appropriate resource requirements based on the TPU type. The actor will manage all hosts within its assigned slice. Returns: Ray actor handle for the newly created SliceActor. """ return SliceActor.options(num_cpus=0, resources={f"TPU-{self._tpu_type}-head": 1}).remote()
[docs] def scale_multislice(self, num_slices: int | Sequence[int]) -> None: """Scale the pool to the desired number of slices. Supports flexible scaling with multiple valid sizes. Will scale to the largest feasible size from the provided options. This method is typically called by RayExecutor.autoscale_execute to set up the required number of slices. Args: num_slices: Target number of slices or list of valid sizes. If int: exact number of slices required. If sequence: will try largest first, falling back to smaller. Raises: ValueError: If target is invalid or empty list provided. InsufficientSlicesError: If none of the requested sizes can be achieved. Example: >>> manager.scale_multislice(4) >>> manager.scale_multislice([2, 4, 8]) Note: - Requests TPU head resources from Ray autoscaler. - Removes unhealthy actors before scaling. - Falls back to smaller sizes if larger ones unavailable. """ self._last_scale_ts = time.time() if isinstance(num_slices, int): valid = [int(num_slices)] else: valid = sorted({int(x) for x in num_slices}) if not valid: raise ValueError("valid sizes list is empty") target = valid[-1] if target <= 0: raise ValueError(f"Target slice count must be > 0, got {target}") head_bundles = [{"CPU": 0, f"TPU-{self._tpu_type}-head": 1} for _ in range(target)] request_resources(bundles=head_bundles) self._scale_actor_pool(target) current = len(self._actor_pool) if current not in valid: feasible = [v for v in valid if v <= current] if not feasible: raise InsufficientSlicesError(f"Requested one of {valid}, but only {current} slices available") self._scale_actor_pool(feasible[-1])
[docs] def prepare_all_slices(self) -> None: """Prepare all slices by ensuring host placement groups. Pre-requests resources for all slices and prepares their host placement groups for distributed execution. This ensures that all nodes are ready before task execution begins. This method: 1. Fetches slice information from all SliceActors. 2. Requests host resources for each slice from autoscaler. 3. Creates placement groups with STRICT_SPREAD strategy. 4. Ensures all hosts are discovered and ready. Note: Called automatically by autoscale_execute before running tasks. Essential for proper multi-host coordination within each slice. """ if os.getenv("EFORMER_SAFE_GATHER", "1") == "1": slice_infos = [] good_members = [] for m in self._actor_pool: try: si = ray.get(m.actor.get_info.remote(), timeout=30) slice_infos.append(si) good_members.append(m) except Exception as e: try: ray.kill(m.actor, no_restart=True) except Exception: pass logger.warning(f"Pruned dead SliceActor during prepare_all_slices: {e}") if len(good_members) != len(self._actor_pool): self._actor_pool = good_members if not self._actor_pool: raise RuntimeError("No SliceActors available after pruning.") all_bundles = [] for info in slice_infos: all_bundles.extend([{"CPU": 0, info.slice_name: 1}] * info.num_hosts) if all_bundles: request_resources(bundles=all_bundles) ray.get([m.actor.prepare_hosts.remote() for m in self._actor_pool]) else: slice_infos: list[SliceInfo] = ray.get([m.actor.get_info.remote() for m in self._actor_pool]) all_bundles = [] for info in slice_infos: all_bundles.extend([{"CPU": 0, info.slice_name: 1}] * info.num_hosts) if all_bundles: request_resources(bundles=all_bundles) ray.get([m.actor.prepare_hosts.remote() for m in self._actor_pool])
[docs] def should_scale_up_multislice(self, valid_sizes: Sequence[int]) -> bool: """Check if pool should scale up to a larger size. Implements rate limiting to prevent frequent scaling operations. Args: valid_sizes: List of valid pool sizes. Returns: True if scaling up is recommended. """ self._last_scale_check_ts = time.time() current = len(self._actor_pool) larger = [size for size in valid_sizes if size > current] if not larger: return False if self._last_scale_ts and (time.time() - self._last_scale_ts) < 60: return False return True
[docs] def execute_on_each_slice(self, remote_fn, env: dict | None = None, runtime_env: dict | None = None): """Execute a function on all hosts across all slices. Prepares all slices and runs the function on every host in parallel. Returns results grouped by slice. Args: remote_fn: Ray remote function or callable to execute. env: Optional environment variables to set. runtime_env: Optional Ray runtime environment configuration. Returns: List of lists where outer list represents slices and inner lists contain ObjectRefs for each host in that slice. """ self.prepare_all_slices() per_slice_futures = ray.get( [m.actor.run_remote_fn.remote(remote_fn, runtime_env=runtime_env, env=env) for m in self._actor_pool] ) return per_slice_futures
[docs] def execute_on_each_host_flat(self, remote_fn, env: dict | None = None, runtime_env: dict | None = None): """Execute a function on all hosts, returning a flat list. Similar to execute_on_each_slice but flattens the nested result structure into a single list of ObjectRefs. Args: remote_fn: Ray remote function or callable to execute. env: Optional environment variables to set. runtime_env: Optional Ray runtime environment configuration. Returns: Flat list of ObjectRefs from all hosts across all slices. """ per_slice = self.execute_on_each_slice(remote_fn, env=env, runtime_env=runtime_env) return [f for sub in per_slice for f in sub]
[docs] def execute_on_each_host(self, fn, *args, env: dict | None = None, **kwargs): """Execute a function on each host across all slices. Prepares all slices and runs the function on every host actor in parallel across all slices. Args: fn: Function to execute on each host. *args: Positional arguments for fn. env: Optional environment variables. **kwargs: Keyword arguments for fn. Returns: Nested list of results (outer list for slices, inner for hosts). """ self.prepare_all_slices() ray.get([member.actor.ensure_host_pool.remote() for member in self._actor_pool]) @ray.remote(max_calls=1) def _runner(): return fn(*args, **kwargs) return ray.get([member.actor.run_remote_fn.remote(_runner, env=env) for member in self._actor_pool])
[docs] def schedule_on_each_host(self, remote_fn, env: dict | None = None, runtime_env: dict | None = None): """Schedule a function on all hosts without waiting for results. Ensures all slices and hosts are ready, then schedules the function execution on all hosts. Returns immediately with ObjectRefs. Args: remote_fn: Ray remote function or callable to execute. env: Optional environment variables to set. runtime_env: Optional Ray runtime environment configuration. Returns: Flat list of ObjectRefs that can be waited on later. """ self.prepare_all_slices() ray.get([m.actor.ensure_host_pool.remote() for m in self._actor_pool]) per_slice_futures = ray.get( [m.actor.run_remote_fn.remote(remote_fn, runtime_env=runtime_env, env=env) for m in self._actor_pool] ) return [f for sl in per_slice_futures for f in sl]
def _add_members_to_actor_pool(self, desired_num_actors: int) -> None: """Add new SliceActors to the pool with timeout and retry logic. Creates and starts SliceActors asynchronously, waiting for them to initialize with configurable timeout. Implements a polling strategy that periodically nudges the Ray autoscaler to provision resources. This method differs from the base class implementation by: - Creating all actors immediately (non-blocking) - Polling with timeout instead of blocking wait - Periodically requesting resources from autoscaler - Processing actors as they become ready Args: desired_num_actors: Target number of SliceActors in the pool. Behavior: 1. Creates actor handles immediately (scheduling is deferred) 2. Polls every SCALE_POLL_S seconds for actors to start 3. Nudges autoscaler each poll to request TPU head resources 4. Processes actors as they complete initialization 5. Kills actors that don't start within SCALE_ADD_TIMEOUT_S Environment Variables: EFORMER_SCALE_POLL_S: Poll interval in seconds (default: 30) EFORMER_SCALE_ADD_TIMEOUT_S: Total timeout in seconds (default: 604800/7 days) Note: - Actors are added to pool as soon as they're ready - Partial success is supported (some actors may start) - Unstarted actors are killed after timeout - Progress is logged throughout the scaling process """ current = len(self._actor_pool) if current >= desired_num_actors: return if current != 0: logger.info("Recreating head PG due to non-zero current pool (align bundles with slots)") self.drain_actor_pool() current = 0 logger.info(f"Scaling up pool {self.get_actor_pool_name()} from {current} to {desired_num_actors}") self._ensure_head_pg(desired_num_actors) slot_to_actor: dict[int, ActorHandle] = {} slot_to_info_ref: dict[int, ray.ObjectRef] = {} def _start_for_slot(slot: int): actor = SliceActor.options( num_cpus=0, scheduling_strategy=PlacementGroupSchedulingStrategy( self._head_pg, placement_group_bundle_index=slot, placement_group_capture_child_tasks=False, ), ).remote() slot_to_actor[slot] = actor slot_to_info_ref[slot] = actor.get_info.remote() for slot in range(current, desired_num_actors): _start_for_slot(slot) deadline = time.time() + SCALE_ADD_TIMEOUT_S started = 0 while slot_to_info_ref and time.time() < deadline: remaining = len(slot_to_info_ref) head_bundles = [{"CPU": 0, f"TPU-{self._tpu_type}-head": 1} for _ in range(remaining)] try: request_resources(bundles=head_bundles) except Exception: pass pending_refs = list(slot_to_info_ref.values()) done, _ = ray.wait(pending_refs, num_returns=1, timeout=SCALE_POLL_S) if not done: continue for ref in done: slot = next((s for s, r in slot_to_info_ref.items() if r == ref), None) if slot is None: continue actor = slot_to_actor.get(slot) slot_to_info_ref.pop(slot, None) try: info = ray.get(ref, timeout=0) self._actor_pool.append(ActorPoolMember(actor, info)) started += 1 logger.info(f"Added actor {self.get_actor_name_from_actor_info(info)} (slot {slot})") except Exception as e: logger.warning(f"SliceActor for slot {slot} failed to start: {e}; killing and re-queuing") try: ray.kill(actor, no_restart=True) except Exception: pass _start_for_slot(slot) logger.info(f"Started {started}/{desired_num_actors - current} slice actors so far") if slot_to_info_ref: for _, actor in slot_to_actor.items(): try: if not any(m.actor == actor for m in self._actor_pool): ray.kill(actor, no_restart=True) except Exception: pass logger.info(f"Started {started}/{desired_num_actors - current} slice actors (timed out for the rest)") def _ensure_head_pg(self, desired_num_actors: int) -> None: """Ensure head placement group exists with the correct number of bundles. Creates or recreates the head placement group to match the desired number of actors. The placement group uses STRICT_SPREAD strategy to ensure SliceActors are distributed across different nodes for optimal resource utilization and fault tolerance. Each bundle in the placement group reserves one TPU head resource for a SliceActor. If the current placement group doesn't match the target size, it's destroyed and recreated with the correct number of bundles. Args: desired_num_actors: Target number of SliceActors (and placement group bundles) Side Effects: - Requests resources from Ray autoscaler for the new placement group - Destroys existing placement group if size mismatch - Updates _head_pg and _head_pg_target instance variables - Blocks until placement group is ready Note: The placement group uses TPU-{tpu_type}-head resource labels for each bundle to ensure proper resource allocation by the Ray autoscaler. """ label = f"TPU-{self._tpu_type}-head" if self._head_pg and self._head_pg_target == desired_num_actors: return if self._head_pg: try: remove_placement_group(self._head_pg) except Exception: pass self._head_pg = None self._head_pg_target = 0 bundles = [{"CPU": 0, label: 1} for _ in range(desired_num_actors)] request_resources(bundles=bundles) self._head_pg = placement_group(bundles, strategy="STRICT_SPREAD") ray.get(self._head_pg.ready()) self._head_pg_target = desired_num_actors def _destroy_head_pg(self) -> None: """Destroy the current head placement group and reset tracking variables. Removes the head placement group from Ray's cluster state and resets the instance variables that track placement group state. This is typically called during pool shutdown or when recreating placement groups with different sizes. The method uses best-effort cleanup - if the placement group removal fails (e.g., due to Ray cluster issues), the error is logged but not propagated to avoid disrupting the overall shutdown process. Side Effects: - Removes placement group from Ray cluster - Sets _head_pg to None - Resets _head_pg_target to 0 - Frees up cluster resources reserved by the placement group Note: This is a cleanup operation that should be called when the placement group is no longer needed, such as during pool draining or before recreating with a different size. """ if self._head_pg: try: remove_placement_group(self._head_pg) except Exception: pass self._head_pg = None self._head_pg_target = 0
[docs] def drain_actor_pool(self) -> None: """Shut down and remove all actors from the pool. Attempts graceful shutdown first, then forcefully kills actors. Clears the actor pool after draining. """ super().drain_actor_pool() self._destroy_head_pg()