eformer.serialization.checkpointer#

High-level checkpoint management for eFormer.

This module provides a sophisticated checkpoint manager with time- and run-based policies for training workflows. Key features include:

  • Flexible Checkpoint Policies: Configure time-based and run-based saving policies

  • TensorStore Backend: Efficient storage for large-scale distributed arrays

  • TP+FSDP Compatibility: No all-gather required, preserves existing shardings

  • Async Operations: Non-blocking checkpoint saves with background cleanup

  • Temporary Checkpoints: Automatic management of temporary vs permanent checkpoints

  • Multi-host Support: Distributed checkpoint operations across multiple hosts

The Checkpointer class is designed to be used in training loops where you want automatic checkpoint management without manual intervention.

class eformer.serialization.checkpointer.CheckpointInterval(every: int, until: int | None = None)[source]#

Bases: object

Configuration for run-based checkpoint saving policy.

Defines when to save checkpoints based on training steps. Multiple intervals can be combined to create sophisticated checkpoint policies (e.g., save every 100 steps for the first 1000 steps, then every 1000 steps thereafter).

every#

Save checkpoint every N steps within this interval.

Type

int

until#

Save using this policy until this step (inclusive). If None, this policy applies indefinitely. Only the last policy in a sequence can have until=None.

Type

int | None

Examples

```python # Save every 100 steps interval = CheckpointInterval(every=100)

# Save every 50 steps until step 1000 interval = CheckpointInterval(every=50, until=1000)

# Multi-stage policy: frequent saves early, less frequent later policies = [

CheckpointInterval(every=100, until=1000), # Every 100 steps up to 1000 CheckpointInterval(every=1000), # Every 1000 steps thereafter

]#

every: int#
until: int | None = None#
class eformer.serialization.checkpointer.Checkpointer(base_path: str, save_interval: datetime.timedelta | None, step_policies: Sequence[CheckpointInterval], *, manager: eformer.serialization.async_manager.AsyncCheckpointManager | None = None, dt_now_injection: Optional[Callable[[], datetime]] = None, delete_old_temp_checkpoints: bool = True)[source]#

Bases: object

High-level checkpoint manager with time- and run-based policies for eFormer.

This class provides automatic checkpoint management for training loops with support for both time-based and run-based saving policies. It integrates with JAX’s distributed training capabilities and TensorStore for efficient storage.

Key Features:
  • Multi-policy checkpointing: Configure different save intervals at different training stages

  • Time-based saves: Automatically save at regular time intervals

  • Temporary checkpoints: Distinguish between temporary (time-based) and permanent (run-based) checkpoints

  • Async cleanup: Background deletion of old temporary checkpoints

  • Multi-host support: Coordinated saves across distributed training hosts

  • TensorStore backend: Efficient storage without all-gather operations

The checkpointer maintains existing array shardings (TP/FSDP) during saves, avoiding expensive all-gather operations.

base_path#

Root directory for all checkpoints.

save_interval#

Optional time interval for temporary checkpoint saves.

step_policies#

Sequence of run-based checkpoint policies.

Examples

```python from datetime import timedelta

# Create checkpointer with time and step policies checkpointer = Checkpointer(

base_path=”/checkpoints/my_model”, save_interval=timedelta(minutes=15), # Temp checkpoint every 15min step_policies=[

CheckpointInterval(every=500, until=5000), # Every 500 steps until 5000 CheckpointInterval(every=1000), # Every 1000 steps after

],

)

# In training loop for step, batch in enumerate(train_loader):

# … training code … checkpointer.on_step(training_state, step)

# Wait for all saves to complete checkpointer.wait_until_finished() ```

load_checkpoint(mesh: Mesh, *, path: str | None = None, discover_latest: bool = True, shardings: dict | None = None, partition_rules: Any = None, dtype: Any = None, prefix: str | None = None, structured: bool = False, template: eformer.pytree._pytree.PyTree | None = None, strict_shapes: bool = True) tuple[eformer.pytree._pytree.PyTree, dict[str, Any]][source]#

Load a checkpoint from disk with automatic discovery.

Loads a checkpoint directory and restores the PyTree structure with proper array shardings for distributed training. Can automatically discover the most recent checkpoint based on metadata timestamps.

Parameters
  • mesh – JAX mesh for distributed array loading. Required for properly restoring sharded arrays across devices.

  • path – Specific checkpoint directory to load. If None, uses base_path and discovers the latest checkpoint if discover_latest=True.

  • discover_latest – If True, automatically finds and loads the most recent checkpoint under the specified path based on metadata timestamps. If False, loads from the exact path specified.

  • shardings – Dictionary mapping array names to sharding specifications. Used to restore arrays with specific shardings. If None, attempts to restore original shardings from checkpoint metadata. Only used in non-structured mode.

  • partition_rules – Optional partition rules for automatic sharding inference. Alternative to explicit shardings dictionary.

  • dtype – Optional dtype to cast loaded arrays to. If None, preserves original dtypes from the checkpoint.

  • prefix – Optional prefix if the checkpoint contains multiple trees. Must match the prefix used during save_checkpoint. Required when structured=True.

  • structured – If True, loads the checkpoint in structured mode which restores the full PyTree structure (treedef) and requires a prefix. If False, uses non-structured TensorStore mode loading only array leaves. Default: False.

Returns

  • tree: Restored PyTree with the same structure as saved

  • metadata: Dictionary containing checkpoint metadata (step, timestamp, etc.)

Return type

A tuple of (tree, metadata) where

Raises
  • FileNotFoundError – If no checkpoint is found at the specified path or if discover_latest=True and no checkpoints exist.

  • ValueError – If structured=True but prefix is not provided.

Note

  • Automatically handles both local and cloud storage paths

  • Restores arrays with distributed shardings (no all-gather)

  • All processes must call this method (distributed operation)

  • Discovered checkpoints are sorted by timestamp then step number

  • structured=True requires a prefix matching the one used during save

Examples

```python # Load latest checkpoint automatically state, metadata = checkpointer.load_checkpoint(mesh=my_mesh) print(f”Loaded checkpoint from step {metadata[‘step’]}”)

# Load specific checkpoint state, _ = checkpointer.load_checkpoint(

mesh=my_mesh, path=”/checkpoints/my_model/run-1000”, discover_latest=False,

)

# Load structured checkpoint with prefix state, _ = checkpointer.load_checkpoint(

mesh=my_mesh, prefix=”optimizer”, structured=True,

)

# Load with custom shardings (non-structured mode) state, _ = checkpointer.load_checkpoint(

mesh=my_mesh, shardings=custom_shardings, structured=False,

)#

load_pytree(mesh: Mesh, *, prefix: str, path: str | None = None, discover_latest: bool = True, discover_raise: bool = True, partition_rules: Any = None, dtype: numpy.dtype | None = None, load_treedef: bool = False, callback: Optional[Callable[[Array, str], Array]] = None, template: eformer.pytree._pytree.PyTree | None = None, strict_shapes: bool = True, chunk_size: int | None = None) tuple[eformer.pytree._pytree.PyTree, dict[str, Any]][source]#

Load a treedef-preserving PyTree saved under a specific prefix.

This method loads checkpoints saved in structured mode using save_pytree(), which preserves the full PyTree structure definition (treedef).

Parameters
  • mesh – JAX Mesh for array sharding on load. Required for properly restoring sharded arrays across devices.

  • prefix – Namespace/prefix (e.g., “tx”, “model”) used at save time. Must match the prefix used when saving the checkpoint.

  • path – Optional exact checkpoint directory to load from. If None, uses base_path and discovers the latest checkpoint if discover_latest=True.

  • discover_latest – If True, automatically finds and loads the most recent checkpoint under the specified path based on metadata timestamps. If False, loads from the exact path specified. Default: True.

  • discover_raise – If True, raises FileNotFoundError when no checkpoint is found during discovery. If False, returns None silently when no checkpoint exists. Only used when discover_latest=True. Default: True.

  • partition_rules – Optional partition rules for automatic sharding inference. Alternative to explicit shardings.

  • dtype – Optional dtype to cast loaded arrays to. If None, preserves original dtypes from the checkpoint.

  • load_treedef – If True, uses load_pytree() which restores the full PyTree structure definition. If False, uses load() which restores only the array values. Default: False.

  • callback – Optional callback function to process each loaded array. Receives (array, key_path) and should return the processed array. Useful for custom transformations during loading.

  • chunk_size – Optional number of arrays to load per batch.

Returns

  • pytree: Restored PyTree with the same structure as saved

  • extras_metadata: Dictionary containing checkpoint metadata from save

Return type

A tuple of (pytree, extras_metadata) where

Raises

FileNotFoundError – If no checkpoint is found and discover_raise=True.

Examples

```python # Load latest checkpoint for a specific prefix optimizer_state, metadata = checkpointer.load_pytree(

mesh=my_mesh, prefix=”optimizer”,

)

# Load specific checkpoint without discovery model, _ = checkpointer.load_pytree(

mesh=my_mesh, prefix=”model”, path=”/checkpoints/run-1000”, discover_latest=False,

)

# Load with callback for custom processing def convert_to_fp16(arr, key):

if arr.dtype == jnp.float32:

return arr.astype(jnp.float16)

return arr

state, _ = checkpointer.load_pytree(

mesh=my_mesh, prefix=”model”, callback=convert_to_fp16,

)

# Silently handle missing checkpoint state, metadata = checkpointer.load_pytree(

mesh=my_mesh, prefix=”tx”, discover_raise=False,

) if state is None:

print(“No checkpoint found, using fresh state”)

```

on_step(mesh: Mesh, pytree: Any | None = None, force: bool = False, *, step: int, true_callbacks: list[Callable[[str, jax._src.mesh.Mesh, dict], NoneType]] | None = None, extras: dict | None = None) None[source]#

Process a training step and save checkpoint if policies dictate.

This method should be called once per training step. It evaluates both time-based and run-based policies to determine whether to save a checkpoint. The decision is made on process 0 and broadcast to all processes to ensure consistency in distributed settings.

Parameters
  • mesh – JAX mesh for distributed arrays. Required for checkpoint saving with proper sharding. Passed to save_checkpoint and callbacks.

  • pytree – Training state PyTree to save. Can be None if only using true_callbacks to handle checkpoint saving externally.

  • force – If True, force a permanent checkpoint save regardless of policies. Useful for saving at the end of training or before evaluation.

  • step – Current training step number. Used to determine if checkpoint should be saved based on step_policies.

  • true_callbacks – Optional list of callback functions to execute when a checkpoint save is triggered. Each callback receives three arguments: destination (str), mesh (Mesh), and metadata (dict). Useful for custom checkpoint handling logic.

  • extras – Optional dictionary of extra metadata to include in the checkpoint. This metadata will be passed to save_checkpoint and stored with the checkpoint.

Note

  • Step 0 is skipped unless force=True (initialization step)

  • Duplicate step saves are skipped unless force=True

  • Time-based saves create temporary checkpoints (auto-cleaned)

  • run-based saves create permanent checkpoints

  • All processes synchronize on the save decision via broadcast

  • Old temporary checkpoints are queued for async deletion

Examples

```python # Regular usage in training loop for step, batch in enumerate(dataloader):

loss = train_step(state, batch) checkpointer.on_step(my_mesh, state, step=step)

# Force save at end of training checkpointer.on_step(my_mesh, state, force=True, step=final_step)

# With custom callbacks def log_save(destination: str, mesh: Mesh, metadata: dict):

print(f”Saved to {destination} with step {metadata.get(‘step’)}”)

checkpointer.on_step(

my_mesh, state, step=step, true_callbacks=[log_save]

)

# With extras metadata checkpointer.on_step(

my_mesh, state, step=step, extras={“loss”: float(loss), “accuracy”: float(acc)}

)#

save_checkpoint(tree: PyTree, destination: str, *, commit_callback: Optional[Callable[[], None]] = None, is_temporary: bool = False, mesh: Mesh = None, step: int = -1, shardings: Any = None, partition_rules: Any = None, prefix: str | None = None, structured: bool = False, dtype: numpy.dtype | None = None, extras: dict | None = None) None[source]#

Save a checkpoint to the specified destination.

This method saves a PyTree checkpoint using TensorStore backend with support for distributed training. It preserves existing array shardings (TP/FSDP) without performing all-gather operations, making it efficient for large models.

Parameters
  • tree – PyTree to save. Can be any nested structure containing JAX arrays, NumPy arrays, or other serializable Python objects.

  • destination – Subdirectory name under base_path where checkpoint will be saved. The full path will be base_path/destination.

  • commit_callback – Optional callback function to execute after the checkpoint save completes. Used internally for metadata writing and cleanup.

  • is_temporary – If True, marks this checkpoint as temporary in metadata. Temporary checkpoints are subject to automatic cleanup.

  • mesh – Optional JAX mesh for distributed arrays. If None, the current mesh from the tree’s shardings will be used.

  • step – Training step number for this checkpoint. Stored in metadata. Defaults to -1 if not specified.

  • shardings – Optional sharding specifications for arrays in the tree. If None, existing shardings are preserved from the tree. Only used in non-structured mode.

  • partition_rules – Optional partition rules for automatic sharding. Typically not needed if arrays already have shardings. Only used in non-structured mode.

  • prefix – Optional prefix for organizing multiple trees within the same checkpoint directory. Required when structured=True. Useful for saving multiple model states or components (e.g., “model”, “optimizer”).

  • structured – If True, saves the checkpoint in structured mode which preserves the PyTree structure (treedef) and requires a prefix. If False, uses non-structured TensorStore mode saving only array leaves. Default: False.

  • dtype – Optional dtype to cast arrays to before saving. If None, preserves original array dtypes.

  • extras – Optional dictionary of extra metadata to store in the checkpoint. Merged with standard metadata (step, is_temporary).

Note

  • Uses AsyncCheckpointManager for non-blocking I/O

  • Does NOT perform all-gather (preserves distributed arrays)

  • Creates destination directory if it doesn’t exist

  • Updates internal tracking of last save step and time

  • All processes must call this method (distributed operation)

  • structured=True requires a prefix and uses different save path

Examples

```python # Simple checkpoint save checkpointer.save_checkpoint(

tree=model_state, destination=”run-1000”, step=1000,

)

# Structured save with prefix checkpointer.save_checkpoint(

tree=optimizer_state, destination=”run-1000”, prefix=”optimizer”, structured=True, step=1000,

)

# Save with mesh, dtype, and extras checkpointer.save_checkpoint(

tree=training_state, destination=f”run-{current_step}”, mesh=my_mesh, step=current_step, dtype=jnp.bfloat16, extras={“loss”: float(loss), “learning_rate”: 0.001},

)#

save_pytree(tree: PyTree, prefix: str, *, step: int | None = None, destination: str | None = None, mesh: Mesh = None, dtype: numpy.dtype | None = None, extras: dict | None = None, temporary: bool = False, write_index: bool = True) str[source]#

Save a PyTree under a specific prefix with treedef preserved (structured checkpoint).

This method provides structured checkpoint saving that preserves the exact PyTree structure definition (treedef), enabling reconstruction of the original structure during loading without a template.

Parameters
  • tree – PyTree to save. Can contain JAX arrays, numpy arrays, and other serializable Python objects.

  • prefix – Namespace/prefix for organizing the saved tree (e.g., “model”, “optimizer”, “tx”). Required and must be a non-empty string.

  • step – Training step number for metadata. Used to generate default destination name (f”run-{step}”) and stored in metadata.json.

  • destination – Optional subdirectory name under base_path. If None and step is provided, defaults to f”run-{step}”. If both are None, saves directly to base_path.

  • mesh – Optional JAX Mesh for sharding context. Used when do_all_gather is performed to preserve proper device placement.

  • dtype – Optional dtype to cast floating point arrays to before saving. If None, preserves original dtypes.

  • extras – Optional dictionary of extra metadata to store in checkpoint_metadata.json alongside standard fields.

  • temporary – If True, marks this checkpoint as temporary in metadata.json. Temporary checkpoints are subject to automatic cleanup. Defaults to False.

  • write_index – Whether to write/update the TensorStore index file. Set to False only if you’re managing the index externally. Defaults to True.

Returns

The full checkpoint directory path (base_path/destination) where the checkpoint was saved.

Raises

ValueError – If prefix is empty or not a string.

Example

>>> path = checkpointer.save_pytree(
...     tree=model_state,
...     prefix="model",
...     step=1000,
...     extras={"learning_rate": 0.001}
... )
>>> print(f"Saved to: {path}")
wait_until_finished() None[source]#

Block until all checkpoint operations complete.

Waits for both: 1. All async checkpoint saves to finish writing to disk 2. All background checkpoint deletion operations to complete

This method should be called before program exit or when you need to ensure all checkpoint I/O has completed (e.g., before final evaluation or shutdown).

Note

  • Blocks the calling thread until completion

  • On non-primary processes, only waits for save operations

  • On primary process (0), also waits for deletion queue to empty

  • Uses polling with 0.2s sleep intervals for deletion queue

Examples

```python # At end of training for step in range(num_steps):

train_step(state) checkpointer.on_step(state, step=step)

checkpointer.wait_until_finished() print(“All checkpoints saved and cleaned up”) ```

eformer.serialization.checkpointer.find_latest_checkpoint(base_path: str) str | None[source]#

Find the most recent checkpoint under a directory.

Searches for checkpoint directories containing metadata.json files and returns the path to the most recent one based on timestamp and step number.

The function searches: 1. All immediate subdirectories under base_path 2. The base_path itself (in case it’s a checkpoint directory)

Parameters

base_path – Root directory to search for checkpoints. Can be a local path or cloud storage path (e.g., “gs://bucket/path”).

Returns

Full path to the latest checkpoint directory, or None if no valid checkpoints are found. If base_path has a URL scheme (e.g., “gs://”), the returned path preserves that scheme.

Note

  • Checkpoints are identified by presence of metadata.json

  • Sorting priority: timestamp first, then step number

  • Handles both local and cloud storage via fsspec

  • Logs warning if no checkpoints found

  • Handles FileNotFoundError gracefully (returns None)

Examples

```python # Local filesystem latest = find_latest_checkpoint(“/checkpoints/my_model”) # Returns: “/checkpoints/my_model/run-5000”

# Cloud storage latest = find_latest_checkpoint(“gs://bucket/checkpoints”) # Returns: “gs://bucket/checkpoints/run-10000”

# No checkpoints found latest = find_latest_checkpoint(“/empty/dir”) # Returns: None ```