eformer.serialization.base_manager#

class eformer.serialization.base_manager.CheckpointManager(checkpoint_dir: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, enable: bool | None = None, float_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, save_optimizer_state: bool = True, verbose: bool = False, gcs_bucket: str | None = None, gcs_credentials_path: str | None = None)[source]#

Bases: object

Base checkpoint manager for saving and loading PyTree structures.

This manager provides functionality for saving and loading checkpoints with support for sharding, Google Cloud Storage, and various data formats.

float_dtype#

Default data type for floating point arrays.

save_optimizer_state#

Whether to save optimizer state.

checkpoint_dir#

Directory for saving checkpoints.

enable#

Whether checkpointing is enabled.

verbose#

Enable verbose output.

gcs_bucket#

Google Cloud Storage bucket name.

gcs_client#

GCS client instance.

static create_gcs_client(gcs_credentials_path: str | None = None)[source]#

Create a Google Cloud Storage client.

Parameters

gcs_credentials_path – Optional path to service account credentials.

Returns

Google Cloud Storage client instance.

static load_checkpoint(path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, shard_fns: dict[Callable] | None = None, verbose: bool = False, mismatch_allowed: bool = True, callback: Optional[Callable[[Array, str], Array]] = None, dtype: str | numpy.dtype | None = None, gcs_client: google.cloud.storage.client.Client | None = None) tuple[eformer.pytree._pytree.PyTree | dict, dict][source]#

Load a checkpoint from local path or GCS.

Supports:
  • Single safetensors file

  • Sharded safetensors with index (prefix.safetensors.index.json)

Parameters
  • path – Path to the checkpoint file or directory.

  • shard_fns – Dictionary of functions to apply to specific shards.

  • verbose – Enable verbose output.

  • mismatch_allowed – Whether to allow missing shard functions.

  • callback – Optional callback to process arrays after loading.

  • dtype – Data type to cast arrays to.

  • gcs_client – Optional GCS client instance.

Returns

Tuple of (loaded PyTree or dict, metadata dict).

load_state_from_gcs_msgpack(gcs_path: str, verbose: bool = False) dict[source]#

Load tree from GCS msgpack format.

Parameters
  • gcs_path – GCS path to the msgpack checkpoint.

  • verbose – Enable verbose output.

Returns

Dictionary containing the loaded checkpoint.

Raises

ValueError – If GCS client is not initialized.

classmethod save_checkpoint(tree: PyTree, path: eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike, mesh: Mesh, gather_fns: dict[Callable] | bool | None = None, float_dtype: str | numpy.dtype | None = None, verbose: bool = True, mismatch_allowed: bool = True, metadata: dict[str, str] | None = None, enable: bool | None = None, shard_size_gb: float | None = 5.0, write_index_file: bool = True) eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath | str | os.PathLike[source]#

Save a checkpoint to local path or GCS using SafeTensors.

If shard_size_gb is provided, the tree is saved as multiple shards of up to that size (except the last shard, which may be smaller). An index file ‘prefix.safetensors.index.json’ is also written mapping every tensor name to a shard file.

Parameters
  • tree – PyTree structure to save.

  • path – Path where the checkpoint will be saved.

  • mesh – JAX mesh for distributed computation.

  • gather_fns – Dictionary of gather functions or bool for device gathering.

  • float_dtype – Data type for floating point arrays.

  • verbose – Enable verbose output.

  • mismatch_allowed – Whether to allow missing gather functions.

  • metadata – Additional metadata to save with checkpoint.

  • enable – Whether checkpointing is enabled (None = auto-detect process 0).

  • shard_size_gb – Maximum size of each shard in GB.

  • write_index_file – Whether to write the index file for sharded saves.

Returns

Path where the checkpoint was saved.

save_state_to_gcs_msgpack(tree: PyTree, gcs_path: str, gather_fns: dict[Callable] | None = None, float_dtype: str | numpy.dtype | None = None, verbose: bool = False, mismatch_allowed: bool = True)[source]#

Save tree to GCS using msgpack format (streaming).

Parameters
  • tree – PyTree structure to save.

  • gcs_path – GCS path where the checkpoint will be saved.

  • gather_fns – Dictionary of gather functions.

  • float_dtype – Data type for floating point arrays.

  • verbose – Enable verbose output.

  • mismatch_allowed – Whether to allow missing gather functions.

Raises
  • ValueError – If GCS client is not initialized.

  • KeyError – If gather function is missing and mismatch not allowed.