eformer.serialization.utils#

eformer.serialization.utils.apply_gather_functions(state: dict, gather_fns: dict | bool, mismatch_allowed: bool, verbose: bool) dict[source]#

Apply gather functions to state.

Parameters
  • state – State dictionary

  • gather_fns – Gather functions or boolean flag

  • mismatch_allowed – Whether to allow mismatches

  • verbose – Enable verbose output

Returns

Processed state dictionary

eformer.serialization.utils.broadcast_tensor(tensor: Array, memory_limit_bytes: int, target_sharding: jax.sharding.NamedSharding | None = None) Array[source]#

Efficiently broadcast tensor from single replica to all devices.

Parameters
  • tensor – Input tensor

  • memory_limit_bytes – Memory limit for chunking

  • target_sharding – Optional target sharding

Returns

Broadcasted tensor

eformer.serialization.utils.chunk_tensor_by_memory(tensor: Array, memory_limit: int) list[jax.jaxlib._jax.Array][source]#

Split tensor into chunks that fit within memory limit.

Parameters
  • tensor – Input tensor

  • memory_limit – Memory limit in bytes

Returns

List of tensor chunks

eformer.serialization.utils.derive_base_prefix_from_path(path_str: str) str[source]#

Normalize a path into its ‘base prefix’ used for sharded file naming.

Examples

/x/model.safetensors -> /x/model /x/model.safetensors.index.json -> /x/model /x/model-00001-of-00004.safetensors -> /x/model

Parameters

path_str – Input path string

Returns

Base prefix for the path

eformer.serialization.utils.estimate_array_nbytes(array: Array) int[source]#

Estimate number of bytes for a JAX array without device_get.

Parameters

array – JAX array to estimate size for

Returns

Estimated size in bytes

eformer.serialization.utils.estimate_available_memory() int[source]#

Dynamically estimate available memory for safe loading.

Returns

Available memory in bytes

eformer.serialization.utils.flatten_for_broadcast(state: dict) dict[source]#

Flatten state for efficient broadcasting.

Parameters

state – State dictionary

Returns

Flattened state dictionary

eformer.serialization.utils.group_keys_by_shard_size(flat_state: dict[str, jax.jaxlib._jax.Array], max_shard_size_bytes: int) list[list[str]][source]#

Group keys into shards under max_shard_size_bytes each.

Parameters
  • flat_state – Flattened state dictionary

  • max_shard_size_bytes – Maximum size per shard in bytes

Returns

List of key groups (shards)

eformer.serialization.utils.index_filename(base_prefix: str) str[source]#

Generate the index filename for a sharded checkpoint.

Creates the filename for the JSON index file that maps tensor names to their shard files.

Parameters

base_prefix – Base path/prefix for the checkpoint.

Returns

Full path to the index JSON file.

Example

>>> index_filename("/checkpoints/model")
'/checkpoints/model.safetensors.index.json'
eformer.serialization.utils.is_gcs_path(path: str) bool[source]#

Check if a path points to Google Cloud Storage.

Determines whether a path is a GCS path by checking for the gs:// prefix or by checking if it’s a GCSPath instance.

Parameters

path – Path string or path object to check.

Returns

True if the path is a GCS path, False otherwise.

Example

>>> is_gcs_path("gs://my-bucket/checkpoint")
True
>>> is_gcs_path("/local/path/checkpoint")
False
eformer.serialization.utils.optimize_shard_layout(state: dict[str, jax.jaxlib._jax.Array], max_shard_size_bytes: int) list[list[str]][source]#

Optimize shard layout for better loading performance. Groups related tensors and considers access patterns.

Parameters
  • state – State dictionary with arrays

  • max_shard_size_bytes – Maximum size per shard

Returns

Optimized list of key groups

eformer.serialization.utils.parse_gcs_path(gcs_path: str) tuple[str, str][source]#

Parse gs://bucket/path into bucket and blob name.

Parameters

gcs_path – GCS path string

Returns

Tuple of (bucket_name, blob_name)

eformer.serialization.utils.read_process_array(key: str, shard_fns: dict | None, mismatch_allowed: bool, manager, callback: Optional[Callable[[Array, str], Array]] = None, dtype: str | numpy.dtype | None = None) tuple[str, jax.jaxlib._jax.Array, int][source]#

Helper function to process a single tensor from a checkpoint.

Parameters
  • key – Tensor key

  • shard_fns – Shard functions dictionary

  • mismatch_allowed – Whether to allow shard function mismatches

  • manager – Checkpoint manager instance

  • callback – Optional callback for tensor processing

  • dtype – Target dtype for conversion

Returns

Tuple of (key, processed_tensor, mismatch_count)

eformer.serialization.utils.reshard(x, target_sharding)[source]#

Reshard an array to a new sharding specification.

Uses JIT compilation with sharding constraints to efficiently move an array to a new sharding layout across devices.

Parameters
  • x – Input JAX array to reshard.

  • target_sharding – Target sharding specification (e.g., NamedSharding).

Returns

The input array resharded according to target_sharding.

Example

>>> from jax.sharding import NamedSharding, PartitionSpec
>>> new_sharding = NamedSharding(mesh, PartitionSpec("data"))
>>> resharded = reshard(my_array, new_sharding)
eformer.serialization.utils.shard_filename(base_prefix: str, idx: int, total: int) str[source]#

Generate a standardized shard filename for sharded checkpoints.

Creates filenames following the pattern: <prefix>-XXXXX-of-YYYYY.safetensors

Parameters
  • base_prefix – Base path/prefix for the shard files.

  • idx – Shard index (1-indexed).

  • total – Total number of shards.

Returns

Full shard filename with zero-padded indices.

Example

>>> shard_filename("/checkpoints/model", 1, 4)
'/checkpoints/model-00001-of-00004.safetensors'
eformer.serialization.utils.to_host(x: Array, float_dtype: numpy.floating | None, mesh: Mesh, cpu_offload: bool)[source]#

Move array to host with optional dtype conversion and CPU offloading.

Reshards the array to a fully replicated sharding on the provided mesh, optionally offloads to CPU, and converts floating point arrays to the specified dtype.

Parameters
  • x – JAX array to move to host.

  • float_dtype – Target dtype for floating point arrays. Can be a string (e.g., “float32”) or jnp dtype. If None, dtype is unchanged.

  • mesh – JAX mesh to use for resharding to replicated layout.

  • cpu_offload – If True, additionally reshard to CPU mesh to free accelerator memory.

Returns

Array moved to host, potentially with converted dtype.

Note

CPU offloading is useful for preventing OOM during checkpointing by moving data off accelerators before serialization.