eformer.serialization.base_manager#
- class eformer.serialization.base_manager.CheckpointManager(checkpoint_dir: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, enable: bool | None = None, float_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, save_optimizer_state: bool = True, verbose: bool = False, gcs_bucket: str | None = None, gcs_credentials_path: str | None = None)[source]#
Bases:
objectBase checkpoint manager for saving and loading PyTree structures.
This manager provides functionality for saving and loading checkpoints with support for sharding, Google Cloud Storage, and various data formats.
- float_dtype#
Default data type for floating point arrays.
- save_optimizer_state#
Whether to save optimizer state.
- checkpoint_dir#
Directory for saving checkpoints.
- enable#
Whether checkpointing is enabled.
- verbose#
Enable verbose output.
- gcs_bucket#
Google Cloud Storage bucket name.
- gcs_client#
GCS client instance.
- static create_gcs_client(gcs_credentials_path: str | None = None)[source]#
Create a Google Cloud Storage client.
- Parameters
gcs_credentials_path – Optional path to service account credentials.
- Returns
Google Cloud Storage client instance.
- static load_checkpoint(path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, shard_fns: dict[Callable] | None = None, verbose: bool = False, mismatch_allowed: bool = True, callback: Optional[Callable[[Array, str], Array]] = None, dtype: str | numpy.dtype | None = None, gcs_client: google.cloud.storage.client.Client | None = None) tuple[eformer.pytree._pytree.PyTree | dict, dict][source]#
Load a checkpoint from local path or GCS.
- Supports:
Single safetensors file
Sharded safetensors with index (prefix.safetensors.index.json)
- Parameters
path – Path to the checkpoint file or directory.
shard_fns – Dictionary of functions to apply to specific shards.
verbose – Enable verbose output.
mismatch_allowed – Whether to allow missing shard functions.
callback – Optional callback to process arrays after loading.
dtype – Data type to cast arrays to.
gcs_client – Optional GCS client instance.
- Returns
Tuple of (loaded PyTree or dict, metadata dict).
- load_state_from_gcs_msgpack(gcs_path: str, verbose: bool = False) dict[source]#
Load tree from GCS msgpack format.
- Parameters
gcs_path – GCS path to the msgpack checkpoint.
verbose – Enable verbose output.
- Returns
Dictionary containing the loaded checkpoint.
- Raises
ValueError – If GCS client is not initialized.
- classmethod save_checkpoint(tree: PyTree, path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, mesh: Mesh, gather_fns: dict[Callable] | bool | None = None, float_dtype: str | numpy.dtype | None = None, verbose: bool = True, mismatch_allowed: bool = True, metadata: dict[str, str] | None = None, enable: bool | None = None, shard_size_gb: float | None = 5.0, write_index_file: bool = True) eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike[source]#
Save a checkpoint to local path or GCS using SafeTensors.
If shard_size_gb is provided, the tree is saved as multiple shards of up to that size (except the last shard, which may be smaller). An index file ‘prefix.safetensors.index.json’ is also written mapping every tensor name to a shard file.
- Parameters
tree – PyTree structure to save.
path – Path where the checkpoint will be saved.
mesh – JAX mesh for distributed computation.
gather_fns – Dictionary of gather functions or bool for device gathering.
float_dtype – Data type for floating point arrays.
verbose – Enable verbose output.
mismatch_allowed – Whether to allow missing gather functions.
metadata – Additional metadata to save with checkpoint.
enable – Whether checkpointing is enabled (None = auto-detect process 0).
shard_size_gb – Maximum size of each shard in GB.
write_index_file – Whether to write the index file for sharded saves.
- Returns
Path where the checkpoint was saved.
- save_state_to_gcs_msgpack(tree: PyTree, gcs_path: str, gather_fns: dict[Callable] | None = None, float_dtype: str | numpy.dtype | None = None, verbose: bool = False, mismatch_allowed: bool = True)[source]#
Save tree to GCS using msgpack format (streaming).
- Parameters
tree – PyTree structure to save.
gcs_path – GCS path where the checkpoint will be saved.
gather_fns – Dictionary of gather functions.
float_dtype – Data type for floating point arrays.
verbose – Enable verbose output.
mismatch_allowed – Whether to allow missing gather functions.
- Raises
ValueError – If GCS client is not initialized.
KeyError – If gather function is missing and mismatch not allowed.