eformer.serialization.utils#
- eformer.serialization.utils.apply_gather_functions(state: dict, gather_fns: dict | bool, mismatch_allowed: bool, verbose: bool) dict[source]#
Apply gather functions to state.
- Parameters
state – State dictionary
gather_fns – Gather functions or boolean flag
mismatch_allowed – Whether to allow mismatches
verbose – Enable verbose output
- Returns
Processed state dictionary
- eformer.serialization.utils.broadcast_tensor(tensor: Array, memory_limit_bytes: int, target_sharding: jax.sharding.NamedSharding | None = None) Array[source]#
Efficiently broadcast tensor from single replica to all devices.
- Parameters
tensor – Input tensor
memory_limit_bytes – Memory limit for chunking
target_sharding – Optional target sharding
- Returns
Broadcasted tensor
- eformer.serialization.utils.chunk_tensor_by_memory(tensor: Array, memory_limit: int) list[jax.jaxlib._jax.Array][source]#
Split tensor into chunks that fit within memory limit.
- Parameters
tensor – Input tensor
memory_limit – Memory limit in bytes
- Returns
List of tensor chunks
- eformer.serialization.utils.derive_base_prefix_from_path(path_str: str) str[source]#
Normalize a path into its ‘base prefix’ used for sharded file naming.
Examples
/x/model.safetensors -> /x/model /x/model.safetensors.index.json -> /x/model /x/model-00001-of-00004.safetensors -> /x/model
- Parameters
path_str – Input path string
- Returns
Base prefix for the path
- eformer.serialization.utils.estimate_array_nbytes(array: Array) int[source]#
Estimate number of bytes for a JAX array without device_get.
- Parameters
array – JAX array to estimate size for
- Returns
Estimated size in bytes
- eformer.serialization.utils.estimate_available_memory() int[source]#
Dynamically estimate available memory for safe loading.
- Returns
Available memory in bytes
- eformer.serialization.utils.flatten_for_broadcast(state: dict) dict[source]#
Flatten state for efficient broadcasting.
- Parameters
state – State dictionary
- Returns
Flattened state dictionary
- eformer.serialization.utils.group_keys_by_shard_size(flat_state: dict[str, jax.jaxlib._jax.Array], max_shard_size_bytes: int) list[list[str]][source]#
Group keys into shards under max_shard_size_bytes each.
- Parameters
flat_state – Flattened state dictionary
max_shard_size_bytes – Maximum size per shard in bytes
- Returns
List of key groups (shards)
- eformer.serialization.utils.index_filename(base_prefix: str) str[source]#
Generate the index filename for a sharded checkpoint.
Creates the filename for the JSON index file that maps tensor names to their shard files.
- Parameters
base_prefix – Base path/prefix for the checkpoint.
- Returns
Full path to the index JSON file.
Example
>>> index_filename("/checkpoints/model") '/checkpoints/model.safetensors.index.json'
- eformer.serialization.utils.is_gcs_path(path: str) bool[source]#
Check if a path points to Google Cloud Storage.
Determines whether a path is a GCS path by checking for the gs:// prefix or by checking if it’s a GCSPath instance.
- Parameters
path – Path string or path object to check.
- Returns
True if the path is a GCS path, False otherwise.
Example
>>> is_gcs_path("gs://my-bucket/checkpoint") True >>> is_gcs_path("/local/path/checkpoint") False
- eformer.serialization.utils.optimize_shard_layout(state: dict[str, jax.jaxlib._jax.Array], max_shard_size_bytes: int) list[list[str]][source]#
Optimize shard layout for better loading performance. Groups related tensors and considers access patterns.
- Parameters
state – State dictionary with arrays
max_shard_size_bytes – Maximum size per shard
- Returns
Optimized list of key groups
- eformer.serialization.utils.parse_gcs_path(gcs_path: str) tuple[str, str][source]#
Parse gs://bucket/path into bucket and blob name.
- Parameters
gcs_path – GCS path string
- Returns
Tuple of (bucket_name, blob_name)
- eformer.serialization.utils.read_process_array(key: str, shard_fns: dict | None, mismatch_allowed: bool, manager, callback: Optional[Callable[[Array, str], Array]] = None, dtype: str | numpy.dtype | None = None) tuple[str, jax.jaxlib._jax.Array, int][source]#
Helper function to process a single tensor from a checkpoint.
- Parameters
key – Tensor key
shard_fns – Shard functions dictionary
mismatch_allowed – Whether to allow shard function mismatches
manager – Checkpoint manager instance
callback – Optional callback for tensor processing
dtype – Target dtype for conversion
- Returns
Tuple of (key, processed_tensor, mismatch_count)
- eformer.serialization.utils.reshard(x, target_sharding)[source]#
Reshard an array to a new sharding specification.
Uses JIT compilation with sharding constraints to efficiently move an array to a new sharding layout across devices.
- Parameters
x – Input JAX array to reshard.
target_sharding – Target sharding specification (e.g., NamedSharding).
- Returns
The input array resharded according to target_sharding.
Example
>>> from jax.sharding import NamedSharding, PartitionSpec >>> new_sharding = NamedSharding(mesh, PartitionSpec("data")) >>> resharded = reshard(my_array, new_sharding)
- eformer.serialization.utils.shard_filename(base_prefix: str, idx: int, total: int) str[source]#
Generate a standardized shard filename for sharded checkpoints.
Creates filenames following the pattern: <prefix>-XXXXX-of-YYYYY.safetensors
- Parameters
base_prefix – Base path/prefix for the shard files.
idx – Shard index (1-indexed).
total – Total number of shards.
- Returns
Full shard filename with zero-padded indices.
Example
>>> shard_filename("/checkpoints/model", 1, 4) '/checkpoints/model-00001-of-00004.safetensors'
- eformer.serialization.utils.to_host(x: Array, float_dtype: numpy.floating | None, mesh: Mesh, cpu_offload: bool)[source]#
Move array to host with optional dtype conversion and CPU offloading.
Reshards the array to a fully replicated sharding on the provided mesh, optionally offloads to CPU, and converts floating point arrays to the specified dtype.
- Parameters
x – JAX array to move to host.
float_dtype – Target dtype for floating point arrays. Can be a string (e.g., “float32”) or jnp dtype. If None, dtype is unchanged.
mesh – JAX mesh to use for resharding to replicated layout.
cpu_offload – If True, additionally reshard to CPU mesh to free accelerator memory.
- Returns
Array moved to host, potentially with converted dtype.
Note
CPU offloading is useful for preventing OOM during checkpointing by moving data off accelerators before serialization.