eformer.serialization.async_manager#
- class eformer.serialization.async_manager.AsyncCheckpointManager(enable: bool | None = None, float_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, verbose: bool = False, gcs_bucket: str | None = None, gcs_credentials_path: str | None = None, enable_validation: bool = False, enable_compression: bool = False, use_tensorstore: bool = True)[source]#
Bases:
objectCheckpoint manager with concurrent operations.
This manager provides checkpoint saving and loading with support for parallel operations, tensorstore backend, validation, and compression. Supports both TensorStore (for large-scale distributed checkpoints) and SafeTensors (for smaller, single-file checkpoints) formats.
- Key Features:
Automatic format detection (TensorStore vs SafeTensors)
Parallel I/O operations for faster loading/saving
CPU offloading to prevent OOM on accelerators
Checksum validation for data integrity
Support for sharded checkpoints across multiple files
Pattern-based partition rules with preserved ordering
- float_dtype#
Default data type for floating point arrays.
- enable#
Whether checkpointing is enabled.
- verbose#
Enable verbose output.
- gcs_bucket#
Google Cloud Storage bucket name.
- enable_validation#
Enable checksum validation.
- enable_compression#
Enable compression for tensorstore.
- use_tensorstore#
Use tensorstore backend when available.
Example
>>> manager = AsyncCheckpointManager( ... enable_validation=True, ... use_tensorstore=True ... ) >>> >>> manager.save(model_state, "checkpoint", mesh=mesh) >>> >>> rules = [(".*kernel", PartitionSpec("model", None))] >>> state, meta = manager.load("checkpoint", mesh, partition_rules=rules)
- static compute_checksum(array: Array) str[source]#
Compute SHA256 checksum for validation.
Converts array to bytes and computes SHA256 hash for data integrity verification.
- Parameters
array – JAX array to compute checksum for.
- Returns
SHA256 checksum as hexadecimal string.
Note
Arrays are converted to numpy before hashing for consistency.
- property global_manager: GlobalAsyncCheckpointManager#
Get or create the global async checkpoint manager.
- Returns
The singleton manager instance.
- Return type
GlobalAsyncCheckpointManager
- static is_tensorstore(path) bool[source]#
Check if a checkpoint path uses TensorStore format.
- Parameters
path – Path to check for TensorStore format.
- Returns
True if the path contains or points to a TensorStore checkpoint.
- Return type
bool
- load(path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, mesh: Mesh, shardings: dict[jax.sharding.NamedSharding] | None | dict[Callable] = None, mismatch_allowed: bool = True, callback: Optional[Callable[[Array, str], Array]] = None, partition_rules: tuple[tuple[str, jax.sharding.PartitionSpec]] | None = None, dtype: str | numpy.dtype | None = None, validate: bool | None = None, prefix_filter: str | None = None, prefix: str | None = None, use_async: bool = True) tuple[eformer.pytree._pytree.PyTree | dict, dict][source]#
Synchronous load method that can work with or without async.
Automatically detects checkpoint format (TensorStore or SafeTensors) and loads accordingly. Can be called without async/await.
- Parameters
path – Path to the checkpoint directory or file.
mesh – JAX mesh for distributed computation. Required for proper sharding.
shardings – PyTree of sharding specifications matching checkpoint structure, or dict mapping keys to functions that process/reshard arrays after loading.
mismatch_allowed – Whether to allow missing shard functions without error.
callback – Optional callback to process each array after loading. Receives (array, key) and returns processed array.
partition_rules – List of (regex, PartitionSpec) tuples for pattern-based sharding. Applied to arrays matching the regex patterns. Preserves order of arrays during loading.
dtype – Data type to cast arrays to after loading.
validate – Whether to validate checksums. If None, uses self.enable_validation.
prefix_filter – Deprecated. Use ‘prefix’ instead.
prefix – Optional prefix for loading specific tree (e.g., ‘model’, ‘optimizer’). Required when checkpoint contains multiple prefixes.
use_async – Whether to use parallel loading (faster) or sequential loading.
- Returns
Tuple of (loaded tree, metadata dictionary). Tree is unflattened to nested structure.
- Raises
ValueError – If validation fails or prefix not found.
FileNotFoundError – If checkpoint doesn’t exist.
Note
Automatically detects TensorStore format by checking for .zarray files or tensorstore_index.json.
When using partition_rules, the order of loaded arrays is preserved to ensure consistent sharding application.
Example
>>> manager = AsyncCheckpointManager() >>> rules = [(".*weight", PartitionSpec("model", None))] >>> tree, meta = manager.load("checkpoint", mesh, partition_rules=rules)
- load_pytree(path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str, mesh: Mesh, *, prefix: str, shardings: dict[str, Callable] | None = None, partition_rules: Optional[Sequence[tuple[str, jax.sharding.PartitionSpec]]] = None, dtype: numpy.dtype | None = None, template: eformer.pytree._pytree.PyTree | None = None, strict_shapes: bool = True) tuple[eformer.pytree._pytree.PyTree, dict][source]#
Load a PyTree saved by save_pytree with the same prefix.
Loads a PyTree structure from disk that was previously saved with save_pytree. Supports both TensorStore and SafeTensors backends with automatic detection.
- Parameters
path – Directory path containing the saved checkpoint.
mesh – JAX mesh for distributed computation and array sharding.
prefix – Required prefix that must match the one used during save.
shardings – Optional dictionary mapping array keys to sharding functions.
partition_rules – Optional sequence of (regex, PartitionSpec) tuples for pattern-based array sharding.
dtype – Optional data type to cast arrays to after loading.
- Returns
Tuple of (loaded PyTree, metadata dictionary).
- Raises
ValueError – If prefix is empty, doesn’t match saved prefix, or data is corrupted.
FileNotFoundError – If structure file or arrays are missing.
- load_tree_parallel(path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, shardings: None | dict[Callable] = None, mismatch_allowed: bool = True, callback: Optional[Callable[[Array, str], Array]] = None, dtype: str | numpy.dtype | None = None, validate: bool | None = None, prefix_filter: str | None = None) tuple[eformer.pytree._pytree.PyTree | dict, dict][source]#
Load checkpoint with parallel shard reading.
- Parameters
path – Path to the checkpoint.
shardings – PyTree of sharding specifications or dict of functions.
mismatch_allowed – Whether to allow missing shard functions.
callback – Optional callback to process arrays.
dtype – Data type to cast arrays to.
validate – Whether to validate checksums.
prefix_filter – Optional prefix to filter shards.
- Returns
Tuple of (loaded tree, metadata dictionary).
- Raises
ValueError – If checkpoint validation fails.
- static safe_loadpath(path) eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath[source]#
Convert a checkpoint path to a safe loadable format.
Strips TensorStore index filename if present to get the base directory.
- Parameters
path – Checkpoint path that may include index filename.
- Returns
Cleaned path suitable for loading.
- Return type
ePathLike
- save(tree: PyTree, path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, mesh: jax._src.mesh.Mesh | None = None, gather_fns: dict[Callable] | bool | None = None, float_dtype: str | numpy.dtype | None = None, metadata: dict[str, str] | None = None, callback: Optional[Callable[[str], None]] = None, prefix: str | None = None, do_all_gather: bool = False, cpu_offload: bool = False) str#
Save checkpoint with parallel shard writing.
Saves a PyTree structure to disk using either TensorStore or SafeTensors format, with support for sharding large checkpoints and parallel I/O operations.
- Parameters
tree – PyTree structure to save.
path – Path where the checkpoint will be saved.
mesh – JAX mesh for distributed computation. If None, creates a CPU mesh with a warning about potential sharding issues.
gather_fns – Dictionary of gather functions or bool for device gathering. If True, uses jax.device_get for all arrays.
float_dtype – Data type for floating point arrays. Defaults to self.float_dtype.
metadata – Additional metadata to save with checkpoint.
callback – Optional callback function called after save completes.
prefix – Optional prefix for saving specific tree (e.g., ‘model’, ‘optimizer’). Used for organizing multiple trees in same directory.
do_all_gather – Whether to gather all arrays to host before saving. Defaults to True for safer and more consistent checkpoint saving.
cpu_offload – Whether to offload arrays to CPU during gathering. Defaults to True to reduce memory pressure on accelerators and prevent OOM errors.
- Returns
Path where the checkpoint was saved.
Note
Automatically chooses between TensorStore (if available and enabled) or SafeTensors format based on configuration.
When mesh is not provided, a warning is logged and CPU mesh is used as fallback.
CPU offloading helps prevent out-of-memory errors on GPUs/TPUs during checkpointing.
Arrays are automatically flattened before saving and unflattened when loading.
Example
>>> manager = AsyncCheckpointManager() >>> manager.save_tree( ... tree=model_state, ... path="checkpoint", ... mesh=mesh, ... prefix="model", ... cpu_offload=True ... )
- save_pytree(pytree: PyTree, path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str, mesh: jax._src.mesh.Mesh | None = None, *, prefix: str, do_all_gather: bool = False, cpu_offload: bool = True, dtype: numpy.dtype | None = None, extras: dict | None = None, write_index: bool = True) str[source]#
Save a PyTree with exact structure and prefix.
Saves a PyTree structure to disk with support for both TensorStore and SafeTensors backends. Arrays are saved to <path>/<prefix>/…, while index and structure metadata go to <path>/.
- Parameters
pytree – PyTree structure to save.
path – Directory path where the checkpoint will be saved.
mesh – JAX mesh for distributed computation.
prefix – Required prefix for organizing the saved tree (e.g., ‘model’, ‘optimizer’).
do_all_gather – Whether to gather all arrays to host before saving.
cpu_offload – Whether to offload arrays to CPU during gathering to prevent OOM.
dtype – Optional data type to cast arrays to before saving.
extras – Additional metadata to save with the checkpoint.
write_index – Whether to write the index file (for TensorStore backend).
- Returns
Path where the checkpoint was saved.
- Return type
str
- Raises
ValueError – If prefix is empty or not a string.
FileNotFoundError – If TensorStore index creation fails.
- save_tree(tree: PyTree, path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, mesh: jax._src.mesh.Mesh | None = None, gather_fns: dict[Callable] | bool | None = None, float_dtype: str | numpy.dtype | None = None, metadata: dict[str, str] | None = None, callback: Optional[Callable[[str], None]] = None, prefix: str | None = None, do_all_gather: bool = False, cpu_offload: bool = False) str[source]#
Save checkpoint with parallel shard writing.
Saves a PyTree structure to disk using either TensorStore or SafeTensors format, with support for sharding large checkpoints and parallel I/O operations.
- Parameters
tree – PyTree structure to save.
path – Path where the checkpoint will be saved.
mesh – JAX mesh for distributed computation. If None, creates a CPU mesh with a warning about potential sharding issues.
gather_fns – Dictionary of gather functions or bool for device gathering. If True, uses jax.device_get for all arrays.
float_dtype – Data type for floating point arrays. Defaults to self.float_dtype.
metadata – Additional metadata to save with checkpoint.
callback – Optional callback function called after save completes.
prefix – Optional prefix for saving specific tree (e.g., ‘model’, ‘optimizer’). Used for organizing multiple trees in same directory.
do_all_gather – Whether to gather all arrays to host before saving. Defaults to True for safer and more consistent checkpoint saving.
cpu_offload – Whether to offload arrays to CPU during gathering. Defaults to True to reduce memory pressure on accelerators and prevent OOM errors.
- Returns
Path where the checkpoint was saved.
Note
Automatically chooses between TensorStore (if available and enabled) or SafeTensors format based on configuration.
When mesh is not provided, a warning is logged and CPU mesh is used as fallback.
CPU offloading helps prevent out-of-memory errors on GPUs/TPUs during checkpointing.
Arrays are automatically flattened before saving and unflattened when loading.
Example
>>> manager = AsyncCheckpointManager() >>> manager.save_tree( ... tree=model_state, ... path="checkpoint", ... mesh=mesh, ... prefix="model", ... cpu_offload=True ... )
- class eformer.serialization.async_manager.CheckpointMetadata(version: str = '0.0.90', timestamp: str = None, checksum: dict[str, str] = None, array_metadata: dict[str, dict] = None, framework_version: str = None, custom_metadata: dict = None)[source]#
Bases:
objectEnhanced metadata for checkpoints with versioning and validation.
Stores comprehensive metadata about a checkpoint including version information, timestamps, checksums for validation, and custom user metadata.
- version#
Version string for the checkpoint format.
- Type
str
- timestamp#
ISO format timestamp of when checkpoint was created.
- Type
str
- checksum#
Dictionary mapping array keys to SHA256 checksums.
- Type
dict[str, str]
- array_metadata#
Dictionary mapping array keys to shape/dtype info.
- Type
dict[str, dict]
- framework_version#
Version of the framework used to create checkpoint.
- Type
str
- custom_metadata#
User-defined metadata dictionary.
- Type
dict
- array_metadata: dict[str, dict] = None#
- checksum: dict[str, str] = None#
- custom_metadata: dict = None#
- framework_version: str = None#
- classmethod from_dict(data: dict) CheckpointMetadata[source]#
Create CheckpointMetadata from dictionary.
- Parameters
data – Dictionary containing metadata fields.
- Returns
CheckpointMetadata instance.
- timestamp: str = None#
- to_dict() dict[source]#
Convert metadata to dictionary format.
- Returns
Dictionary representation of the metadata.
- version: str = '0.0.90'#