eformer.serialization.async_manager#

class eformer.serialization.async_manager.AsyncCheckpointManager(enable: bool | None = None, float_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, verbose: bool = False, gcs_bucket: str | None = None, gcs_credentials_path: str | None = None, enable_validation: bool = False, enable_compression: bool = False, use_tensorstore: bool = True)[source]#

Bases: object

Checkpoint manager with concurrent operations.

This manager provides checkpoint saving and loading with support for parallel operations, tensorstore backend, validation, and compression. Supports both TensorStore (for large-scale distributed checkpoints) and SafeTensors (for smaller, single-file checkpoints) formats.

Key Features:
  • Automatic format detection (TensorStore vs SafeTensors)

  • Parallel I/O operations for faster loading/saving

  • CPU offloading to prevent OOM on accelerators

  • Checksum validation for data integrity

  • Support for sharded checkpoints across multiple files

  • Pattern-based partition rules with preserved ordering

float_dtype#

Default data type for floating point arrays.

enable#

Whether checkpointing is enabled.

verbose#

Enable verbose output.

gcs_bucket#

Google Cloud Storage bucket name.

enable_validation#

Enable checksum validation.

enable_compression#

Enable compression for tensorstore.

use_tensorstore#

Use tensorstore backend when available.

Example

>>> manager = AsyncCheckpointManager(
...     enable_validation=True,
...     use_tensorstore=True
... )
>>>
>>> manager.save(model_state, "checkpoint", mesh=mesh)
>>>
>>> rules = [(".*kernel", PartitionSpec("model", None))]
>>> state, meta = manager.load("checkpoint", mesh, partition_rules=rules)
static compute_checksum(array: Array) str[source]#

Compute SHA256 checksum for validation.

Converts array to bytes and computes SHA256 hash for data integrity verification.

Parameters

array – JAX array to compute checksum for.

Returns

SHA256 checksum as hexadecimal string.

Note

Arrays are converted to numpy before hashing for consistency.

property global_manager: GlobalAsyncCheckpointManager#

Get or create the global async checkpoint manager.

Returns

The singleton manager instance.

Return type

GlobalAsyncCheckpointManager

static is_tensorstore(path) bool[source]#

Check if a checkpoint path uses TensorStore format.

Parameters

path – Path to check for TensorStore format.

Returns

True if the path contains or points to a TensorStore checkpoint.

Return type

bool

load(path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, mesh: Mesh, shardings: dict[jax.sharding.NamedSharding] | None | dict[Callable] = None, mismatch_allowed: bool = True, callback: Optional[Callable[[Array, str], Array]] = None, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, dtype: str | numpy.dtype | None = None, validate: bool | None = None, prefix_filter: str | None = None, prefix: str | None = None, use_async: bool = True) tuple[eformer.pytree._pytree.PyTree | dict, dict][source]#

Synchronous load method that can work with or without async.

Automatically detects checkpoint format (TensorStore or SafeTensors) and loads accordingly. Can be called without async/await.

Parameters
  • path – Path to the checkpoint directory or file.

  • mesh – JAX mesh for distributed computation. Required for proper sharding.

  • shardings – PyTree of sharding specifications matching checkpoint structure, or dict mapping keys to functions that process/reshard arrays after loading.

  • mismatch_allowed – Whether to allow missing shard functions without error.

  • callback – Optional callback to process each array after loading. Receives (array, key) and returns processed array.

  • partition_rules – List of (regex, PartitionSpec) tuples for pattern-based sharding. Applied to arrays matching the regex patterns. Preserves order of arrays during loading.

  • dtype – Data type to cast arrays to after loading.

  • validate – Whether to validate checksums. If None, uses self.enable_validation.

  • prefix_filter – Deprecated. Use ‘prefix’ instead.

  • prefix – Optional prefix for loading specific tree (e.g., ‘model’, ‘optimizer’). Required when checkpoint contains multiple prefixes.

  • use_async – Whether to use parallel loading (faster) or sequential loading.

Returns

Tuple of (loaded tree, metadata dictionary). Tree is unflattened to nested structure.

Raises
  • ValueError – If validation fails or prefix not found.

  • FileNotFoundError – If checkpoint doesn’t exist.

Note

  • Automatically detects TensorStore format by checking for .zarray files or tensorstore_index.json.

  • When using partition_rules, the order of loaded arrays is preserved to ensure consistent sharding application.

Example

>>> manager = AsyncCheckpointManager()
>>> rules = [(".*weight", PartitionSpec("model", None))]
>>> tree, meta = manager.load("checkpoint", mesh, partition_rules=rules)
load_pytree(path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str, mesh: Mesh, *, prefix: str, shardings: dict[str, Callable] | None = None, partition_rules: Optional[Sequence[tuple[str, jax.sharding.PartitionSpec]]] = None, dtype: numpy.dtype | None = None, template: eformer.pytree._pytree.PyTree | None = None, strict_shapes: bool = True) tuple[eformer.pytree._pytree.PyTree, dict][source]#

Load a PyTree saved by save_pytree with the same prefix.

Loads a PyTree structure from disk that was previously saved with save_pytree. Supports both TensorStore and SafeTensors backends with automatic detection.

Parameters
  • path – Directory path containing the saved checkpoint.

  • mesh – JAX mesh for distributed computation and array sharding.

  • prefix – Required prefix that must match the one used during save.

  • shardings – Optional dictionary mapping array keys to sharding functions.

  • partition_rules – Optional sequence of (regex, PartitionSpec) tuples for pattern-based array sharding.

  • dtype – Optional data type to cast arrays to after loading.

Returns

Tuple of (loaded PyTree, metadata dictionary).

Raises
  • ValueError – If prefix is empty, doesn’t match saved prefix, or data is corrupted.

  • FileNotFoundError – If structure file or arrays are missing.

load_tree_parallel(path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, shardings: None | dict[Callable] = None, mismatch_allowed: bool = True, callback: Optional[Callable[[Array, str], Array]] = None, dtype: str | numpy.dtype | None = None, validate: bool | None = None, prefix_filter: str | None = None) tuple[eformer.pytree._pytree.PyTree | dict, dict][source]#

Load checkpoint with parallel shard reading.

Parameters
  • path – Path to the checkpoint.

  • shardings – PyTree of sharding specifications or dict of functions.

  • mismatch_allowed – Whether to allow missing shard functions.

  • callback – Optional callback to process arrays.

  • dtype – Data type to cast arrays to.

  • validate – Whether to validate checksums.

  • prefix_filter – Optional prefix to filter shards.

Returns

Tuple of (loaded tree, metadata dictionary).

Raises

ValueError – If checkpoint validation fails.

static safe_loadpath(path) eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath[source]#

Convert a checkpoint path to a safe loadable format.

Strips TensorStore index filename if present to get the base directory.

Parameters

path – Checkpoint path that may include index filename.

Returns

Cleaned path suitable for loading.

Return type

ePathLike

save(tree: PyTree, path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, mesh: jax._src.mesh.Mesh | None = None, gather_fns: dict[Callable] | bool | None = None, float_dtype: str | numpy.dtype | None = None, metadata: dict[str, str] | None = None, callback: Optional[Callable[[str], None]] = None, prefix: str | None = None, do_all_gather: bool = False, cpu_offload: bool = False) str#

Save checkpoint with parallel shard writing.

Saves a PyTree structure to disk using either TensorStore or SafeTensors format, with support for sharding large checkpoints and parallel I/O operations.

Parameters
  • tree – PyTree structure to save.

  • path – Path where the checkpoint will be saved.

  • mesh – JAX mesh for distributed computation. If None, creates a CPU mesh with a warning about potential sharding issues.

  • gather_fns – Dictionary of gather functions or bool for device gathering. If True, uses jax.device_get for all arrays.

  • float_dtype – Data type for floating point arrays. Defaults to self.float_dtype.

  • metadata – Additional metadata to save with checkpoint.

  • callback – Optional callback function called after save completes.

  • prefix – Optional prefix for saving specific tree (e.g., ‘model’, ‘optimizer’). Used for organizing multiple trees in same directory.

  • do_all_gather – Whether to gather all arrays to host before saving. Defaults to True for safer and more consistent checkpoint saving.

  • cpu_offload – Whether to offload arrays to CPU during gathering. Defaults to True to reduce memory pressure on accelerators and prevent OOM errors.

Returns

Path where the checkpoint was saved.

Note

  • Automatically chooses between TensorStore (if available and enabled) or SafeTensors format based on configuration.

  • When mesh is not provided, a warning is logged and CPU mesh is used as fallback.

  • CPU offloading helps prevent out-of-memory errors on GPUs/TPUs during checkpointing.

  • Arrays are automatically flattened before saving and unflattened when loading.

Example

>>> manager = AsyncCheckpointManager()
>>> manager.save_tree(
...     tree=model_state,
...     path="checkpoint",
...     mesh=mesh,
...     prefix="model",
...     cpu_offload=True
... )
save_pytree(pytree: PyTree, path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str, mesh: jax._src.mesh.Mesh | None = None, *, prefix: str, do_all_gather: bool = False, cpu_offload: bool = True, dtype: numpy.dtype | None = None, extras: dict | None = None, write_index: bool = True) str[source]#

Save a PyTree with exact structure and prefix.

Saves a PyTree structure to disk with support for both TensorStore and SafeTensors backends. Arrays are saved to <path>/<prefix>/…, while index and structure metadata go to <path>/.

Parameters
  • pytree – PyTree structure to save.

  • path – Directory path where the checkpoint will be saved.

  • mesh – JAX mesh for distributed computation.

  • prefix – Required prefix for organizing the saved tree (e.g., ‘model’, ‘optimizer’).

  • do_all_gather – Whether to gather all arrays to host before saving.

  • cpu_offload – Whether to offload arrays to CPU during gathering to prevent OOM.

  • dtype – Optional data type to cast arrays to before saving.

  • extras – Additional metadata to save with the checkpoint.

  • write_index – Whether to write the index file (for TensorStore backend).

Returns

Path where the checkpoint was saved.

Return type

str

Raises
  • ValueError – If prefix is empty or not a string.

  • FileNotFoundError – If TensorStore index creation fails.

save_tree(tree: PyTree, path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, mesh: jax._src.mesh.Mesh | None = None, gather_fns: dict[Callable] | bool | None = None, float_dtype: str | numpy.dtype | None = None, metadata: dict[str, str] | None = None, callback: Optional[Callable[[str], None]] = None, prefix: str | None = None, do_all_gather: bool = False, cpu_offload: bool = False) str[source]#

Save checkpoint with parallel shard writing.

Saves a PyTree structure to disk using either TensorStore or SafeTensors format, with support for sharding large checkpoints and parallel I/O operations.

Parameters
  • tree – PyTree structure to save.

  • path – Path where the checkpoint will be saved.

  • mesh – JAX mesh for distributed computation. If None, creates a CPU mesh with a warning about potential sharding issues.

  • gather_fns – Dictionary of gather functions or bool for device gathering. If True, uses jax.device_get for all arrays.

  • float_dtype – Data type for floating point arrays. Defaults to self.float_dtype.

  • metadata – Additional metadata to save with checkpoint.

  • callback – Optional callback function called after save completes.

  • prefix – Optional prefix for saving specific tree (e.g., ‘model’, ‘optimizer’). Used for organizing multiple trees in same directory.

  • do_all_gather – Whether to gather all arrays to host before saving. Defaults to True for safer and more consistent checkpoint saving.

  • cpu_offload – Whether to offload arrays to CPU during gathering. Defaults to True to reduce memory pressure on accelerators and prevent OOM errors.

Returns

Path where the checkpoint was saved.

Note

  • Automatically chooses between TensorStore (if available and enabled) or SafeTensors format based on configuration.

  • When mesh is not provided, a warning is logged and CPU mesh is used as fallback.

  • CPU offloading helps prevent out-of-memory errors on GPUs/TPUs during checkpointing.

  • Arrays are automatically flattened before saving and unflattened when loading.

Example

>>> manager = AsyncCheckpointManager()
>>> manager.save_tree(
...     tree=model_state,
...     path="checkpoint",
...     mesh=mesh,
...     prefix="model",
...     cpu_offload=True
... )
class eformer.serialization.async_manager.CheckpointMetadata(version: str = '0.0.90', timestamp: str = None, checksum: dict[str, str] = None, array_metadata: dict[str, dict] = None, framework_version: str = None, custom_metadata: dict = None)[source]#

Bases: object

Enhanced metadata for checkpoints with versioning and validation.

Stores comprehensive metadata about a checkpoint including version information, timestamps, checksums for validation, and custom user metadata.

version#

Version string for the checkpoint format.

Type

str

timestamp#

ISO format timestamp of when checkpoint was created.

Type

str

checksum#

Dictionary mapping array keys to SHA256 checksums.

Type

dict[str, str]

array_metadata#

Dictionary mapping array keys to shape/dtype info.

Type

dict[str, dict]

framework_version#

Version of the framework used to create checkpoint.

Type

str

custom_metadata#

User-defined metadata dictionary.

Type

dict

array_metadata: dict[str, dict] = None#
checksum: dict[str, str] = None#
custom_metadata: dict = None#
framework_version: str = None#
classmethod from_dict(data: dict) CheckpointMetadata[source]#

Create CheckpointMetadata from dictionary.

Parameters

data – Dictionary containing metadata fields.

Returns

CheckpointMetadata instance.

timestamp: str = None#
to_dict() dict[source]#

Convert metadata to dictionary format.

Returns

Dictionary representation of the metadata.

version: str = '0.0.90'#