eformer.serialization.checkpointer#
High-level checkpoint management for eFormer.
This module provides a sophisticated checkpoint manager with time- and run-based policies for training workflows. Key features include:
Flexible Checkpoint Policies: Configure time-based and run-based saving policies
TensorStore Backend: Efficient storage for large-scale distributed arrays
TP+FSDP Compatibility: No all-gather required, preserves existing shardings
Async Operations: Non-blocking checkpoint saves with background cleanup
Temporary Checkpoints: Automatic management of temporary vs permanent checkpoints
Multi-host Support: Distributed checkpoint operations across multiple hosts
The Checkpointer class is designed to be used in training loops where you want automatic checkpoint management without manual intervention.
- class eformer.serialization.checkpointer.CheckpointInterval(every: int, until: int | None = None)[source]#
Bases:
objectConfiguration for run-based checkpoint saving policy.
Defines when to save checkpoints based on training steps. Multiple intervals can be combined to create sophisticated checkpoint policies (e.g., save every 100 steps for the first 1000 steps, then every 1000 steps thereafter).
- every#
Save checkpoint every N steps within this interval.
- Type
int
- until#
Save using this policy until this step (inclusive). If None, this policy applies indefinitely. Only the last policy in a sequence can have until=None.
- Type
int | None
Examples
```python # Save every 100 steps interval = CheckpointInterval(every=100)
# Save every 50 steps until step 1000 interval = CheckpointInterval(every=50, until=1000)
# Multi-stage policy: frequent saves early, less frequent later policies = [
CheckpointInterval(every=100, until=1000), # Every 100 steps up to 1000 CheckpointInterval(every=1000), # Every 1000 steps thereafter
- class eformer.serialization.checkpointer.Checkpointer(base_path: str, save_interval: datetime.timedelta | None, step_policies: Sequence[CheckpointInterval], *, manager: eformer.serialization.async_manager.AsyncCheckpointManager | None = None, dt_now_injection: Optional[Callable[[], datetime]] = None, delete_old_temp_checkpoints: bool = True)[source]#
Bases:
objectHigh-level checkpoint manager with time- and run-based policies for eFormer.
This class provides automatic checkpoint management for training loops with support for both time-based and run-based saving policies. It integrates with JAX’s distributed training capabilities and TensorStore for efficient storage.
- Key Features:
Multi-policy checkpointing: Configure different save intervals at different training stages
Time-based saves: Automatically save at regular time intervals
Temporary checkpoints: Distinguish between temporary (time-based) and permanent (run-based) checkpoints
Async cleanup: Background deletion of old temporary checkpoints
Multi-host support: Coordinated saves across distributed training hosts
TensorStore backend: Efficient storage without all-gather operations
The checkpointer maintains existing array shardings (TP/FSDP) during saves, avoiding expensive all-gather operations.
- base_path#
Root directory for all checkpoints.
- save_interval#
Optional time interval for temporary checkpoint saves.
- step_policies#
Sequence of run-based checkpoint policies.
Examples
```python from datetime import timedelta
# Create checkpointer with time and step policies checkpointer = Checkpointer(
base_path=”/checkpoints/my_model”, save_interval=timedelta(minutes=15), # Temp checkpoint every 15min step_policies=[
CheckpointInterval(every=500, until=5000), # Every 500 steps until 5000 CheckpointInterval(every=1000), # Every 1000 steps after
],
)
# In training loop for step, batch in enumerate(train_loader):
# … training code … checkpointer.on_step(training_state, step)
# Wait for all saves to complete checkpointer.wait_until_finished() ```
- load_checkpoint(mesh: Mesh, *, path: str | None = None, discover_latest: bool = True, shardings: dict | None = None, partition_rules: Any = None, dtype: Any = None, prefix: str | None = None, structured: bool = False, template: eformer.pytree._pytree.PyTree | None = None, strict_shapes: bool = True) tuple[eformer.pytree._pytree.PyTree, dict[str, Any]][source]#
Load a checkpoint from disk with automatic discovery.
Loads a checkpoint directory and restores the PyTree structure with proper array shardings for distributed training. Can automatically discover the most recent checkpoint based on metadata timestamps.
- Parameters
mesh – JAX mesh for distributed array loading. Required for properly restoring sharded arrays across devices.
path – Specific checkpoint directory to load. If None, uses base_path and discovers the latest checkpoint if discover_latest=True.
discover_latest – If True, automatically finds and loads the most recent checkpoint under the specified path based on metadata timestamps. If False, loads from the exact path specified.
shardings – Dictionary mapping array names to sharding specifications. Used to restore arrays with specific shardings. If None, attempts to restore original shardings from checkpoint metadata. Only used in non-structured mode.
partition_rules – Optional partition rules for automatic sharding inference. Alternative to explicit shardings dictionary.
dtype – Optional dtype to cast loaded arrays to. If None, preserves original dtypes from the checkpoint.
prefix – Optional prefix if the checkpoint contains multiple trees. Must match the prefix used during save_checkpoint. Required when structured=True.
structured – If True, loads the checkpoint in structured mode which restores the full PyTree structure (treedef) and requires a prefix. If False, uses non-structured TensorStore mode loading only array leaves. Default: False.
- Returns
tree: Restored PyTree with the same structure as saved
metadata: Dictionary containing checkpoint metadata (step, timestamp, etc.)
- Return type
A tuple of (tree, metadata) where
- Raises
FileNotFoundError – If no checkpoint is found at the specified path or if discover_latest=True and no checkpoints exist.
ValueError – If structured=True but prefix is not provided.
Note
Automatically handles both local and cloud storage paths
Restores arrays with distributed shardings (no all-gather)
All processes must call this method (distributed operation)
Discovered checkpoints are sorted by timestamp then step number
structured=True requires a prefix matching the one used during save
Examples
```python # Load latest checkpoint automatically state, metadata = checkpointer.load_checkpoint(mesh=my_mesh) print(f”Loaded checkpoint from step {metadata[‘step’]}”)
# Load specific checkpoint state, _ = checkpointer.load_checkpoint(
mesh=my_mesh, path=”/checkpoints/my_model/run-1000”, discover_latest=False,
)
# Load structured checkpoint with prefix state, _ = checkpointer.load_checkpoint(
mesh=my_mesh, prefix=”optimizer”, structured=True,
)
# Load with custom shardings (non-structured mode) state, _ = checkpointer.load_checkpoint(
mesh=my_mesh, shardings=custom_shardings, structured=False,
)#
- load_pytree(mesh: Mesh, *, prefix: str, path: str | None = None, discover_latest: bool = True, discover_raise: bool = True, partition_rules: Any = None, dtype: numpy.dtype | None = None, load_treedef: bool = False, callback: Optional[Callable[[Array, str], Array]] = None, template: eformer.pytree._pytree.PyTree | None = None, strict_shapes: bool = True) tuple[eformer.pytree._pytree.PyTree, dict[str, Any]][source]#
Load a treedef-preserving PyTree saved under a specific prefix.
This method loads checkpoints saved in structured mode using save_pytree(), which preserves the full PyTree structure definition (treedef).
- Parameters
mesh – JAX Mesh for array sharding on load. Required for properly restoring sharded arrays across devices.
prefix – Namespace/prefix (e.g., “tx”, “model”) used at save time. Must match the prefix used when saving the checkpoint.
path – Optional exact checkpoint directory to load from. If None, uses base_path and discovers the latest checkpoint if discover_latest=True.
discover_latest – If True, automatically finds and loads the most recent checkpoint under the specified path based on metadata timestamps. If False, loads from the exact path specified. Default: True.
discover_raise – If True, raises FileNotFoundError when no checkpoint is found during discovery. If False, returns None silently when no checkpoint exists. Only used when discover_latest=True. Default: True.
partition_rules – Optional partition rules for automatic sharding inference. Alternative to explicit shardings.
dtype – Optional dtype to cast loaded arrays to. If None, preserves original dtypes from the checkpoint.
load_treedef – If True, uses load_pytree() which restores the full PyTree structure definition. If False, uses load() which restores only the array values. Default: False.
callback – Optional callback function to process each loaded array. Receives (array, key_path) and should return the processed array. Useful for custom transformations during loading.
- Returns
pytree: Restored PyTree with the same structure as saved
extras_metadata: Dictionary containing checkpoint metadata from save
- Return type
A tuple of (pytree, extras_metadata) where
- Raises
FileNotFoundError – If no checkpoint is found and discover_raise=True.
Examples
```python # Load latest checkpoint for a specific prefix optimizer_state, metadata = checkpointer.load_pytree(
mesh=my_mesh, prefix=”optimizer”,
)
# Load specific checkpoint without discovery model, _ = checkpointer.load_pytree(
mesh=my_mesh, prefix=”model”, path=”/checkpoints/run-1000”, discover_latest=False,
)
# Load with callback for custom processing def convert_to_fp16(arr, key):
- if arr.dtype == jnp.float32:
return arr.astype(jnp.float16)
return arr
- state, _ = checkpointer.load_pytree(
mesh=my_mesh, prefix=”model”, callback=convert_to_fp16,
)
# Silently handle missing checkpoint state, metadata = checkpointer.load_pytree(
mesh=my_mesh, prefix=”tx”, discover_raise=False,
) if state is None:
print(“No checkpoint found, using fresh state”)
- on_step(mesh: Mesh, pytree: Any | None = None, force: bool = False, *, step: int, true_callbacks: list[Callable[[str, jax._src.mesh.Mesh, dict], NoneType]] | None = None, extras: dict | None = None) None[source]#
Process a training step and save checkpoint if policies dictate.
This method should be called once per training step. It evaluates both time-based and run-based policies to determine whether to save a checkpoint. The decision is made on process 0 and broadcast to all processes to ensure consistency in distributed settings.
- Parameters
mesh – JAX mesh for distributed arrays. Required for checkpoint saving with proper sharding. Passed to save_checkpoint and callbacks.
pytree – Training state PyTree to save. Can be None if only using true_callbacks to handle checkpoint saving externally.
force – If True, force a permanent checkpoint save regardless of policies. Useful for saving at the end of training or before evaluation.
step – Current training step number. Used to determine if checkpoint should be saved based on step_policies.
true_callbacks – Optional list of callback functions to execute when a checkpoint save is triggered. Each callback receives three arguments: destination (str), mesh (Mesh), and metadata (dict). Useful for custom checkpoint handling logic.
extras – Optional dictionary of extra metadata to include in the checkpoint. This metadata will be passed to save_checkpoint and stored with the checkpoint.
Note
Step 0 is skipped unless force=True (initialization step)
Duplicate step saves are skipped unless force=True
Time-based saves create temporary checkpoints (auto-cleaned)
run-based saves create permanent checkpoints
All processes synchronize on the save decision via broadcast
Old temporary checkpoints are queued for async deletion
Examples
```python # Regular usage in training loop for step, batch in enumerate(dataloader):
loss = train_step(state, batch) checkpointer.on_step(my_mesh, state, step=step)
# Force save at end of training checkpointer.on_step(my_mesh, state, force=True, step=final_step)
# With custom callbacks def log_save(destination: str, mesh: Mesh, metadata: dict):
print(f”Saved to {destination} with step {metadata.get(‘step’)}”)
- checkpointer.on_step(
my_mesh, state, step=step, true_callbacks=[log_save]
)
# With extras metadata checkpointer.on_step(
my_mesh, state, step=step, extras={“loss”: float(loss), “accuracy”: float(acc)}
)#
- save_checkpoint(tree: PyTree, destination: str, *, commit_callback: Optional[Callable[[], None]] = None, is_temporary: bool = False, mesh: Mesh = None, step: int = -1, shardings: Any = None, partition_rules: Any = None, prefix: str | None = None, structured: bool = False, dtype: numpy.dtype | None = None, extras: dict | None = None) None[source]#
Save a checkpoint to the specified destination.
This method saves a PyTree checkpoint using TensorStore backend with support for distributed training. It preserves existing array shardings (TP/FSDP) without performing all-gather operations, making it efficient for large models.
- Parameters
tree – PyTree to save. Can be any nested structure containing JAX arrays, NumPy arrays, or other serializable Python objects.
destination – Subdirectory name under base_path where checkpoint will be saved. The full path will be base_path/destination.
commit_callback – Optional callback function to execute after the checkpoint save completes. Used internally for metadata writing and cleanup.
is_temporary – If True, marks this checkpoint as temporary in metadata. Temporary checkpoints are subject to automatic cleanup.
mesh – Optional JAX mesh for distributed arrays. If None, the current mesh from the tree’s shardings will be used.
step – Training step number for this checkpoint. Stored in metadata. Defaults to -1 if not specified.
shardings – Optional sharding specifications for arrays in the tree. If None, existing shardings are preserved from the tree. Only used in non-structured mode.
partition_rules – Optional partition rules for automatic sharding. Typically not needed if arrays already have shardings. Only used in non-structured mode.
prefix – Optional prefix for organizing multiple trees within the same checkpoint directory. Required when structured=True. Useful for saving multiple model states or components (e.g., “model”, “optimizer”).
structured – If True, saves the checkpoint in structured mode which preserves the PyTree structure (treedef) and requires a prefix. If False, uses non-structured TensorStore mode saving only array leaves. Default: False.
dtype – Optional dtype to cast arrays to before saving. If None, preserves original array dtypes.
extras – Optional dictionary of extra metadata to store in the checkpoint. Merged with standard metadata (step, is_temporary).
Note
Uses AsyncCheckpointManager for non-blocking I/O
Does NOT perform all-gather (preserves distributed arrays)
Creates destination directory if it doesn’t exist
Updates internal tracking of last save step and time
All processes must call this method (distributed operation)
structured=True requires a prefix and uses different save path
Examples
```python # Simple checkpoint save checkpointer.save_checkpoint(
tree=model_state, destination=”run-1000”, step=1000,
)
# Structured save with prefix checkpointer.save_checkpoint(
tree=optimizer_state, destination=”run-1000”, prefix=”optimizer”, structured=True, step=1000,
)
# Save with mesh, dtype, and extras checkpointer.save_checkpoint(
tree=training_state, destination=f”run-{current_step}”, mesh=my_mesh, step=current_step, dtype=jnp.bfloat16, extras={“loss”: float(loss), “learning_rate”: 0.001},
)#
- save_pytree(tree: PyTree, prefix: str, *, step: int | None = None, destination: str | None = None, mesh: Mesh = None, dtype: numpy.dtype | None = None, extras: dict | None = None, temporary: bool = False, write_index: bool = True) str[source]#
Save a PyTree under a specific prefix with treedef preserved (structured checkpoint).
This method provides structured checkpoint saving that preserves the exact PyTree structure definition (treedef), enabling reconstruction of the original structure during loading without a template.
- Parameters
tree – PyTree to save. Can contain JAX arrays, numpy arrays, and other serializable Python objects.
prefix – Namespace/prefix for organizing the saved tree (e.g., “model”, “optimizer”, “tx”). Required and must be a non-empty string.
step – Training step number for metadata. Used to generate default destination name (f”run-{step}”) and stored in metadata.json.
destination – Optional subdirectory name under base_path. If None and step is provided, defaults to f”run-{step}”. If both are None, saves directly to base_path.
mesh – Optional JAX Mesh for sharding context. Used when do_all_gather is performed to preserve proper device placement.
dtype – Optional dtype to cast floating point arrays to before saving. If None, preserves original dtypes.
extras – Optional dictionary of extra metadata to store in checkpoint_metadata.json alongside standard fields.
temporary – If True, marks this checkpoint as temporary in metadata.json. Temporary checkpoints are subject to automatic cleanup. Defaults to False.
write_index – Whether to write/update the TensorStore index file. Set to False only if you’re managing the index externally. Defaults to True.
- Returns
The full checkpoint directory path (base_path/destination) where the checkpoint was saved.
- Raises
ValueError – If prefix is empty or not a string.
Example
>>> path = checkpointer.save_pytree( ... tree=model_state, ... prefix="model", ... step=1000, ... extras={"learning_rate": 0.001} ... ) >>> print(f"Saved to: {path}")
- wait_until_finished() None[source]#
Block until all checkpoint operations complete.
Waits for both: 1. All async checkpoint saves to finish writing to disk 2. All background checkpoint deletion operations to complete
This method should be called before program exit or when you need to ensure all checkpoint I/O has completed (e.g., before final evaluation or shutdown).
Note
Blocks the calling thread until completion
On non-primary processes, only waits for save operations
On primary process (0), also waits for deletion queue to empty
Uses polling with 0.2s sleep intervals for deletion queue
Examples
```python # At end of training for step in range(num_steps):
train_step(state) checkpointer.on_step(state, step=step)
checkpointer.wait_until_finished() print(“All checkpoints saved and cleaned up”) ```
- eformer.serialization.checkpointer.find_latest_checkpoint(base_path: str) str | None[source]#
Find the most recent checkpoint under a directory.
Searches for checkpoint directories containing metadata.json files and returns the path to the most recent one based on timestamp and step number.
The function searches: 1. All immediate subdirectories under base_path 2. The base_path itself (in case it’s a checkpoint directory)
- Parameters
base_path – Root directory to search for checkpoints. Can be a local path or cloud storage path (e.g., “gs://bucket/path”).
- Returns
Full path to the latest checkpoint directory, or None if no valid checkpoints are found. If base_path has a URL scheme (e.g., “gs://”), the returned path preserves that scheme.
Note
Checkpoints are identified by presence of metadata.json
Sorting priority: timestamp first, then step number
Handles both local and cloud storage via fsspec
Logs warning if no checkpoints found
Handles FileNotFoundError gracefully (returns None)
Examples
```python # Local filesystem latest = find_latest_checkpoint(“/checkpoints/my_model”) # Returns: “/checkpoints/my_model/run-5000”
# Cloud storage latest = find_latest_checkpoint(“gs://bucket/checkpoints”) # Returns: “gs://bucket/checkpoints/run-10000”
# No checkpoints found latest = find_latest_checkpoint(“/empty/dir”) # Returns: None ```