Source code for eformer.paths

# Copyright 2025 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Universal path utilities for local and cloud storage.

Provides a unified API for working with paths across different storage
backends including local filesystem and Google Cloud Storage (GCS).

Classes:
    UniversalPath: Abstract base class for path operations
    LocalPath: Local filesystem path implementation
    GCSPath: Google Cloud Storage path implementation
    PathManager: Factory for creating appropriate path objects
    MLUtilPath: Extended path manager with ML utilities

Key Features:
    - Unified API for local and cloud storage
    - Transparent switching between storage backends
    - Support for JAX array and dictionary I/O
    - Recursive directory operations
    - Path manipulation and traversal

Example:
    >>> from eformer.paths import ePath
    >>>
    >>>
    >>> local_path = ePath("data/model.pkl")
    >>> local_path.write_bytes(data)
    >>>
    >>>
    >>> gcs_path = ePath("gs://bucket/model.pkl")
    >>> gcs_path.write_bytes(data)
    >>>
    >>>
    >>> ePath.save_jax_array(array, "gs://bucket/weights.npy")
    >>> loaded = ePath.load_jax_array("gs://bucket/weights.npy")

"""

import io
import json
import os
import pickle
import typing as tp
from abc import ABC, abstractmethod
from collections.abc import Iterator
from pathlib import Path
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
from google.cloud import storage


[docs]class UniversalPath(ABC): """Abstract base class for universal path operations. Defines the interface for path operations that work across different storage backends. All concrete implementations must provide these methods. This class follows the pathlib.Path API where possible to provide a familiar interface for Python developers. """
[docs] @abstractmethod def exists(self) -> bool: """Check if the path exists. Returns: True if the path exists, False otherwise. """ pass
[docs] @abstractmethod def read_text(self, encoding: str = "utf-8") -> str: """Read text content from the path. Args: encoding: Text encoding to use. Returns: The text content of the file. Raises: FileNotFoundError: If the path doesn't exist. ValueError: If trying to read from a directory. """ pass
[docs] @abstractmethod def write_text(self, data: str, encoding: str = "utf-8") -> None: """Write text content to the path. Args: data: Text data to write. encoding: Text encoding to use. Raises: ValueError: If trying to write to a directory. """ pass
[docs] @abstractmethod def read_bytes(self) -> bytes: """Read binary content from the path. Returns: The binary content of the file. Raises: FileNotFoundError: If the path doesn't exist. ValueError: If trying to read from a directory. """ pass
[docs] @abstractmethod def write_bytes(self, data: bytes) -> None: """Write binary content to the path. Args: data: Binary data to write. Raises: ValueError: If trying to write to a directory. """ pass
[docs] @abstractmethod def mkdir(self, parents: bool = True, exist_ok: bool = True) -> None: """Create directory at this path. Args: parents: Create parent directories if needed. exist_ok: Don't raise error if directory exists. Raises: FileExistsError: If exist_ok is False and path exists. """ pass
[docs] @abstractmethod def is_dir(self) -> bool: """Check if the path is a directory. Returns: True if the path is a directory, False otherwise. """ pass
[docs] @abstractmethod def is_file(self) -> bool: """Check if the path is a file. Returns: True if the path is a file, False otherwise. """ pass
[docs] @abstractmethod def iterdir(self) -> Iterator["UniversalPath"]: """Iterate over the contents of a directory. Yields: UniversalPath objects for each item in the directory. Raises: NotADirectoryError: If the path is not a directory. """ pass
[docs] @abstractmethod def glob(self, pattern: str, recursive: bool = False) -> Iterator["UniversalPath"]: """Find paths matching a glob pattern. Args: pattern: Glob pattern to match (e.g., "*.txt", "**/*.py"). recursive: If True, search recursively through subdirectories. Yields: UniversalPath objects for each matching path. """ pass
@abstractmethod def __truediv__(self, other) -> "UniversalPath": """Join path with another path component using / operator. Args: other: Path component to append. Returns: New UniversalPath with the combined path. """ pass @abstractmethod def __str__(self) -> str: """Return string representation of the path. Returns: String representation of the path. """ pass
[docs] @abstractmethod def as_posix(self) -> str: """Return the string representation with forward slashes. Returns: Path string with forward slashes as separators. """ pass
[docs] @abstractmethod def stem(self) -> str: """Return the final path component without its suffix. Returns: The stem of the final path component. Example: >>> path = LocalPath("/data/model.tar.gz") >>> path.stem() 'model.tar' """ pass
[docs] @abstractmethod def suffixes(self) -> list[str]: """Return a list of the path's file suffixes. Returns: List of suffixes including the leading dots. Example: >>> path = LocalPath("/data/model.tar.gz") >>> path.suffixes() ['.tar', '.gz'] """ pass
[docs] @abstractmethod def with_name(self, name: str) -> "UniversalPath": """Return a new path with the name changed. Args: name: New name for the final path component. Returns: New path with the name replaced. """ pass
[docs] @abstractmethod def with_suffix(self, suffix: str) -> "UniversalPath": """Return a new path with the suffix changed. Args: suffix: New suffix (including leading dot). Returns: New path with the suffix replaced. """ pass
[docs] @abstractmethod def with_stem(self, stem: str) -> "UniversalPath": """Return a new path with the stem changed. Args: stem: New stem for the final path component. Returns: New path with the stem replaced. """ pass
[docs] @abstractmethod def parts(self) -> tuple[str, ...]: """Return a tuple of the path components. Returns: Tuple of individual path components. """ pass
[docs] @abstractmethod def relative_to(self, other: "UniversalPath") -> "UniversalPath": """Return a relative path from other to this path. Args: other: Base path to compute relative path from. Returns: Relative path from other to this path. Raises: ValueError: If this path is not relative to other. """ pass
[docs] @abstractmethod def is_absolute(self) -> bool: """Return True if the path is absolute. Returns: True if the path is absolute, False otherwise. """ pass
[docs] @abstractmethod def resolve(self) -> "UniversalPath": """Make the path absolute, resolving any symlinks. Returns: Absolute path with symlinks resolved. """ pass
[docs] @abstractmethod def rmdir(self) -> None: """Remove this directory. The directory must be empty. Raises: OSError: If the directory is not empty. NotADirectoryError: If the path is not a directory. """ pass
[docs] @abstractmethod def rename(self, target: "UniversalPath") -> "UniversalPath": """Rename this path to the given target. Args: target: New path name. Returns: New path object pointing to target. """ pass
[docs] @abstractmethod def stat(self) -> dict[str, Any]: """Return file statistics. Returns: Dictionary containing file metadata such as size, mtime, etc. Raises: FileNotFoundError: If the path doesn't exist. """ pass
[docs]class LocalPath(UniversalPath): """Local filesystem path implementation. Wraps pathlib.Path to provide the UniversalPath interface for local filesystem operations. Attributes: path: The underlying pathlib.Path object. Example: >>> path = LocalPath("/data/model.pkl") >>> path.exists() True >>> path.parent LocalPath('/data') >>> (path.parent / "config.json").write_text(config) """ def __init__(self, path: str | Path): """Initialize LocalPath. Args: path: Path string or pathlib.Path object. """ self.path = Path(path)
[docs] def exists(self) -> bool: return self.path.exists()
[docs] def read_text(self, encoding: str = "utf-8") -> str: return self.path.read_text(encoding=encoding)
[docs] def write_text(self, data: str, encoding: str = "utf-8") -> None: self.path.parent.mkdir(parents=True, exist_ok=True) self.path.write_text(data, encoding=encoding)
[docs] def read_bytes(self) -> bytes: return self.path.read_bytes()
[docs] def write_bytes(self, data: bytes) -> None: self.path.parent.mkdir(parents=True, exist_ok=True) self.path.write_bytes(data)
[docs] def mkdir(self, parents: bool = True, exist_ok: bool = True) -> None: if not str(self.path).startswith("/dev/null"): self.path.mkdir(parents=parents, exist_ok=exist_ok)
[docs] def is_dir(self) -> bool: return self.path.is_dir()
[docs] def is_file(self) -> bool: return self.path.is_file()
[docs] def iterdir(self) -> Iterator["LocalPath"]: if self.path.is_dir(): for item in self.path.iterdir(): yield LocalPath(item)
[docs] def glob(self, pattern: str, recursive: bool = False) -> Iterator["LocalPath"]: pattern = pattern.replace("**", "*") if not recursive else pattern for item in self.path.glob(pattern): yield LocalPath(item)
def __truediv__(self, other) -> "LocalPath": return LocalPath(self.path / str(other)) def __str__(self) -> str: return str(self.path) def __repr__(self) -> str: return f"LocalPath('{self.path}')" @property def name(self) -> str: return self.path.name @property def suffix(self) -> str: return self.path.suffix @property def parent(self) -> "LocalPath": return LocalPath(self.path.parent)
[docs] def as_posix(self) -> str: return self.path.as_posix()
[docs] def stem(self) -> str: return self.path.stem
[docs] def suffixes(self) -> list[str]: return self.path.suffixes
[docs] def with_name(self, name: str) -> "LocalPath": return LocalPath(self.path.with_name(name))
[docs] def with_suffix(self, suffix: str) -> "LocalPath": return LocalPath(self.path.with_suffix(suffix))
[docs] def with_stem(self, stem: str) -> "LocalPath": return LocalPath(self.path.with_stem(stem))
[docs] def parts(self) -> tuple[str, ...]: return self.path.parts
[docs] def relative_to(self, other: "LocalPath") -> "LocalPath": if isinstance(other, LocalPath): return LocalPath(self.path.relative_to(other.path)) else: return LocalPath(self.path.relative_to(Path(str(other))))
[docs] def is_absolute(self) -> bool: return self.path.is_absolute()
[docs] def resolve(self) -> "LocalPath": return LocalPath(self.path.resolve())
[docs] def rmdir(self) -> None: self.path.rmdir()
[docs] def rename(self, target: "LocalPath") -> "LocalPath": if isinstance(target, LocalPath): new_path = self.path.rename(target.path) else: new_path = self.path.rename(Path(str(target))) return LocalPath(new_path)
[docs] def stat(self) -> dict[str, Any]: stat_result = self.path.stat() return { "size": stat_result.st_size, "mtime": stat_result.st_mtime, "ctime": stat_result.st_ctime, "atime": stat_result.st_atime, "mode": stat_result.st_mode, "uid": stat_result.st_uid, "gid": stat_result.st_gid, }
[docs]class GCSPath(UniversalPath): """Google Cloud Storage path implementation. Provides UniversalPath interface for Google Cloud Storage operations. Handles blob operations, bucket management, and directory emulation. Attributes: path: Full GCS path string (gs://bucket/path). client: Google Cloud Storage client. bucket_name: Name of the GCS bucket. blob_name: Path within the bucket. Example: >>> path = GCSPath("gs://my-bucket/data/model.pkl") >>> path.exists() True >>> path.write_bytes(model_bytes) >>> for item in path.parent.iterdir(): ... print(item.name) """ def __init__(self, path: str, client: storage.Client | None = None): """Initialize GCSPath. Args: path: GCS path starting with gs://. client: Optional GCS client, creates default if None. Raises: ValueError: If path doesn't start with gs://. """ if not path.startswith("gs://"): raise ValueError(f"GCS path must start with 'gs://': {path}") self.path = path self.client = client or storage.Client() path_parts = path[5:].split("/", 1) self.bucket_name = path_parts[0] self.blob_name = path_parts[1] if len(path_parts) > 1 else "" self._bucket = None self._blob = None @property def bucket(self): if self._bucket is None: self._bucket = self.client.bucket(self.bucket_name) return self._bucket @property def blob(self): if self._blob is None and self.blob_name: self._blob = self.bucket.blob(self.blob_name) return self._blob
[docs] def exists(self) -> bool: if not self.blob_name: return self.bucket.exists() return self.blob.exists() if self.blob else False
[docs] def read_text(self, encoding: str = "utf-8") -> str: if not self.blob: raise ValueError("Cannot read text from bucket root") return self.blob.download_as_text(encoding=encoding)
[docs] def write_text(self, data: str, encoding: str = "utf-8") -> None: if not self.blob: raise ValueError("Cannot write text to bucket root") self.blob.upload_from_string(data, content_type="text/plain")
[docs] def read_bytes(self) -> bytes: if not self.blob: raise ValueError("Cannot read bytes from bucket root") return self.blob.download_as_bytes()
[docs] def write_bytes(self, data: bytes) -> None: if not self.blob: raise ValueError("Cannot write bytes to bucket root") self.blob.upload_from_string(data, content_type="application/octet-stream")
[docs] def mkdir(self, parents: bool = True, exist_ok: bool = True) -> None: if self.blob_name and not self.blob_name.endswith("/"): placeholder_path = f"{self.blob_name}/" else: placeholder_path = self.blob_name or "" if placeholder_path: placeholder_blob = self.bucket.blob(placeholder_path + ".keep") if not placeholder_blob.exists(): placeholder_blob.upload_from_string("", content_type="text/plain")
[docs] def is_dir(self) -> bool: if not self.blob_name: return True prefix = self.blob_name if self.blob_name.endswith("/") else self.blob_name + "/" blobs = list(self.bucket.list_blobs(prefix=prefix, max_results=1)) return len(blobs) > 0
[docs] def is_file(self) -> bool: return self.exists() and not self.is_dir()
[docs] def iterdir(self) -> Iterator["GCSPath"]: if not self.blob_name: prefix = "" delimiter = "/" else: prefix = self.blob_name if self.blob_name.endswith("/") else self.blob_name + "/" delimiter = "/" for page in self.client.list_blobs(self.bucket_name, prefix=prefix, delimiter=delimiter).pages: for blob in page: if blob.name != prefix: yield GCSPath(f"gs://{self.bucket_name}/{blob.name}", self.client) for prfx in page.prefixes: yield GCSPath(f"gs://{self.bucket_name}/{prfx}", self.client)
[docs] def glob(self, pattern: str, recursive: bool = False) -> Iterator["GCSPath"]: import fnmatch prefix = self.blob_name if self.blob_name.endswith("/") else self.blob_name + "/" if not self.blob_name: prefix = "" if recursive: relative_names = set() for blob in self.bucket.list_blobs(prefix=prefix): relative_name = blob.name[len(prefix) :] rel_name = "" parts = relative_name.split("/") while len(parts) > 0: part = parts.pop(0) rel_name = f"{rel_name}/{part}" if rel_name else part relative_names.add(rel_name) if len(parts) > 0: relative_names.add(rel_name + "/") for relative_name in relative_names: if fnmatch.fnmatch(relative_name, pattern): yield GCSPath(f"gs://{self.bucket_name}/{prefix}{relative_name}", self.client) else: pattern = pattern.replace("**", "*") if not recursive else pattern has_trailing_delimited = pattern.endswith("/") sub_patterns = pattern.rstrip("/").split("/") paths_cache: dict[str, list[GCSPath]] = {} stack = [(prefix, sub_patterns)] while len(stack) > 0: current_prefix, patterns = stack.pop() if len(patterns) == 0: continue current_pattern = patterns[0] remaining_patterns = patterns[1:] if current_prefix not in paths_cache: paths_cache[current_prefix] = list( GCSPath(f"gs://{self.bucket_name}/{current_prefix}", self.client).iterdir() ) paths = paths_cache[current_prefix] for path in paths: blob_name = path.blob_name relative_name = blob_name[len(current_prefix) :] if fnmatch.fnmatch(relative_name, current_pattern): if len(remaining_patterns) == 0 and (not has_trailing_delimited or path.is_dir()): yield path elif blob_name.endswith("/") and len(remaining_patterns) > 0: stack.append((blob_name, remaining_patterns))
def __truediv__(self, other) -> "GCSPath": other = str(other) if self.blob_name: new_path = f"{self.blob_name.rstrip('/')}/{other}" else: new_path = str(other) return GCSPath(f"gs://{self.bucket_name}/{new_path}", self.client) def __str__(self) -> str: return self.path def __repr__(self) -> str: return f"GCSPath('{self.path}')" @property def name(self) -> str: if not self.blob_name: return self.bucket_name return os.path.basename(self.blob_name.rstrip("/")) @property def suffix(self) -> str: name = self.name return os.path.splitext(name)[1] if "." in name else "" @property def parent(self) -> "GCSPath": if not self.blob_name: raise ValueError("Bucket root has no parent") parent_blob = os.path.dirname(self.blob_name.rstrip("/")) if parent_blob: return GCSPath(f"gs://{self.bucket_name}/{parent_blob}/", self.client) else: return GCSPath(f"gs://{self.bucket_name}/", self.client)
[docs] def as_posix(self) -> str: return self.path
[docs] def stem(self) -> str: name = self.name return os.path.splitext(name)[0] if "." in name else name
[docs] def suffixes(self) -> list[str]: name = self.name parts = name.split(".") if len(parts) <= 1: return [] return ["." + part for part in parts[1:]]
[docs] def with_name(self, name: str) -> "GCSPath": if not self.blob_name: raise ValueError("Cannot change name of bucket root") parent_path = os.path.dirname(self.blob_name.rstrip("/")) if parent_path: new_blob = f"{parent_path}/{name}" else: new_blob = name return GCSPath(f"gs://{self.bucket_name}/{new_blob}", self.client)
[docs] def with_suffix(self, suffix: str) -> "GCSPath": stem = self.stem() return self.with_name(stem + suffix)
[docs] def with_stem(self, stem: str) -> "GCSPath": return self.with_name(stem + self.suffix)
[docs] def parts(self) -> tuple[str, ...]: parts = ["gs://", self.bucket_name] if self.blob_name: blob_parts = self.blob_name.strip("/").split("/") parts.extend(blob_parts) return tuple(parts)
[docs] def relative_to(self, other: "GCSPath") -> "GCSPath": if not isinstance(other, GCSPath): raise TypeError("Can only compute relative path to another GCSPath") if self.bucket_name != other.bucket_name: raise ValueError("Cannot compute relative path across different buckets") if not other.blob_name: return GCSPath(f"gs://{self.bucket_name}/{self.blob_name}", self.client) self_parts = self.blob_name.strip("/").split("/") other_parts = other.blob_name.strip("/").split("/") common_len = 0 for i, (a, b) in enumerate(zip(self_parts, other_parts)): # noqa:B905 if a == b: common_len = i + 1 else: break up_levels = len(other_parts) - common_len down_parts = self_parts[common_len:] relative_parts = [".."] * up_levels + down_parts relative_blob = "/".join(relative_parts) if relative_parts else "." return GCSPath(f"gs://{self.bucket_name}/{relative_blob}", self.client)
[docs] def is_absolute(self) -> bool: return True
[docs] def resolve(self) -> "GCSPath": return self
[docs] def rmdir(self) -> None: if not self.is_dir(): raise NotADirectoryError(f"'{self.path}' is not a directory") items = list(self.iterdir()) if items: raise OSError(f"Directory not empty: '{self.path}'") if self.blob_name: keep_blob = self.bucket.blob(self.blob_name.rstrip("/") + "/.keep") if keep_blob.exists(): keep_blob.delete()
[docs] def rename(self, target: "GCSPath") -> "GCSPath": if not isinstance(target, GCSPath): raise TypeError("Target must be a GCSPath") if not self.blob: raise ValueError("Cannot rename bucket root") target.write_bytes(self.read_bytes()) self.unlink() return target
[docs] def stat(self) -> dict[str, Any]: if not self.blob: raise ValueError("Cannot get stats for bucket root") if not self.blob.exists(): raise FileNotFoundError(f"'{self.path}' does not exist") self.blob.reload() return { "size": self.blob.size or 0, "mtime": self.blob.updated.timestamp() if self.blob.updated else 0, "ctime": self.blob.time_created.timestamp() if self.blob.time_created else 0, "atime": self.blob.updated.timestamp() if self.blob.updated else 0, "etag": self.blob.etag, "content_type": self.blob.content_type, "generation": self.blob.generation, }
[docs]class PathManager: """Factory for creating appropriate path objects. Automatically creates LocalPath or GCSPath based on the path string. Manages GCS client creation and credential handling. Attributes: gcs_client: Cached GCS client instance. Example: >>> manager = PathManager() >>> local = manager("/data/file.txt") >>> isinstance(local, LocalPath) True >>> gcs = manager("gs://bucket/file.txt") >>> isinstance(gcs, GCSPath) True """ def __init__( self, gcs_client: storage.Client | None = None, gcs_credentials_path: str | None = None, ): """Initialize PathManager. Args: gcs_client: Optional pre-configured GCS client. gcs_credentials_path: Path to GCS service account credentials. """ self._gcs_client = gcs_client self._gcs_credentials_path = gcs_credentials_path @property def gcs_client(self): if self._gcs_client is None: try: if self._gcs_client is None: if self._gcs_credentials_path: from google.oauth2 import service_account credentials = service_account.Credentials.from_service_account_file(self._gcs_credentials_path) self._gcs_client = storage.Client(credentials=credentials) else: self._gcs_client = storage.Client() except Exception: ... return self._gcs_client def __call__(self, path: str | Path) -> UniversalPath: """Create appropriate path object based on path string. Args: path: Path string or Path object. Returns: LocalPath for local paths, GCSPath for gs:// paths. """ path_str = str(path) if path_str.startswith("gs://"): return GCSPath(path_str, self.gcs_client) else: return LocalPath(path_str)
[docs]class MLUtilPath(PathManager): """Extended path manager with ML-specific utilities. Adds JAX array and dictionary I/O operations to the base PathManager. Supports various serialization formats and handles JAX/NumPy conversions. Example: >>> path_manager = MLUtilPath() >>> >>> path_manager.save_jax_array(array, "gs://bucket/weights.npy") >>> >>> loaded = path_manager.load_jax_array("gs://bucket/weights.npy") >>> >>> path_manager.save_dict({"weights": weights}, "config.json") """
[docs] def save_jax_array(self, array: jax.Array, path: str | UniversalPath, format: str = "npy") -> None: # noqa:A002 """Save JAX array in various formats. Args: array: JAX array to save. path: Destination path (local or GCS). format: Serialization format ('npy' or 'pickle'). Raises: ValueError: If format is not supported. Example: >>> manager.save_jax_array(weights, "weights.npy") >>> manager.save_jax_array(biases, "gs://bucket/biases.pkl", "pickle") """ if isinstance(path, str): path = self(path) if format == "npy": buffer = io.BytesIO() np.save(buffer, np.array(array)) path.write_bytes(buffer.getvalue()) elif format == "pickle": buffer = io.BytesIO() pickle.dump(array, buffer) path.write_bytes(buffer.getvalue()) else: raise ValueError(f"Unsupported format: {format}")
[docs] def load_jax_array(self, path: str | UniversalPath, format: str = "npy") -> jax.Array: # noqa:A002 """Load JAX array from various formats. Args: path: Source path (local or GCS). format: Serialization format ('npy' or 'pickle'). Returns: Loaded JAX array. Raises: ValueError: If format is not supported. FileNotFoundError: If path doesn't exist. Example: >>> weights = manager.load_jax_array("weights.npy") >>> biases = manager.load_jax_array("gs://bucket/biases.pkl", "pickle") """ if isinstance(path, str): path = self(path) data = path.read_bytes() buffer = io.BytesIO(data) if format == "npy": return jnp.array(np.load(buffer)) elif format == "pickle": return pickle.load(buffer) else: raise ValueError(f"Unsupported format: {format}")
[docs] def save_dict(self, data: dict[str, Any], path: str | UniversalPath, format: str = "json") -> None: # noqa:A002 """Save dictionary in various formats. Args: data: Dictionary to save. Values can include JAX arrays which will be converted to lists for JSON format. path: Destination path (local or GCS). format: Serialization format ('json' or 'pickle'). Raises: ValueError: If format is not supported. Example: >>> manager.save_dict({"weights": [1, 2, 3]}, "config.json") >>> manager.save_dict(complex_data, "gs://bucket/data.pkl", "pickle") """ if isinstance(path, str): path = self(path) if format == "json": serializable_data = self._make_json_serializable(data) path.write_text(json.dumps(serializable_data, indent=2)) elif format == "pickle": buffer = io.BytesIO() pickle.dump(data, buffer) path.write_bytes(buffer.getvalue()) else: raise ValueError(f"Unsupported format: {format}")
[docs] def load_dict(self, path: str | UniversalPath, format: str = "json") -> dict[str, Any]: # noqa:A002 """Load dictionary from various formats. Args: path: Source path (local or GCS). format: Serialization format ('json' or 'pickle'). Returns: Loaded dictionary. Raises: ValueError: If format is not supported. FileNotFoundError: If path doesn't exist. Example: >>> config = manager.load_dict("config.json") >>> data = manager.load_dict("gs://bucket/data.pkl", "pickle") """ if isinstance(path, str): path = self(path) if format == "json": return json.loads(path.read_text()) elif format == "pickle": data = path.read_bytes() buffer = io.BytesIO(data) return pickle.load(buffer) else: raise ValueError(f"Unsupported format: {format}")
def _make_json_serializable(self, obj): """Convert JAX arrays and other non-serializable objects to JSON-safe types. Args: obj: Object to convert. Can be a dict, list, tuple, array, or scalar. Returns: JSON-serializable version of the object. """ if isinstance(obj, jax.Array | np.ndarray): return obj.tolist() elif isinstance(obj, dict): return {k: self._make_json_serializable(v) for k, v in obj.items()} elif isinstance(obj, list | tuple): return [self._make_json_serializable(item) for item in obj] elif isinstance(obj, np.integer | np.floating): return obj.item() else: return obj
[docs] def copy_tree(self, src: str | UniversalPath, dst: str | UniversalPath) -> None: """Copy entire directory tree between local and GCS. Recursively copies all files and directories from source to destination. Works across different storage backends (local to GCS, GCS to local, etc.). Args: src: Source path (directory or file). dst: Destination path. Example: >>> >>> manager.copy_tree("data/", "gs://bucket/data/") >>> >>> manager.copy_tree("gs://bucket/model/", "local_model/") """ if isinstance(src, str): src = self(src) if isinstance(dst, str): dst = self(dst) if src.is_file(): data = src.read_bytes() dst.write_bytes(data) else: dst.mkdir(parents=True, exist_ok=True) for item in src.iterdir(): dst_item = dst / item.name self.copy_tree(item, dst_item)
ePath: MLUtilPath = MLUtilPath(gcs_credentials_path=os.getenv("EASYDEL_GCS_CLIENT", None)) ePathLike: tp.TypeAlias = GCSPath | LocalPath | MLUtilPath