# Copyright 2026 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import threading
from contextlib import contextmanager
from typing import Any
import jax
import msgpack
import numpy as np
_STATE_DICT_REGISTRY: dict[Any, Any] = {}
class _ErrorContext(threading.local):
"""Thread-local context for tracking the path during deserialization.
This class maintains a stack of path components that are pushed/popped
as the deserialization traverses the nested structure. This enables
meaningful error messages that indicate where in the structure a
problem occurred.
Attributes:
path: List of path components representing the current location
in the nested structure being deserialized.
"""
def __init__(self):
"""Initialize the error context with an empty path."""
self.path = []
_error_context = _ErrorContext()
@contextmanager
def _record_path(name):
"""Context manager to track the current path during deserialization.
This context manager pushes a path component onto the error context stack
when entering and pops it when exiting, enabling accurate error reporting
with full path information.
Args:
name: The name of the current path component (e.g., field name, key).
Yields:
None. The context manager manages the path stack internally.
Examples:
>>> with _record_path("layer1"):
... with _record_path("weights"):
... # current_path() would return "layer1/weights"
... pass
"""
try:
_error_context.path.append(name)
yield
finally:
_error_context.path.pop()
[docs]def current_path():
"""Get the current state_dict path during deserialization.
Returns:
str: Current path as a slash-separated string, useful for error messages.
Examples:
>>>
>>> path = current_path()
>>>
"""
return "/".join(_error_context.path)
class _NamedTuple:
"""Sentinel type marker for namedtuple serialization registration.
This class is used as a key in the state dict registry to handle
namedtuple types uniformly, since namedtuples are created dynamically
by the namedtuple factory and don't share a common base class.
"""
pass
def _is_namedtuple(x):
"""Check if an object is a namedtuple instance using duck typing.
Namedtuples are identified by being tuples that have a '_fields'
attribute containing the field names.
Args:
x: The object to check.
Returns:
bool: True if x appears to be a namedtuple instance, False otherwise.
Examples:
>>> from collections import namedtuple
>>> Point = namedtuple('Point', ['x', 'y'])
>>> _is_namedtuple(Point(1, 2))
True
>>> _is_namedtuple((1, 2))
False
"""
return isinstance(x, tuple) and hasattr(x, "_fields")
[docs]def from_state_dict(target, state: dict[str, Any], name: str = "."):
"""Restores the state of the given target using a state dict.
This function takes the current target as an argument. This
lets us know the exact structure of the target,
as well as lets us add assertions that shapes and dtypes don't change.
In practice, none of the leaf values in ``target`` are actually
used. Only the tree structure, shapes and dtypes.
Args:
target: the object of which the state should be restored.
state: a dictionary generated by ``to_state_dict`` with the desired new
state for ``target``.
name: name of branch taken, used to improve deserialization error messages.
Returns:
A copy of the object with the restored state.
"""
if _is_namedtuple(target):
ty = _NamedTuple
else:
ty = type(target)
if ty not in _STATE_DICT_REGISTRY:
return state
ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1]
with _record_path(name):
return ty_from_state_dict(target, state)
[docs]def to_state_dict(target) -> dict[str, Any]:
"""Convert an object to a state dictionary.
Args:
target: Object to convert, typically a PyTree with registered types.
Returns:
dict: State dictionary representation of the target.
Note:
Only registered types are converted; others are returned as-is.
Examples:
>>> model = MyModel()
>>> state = to_state_dict(model)
>>>
"""
if _is_namedtuple(target):
ty = _NamedTuple
else:
ty = type(target)
if ty not in _STATE_DICT_REGISTRY:
return target
ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0]
state_dict = ty_to_state_dict(target)
if isinstance(state_dict, dict):
for key in state_dict.keys():
if not isinstance(key, str):
raise TypeError("A state dict must only have string keys.")
return state_dict
[docs]def is_serializable(target):
"""Check if a target object is serializable.
Args:
target: Object or type to check for serializability.
Returns:
bool: True if the target type is registered for serialization.
Examples:
>>> is_serializable(my_model)
True
>>> is_serializable(42)
False
"""
if not isinstance(target, type):
target = type(target)
return target in _STATE_DICT_REGISTRY
[docs]def register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, override=False):
"""Register a type for serialization.
Args:
ty: the type to be registered
ty_to_state_dict: a function that takes an instance of ty and
returns its state as a dictionary.
ty_from_state_dict: a function that takes an instance of ty and
a state dict, and returns a copy of the instance with the restored state.
override: override a previously registered serialization handler
(default: False).
"""
if ty in _STATE_DICT_REGISTRY and not override:
raise ValueError(f'a serialization handler for "{ty.__name__}" is already registered')
_STATE_DICT_REGISTRY[ty] = (ty_to_state_dict, ty_from_state_dict)
def _list_state_dict(xs: list[Any]) -> dict[str, Any]:
"""Convert a list to a state dictionary representation.
Args:
xs: The list to convert.
Returns:
dict[str, Any]: Dictionary with string indices as keys and
serialized values.
"""
return {str(i): to_state_dict(x) for i, x in enumerate(xs)}
def _restore_list(xs, state_dict: dict[str, Any]) -> list[Any]:
"""Restore a list from its state dictionary representation.
Args:
xs: The target list providing structure information.
state_dict: Dictionary with string indices as keys containing
serialized values.
Returns:
list[Any]: Restored list with deserialized values.
Raises:
ValueError: If the state_dict size doesn't match the target list.
"""
if len(state_dict) != len(xs):
raise ValueError(
"The size of the list and the state dict do not match,"
f" got {len(xs)} and {len(state_dict)} "
f"at path {current_path()}"
)
ys = []
for i in range(len(state_dict)):
y = from_state_dict(xs[i], state_dict[str(i)], name=str(i))
ys.append(y)
return ys
def _dict_state_dict(xs: dict[str, Any]) -> dict[str, Any]:
"""Convert a dictionary to a state dictionary representation.
Args:
xs: The dictionary to convert.
Returns:
dict[str, Any]: Dictionary with string keys and serialized values.
Raises:
ValueError: If dictionary keys don't have unique string representations.
"""
str_keys = {str(k) for k in xs.keys()}
if len(str_keys) != len(xs):
raise ValueError(f"Dict keys do not have a unique string representation: {str_keys} vs given: {xs}")
return {str(key): to_state_dict(value) for key, value in xs.items()}
def _restore_dict(xs, states: dict[str, Any]) -> dict[str, Any]:
"""Restore a dictionary from its state dictionary representation.
Args:
xs: The target dictionary providing structure information.
states: Dictionary containing serialized values.
Returns:
dict[str, Any]: Restored dictionary with deserialized values.
Raises:
ValueError: If target dict has keys not present in state dict.
"""
diff = set(map(str, xs.keys())).difference(states.keys())
if diff:
raise ValueError(
"The target dict keys and state dict keys do not match, target dict"
f" contains keys {diff} which are not present in state dict at path"
f" {current_path()}"
)
return {key: from_state_dict(value, states[str(key)], name=str(key)) for key, value in xs.items()}
def _namedtuple_state_dict(nt) -> dict[str, Any]:
"""Convert a namedtuple to a state dictionary representation.
Args:
nt: The namedtuple instance to convert.
Returns:
dict[str, Any]: Dictionary with field names as keys and
serialized field values.
"""
return {key: to_state_dict(getattr(nt, key)) for key in nt._fields}
def _restore_namedtuple(xs, state_dict: dict[str, Any]):
"""Restore a namedtuple from its state dictionary representation.
Handles both the standard field-value format and the legacy format
with 'name', 'fields', and 'values' keys.
Args:
xs: The target namedtuple instance providing type information.
state_dict: Dictionary containing serialized field values.
Returns:
A new namedtuple instance of the same type as xs with
restored field values.
Raises:
ValueError: If field names in state_dict don't match the
namedtuple's fields.
"""
if set(state_dict.keys()) == {"name", "fields", "values"}:
state_dict = {
state_dict["fields"][str(i)]: state_dict["values"][str(i)] for i in range(len(state_dict["fields"]))
}
sd_keys = set(state_dict.keys())
nt_keys = set(xs._fields)
if sd_keys != nt_keys:
raise ValueError(
"The field names of the state dict and the named tuple do not match,"
f" got {sd_keys} and {nt_keys} at path {current_path()}"
)
fields = {k: from_state_dict(getattr(xs, k), v, name=k) for k, v in state_dict.items()}
return type(xs)(**fields)
register_serialization_state(dict, _dict_state_dict, _restore_dict)
register_serialization_state(list, _list_state_dict, _restore_list)
register_serialization_state(
tuple,
_list_state_dict,
lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict)),
)
register_serialization_state(_NamedTuple, _namedtuple_state_dict, _restore_namedtuple)
register_serialization_state(
jax.tree_util.Partial,
lambda x: (
{
"args": to_state_dict(x.args),
"keywords": to_state_dict(x.keywords),
}
),
lambda x, sd: jax.tree_util.Partial(
x.func,
*from_state_dict(x.args, sd["args"]),
**from_state_dict(x.keywords, sd["keywords"]),
),
)
def _ndarray_to_bytes(arr) -> bytes:
"""Serialize a numpy or JAX array to bytes using msgpack.
The array is serialized as a tuple of (shape, dtype_name, raw_bytes)
for efficient storage and reconstruction.
Args:
arr: A numpy ndarray or JAX Array to serialize.
Returns:
bytes: Msgpack-encoded bytes containing shape, dtype, and array data.
Raises:
ValueError: If the array has object or structured dtypes which
are not supported for serialization.
"""
if isinstance(arr, jax.Array):
arr = np.array(arr)
if arr.dtype.hasobject or arr.dtype.isalignedstruct:
raise ValueError("Object and structured dtypes not supported for serialization of ndarrays.")
tpl = (arr.shape, arr.dtype.name, arr.tobytes("C"))
return msgpack.packb(tpl, use_bin_type=True)
def _dtype_from_name(name: str):
"""Convert a dtype name string to a numpy/JAX dtype object.
Handles special cases like JAX's bfloat16 which is not a standard
numpy dtype.
Args:
name: The dtype name as a string (may be bytes for msgpack compatibility).
Returns:
A numpy dtype or JAX dtype object corresponding to the name.
"""
if name == b"bfloat16":
return jax.numpy.bfloat16
else:
return np.dtype(name)
def _ndarray_from_bytes(data: bytes) -> np.ndarray:
"""Deserialize a numpy array from msgpack-encoded bytes.
Reconstructs the array from the (shape, dtype_name, raw_bytes) tuple
format used by _ndarray_to_bytes.
Args:
data: Msgpack-encoded bytes containing array shape, dtype, and data.
Returns:
np.ndarray: The reconstructed numpy array.
"""
shape, dtype_name, buffer = msgpack.unpackb(data, raw=True)
return np.frombuffer(buffer, dtype=_dtype_from_name(dtype_name), count=-1, offset=0).reshape(shape, order="C")
class _MsgpackExtType(enum.IntEnum):
"""MessagePack extension type identifiers for custom serialization.
These integer codes are used to identify custom types when encoding
and decoding MessagePack data, allowing proper handling of numpy arrays,
complex numbers, and numpy scalars.
Attributes:
ndarray: Extension type code for numpy/JAX arrays.
native_complex: Extension type code for Python complex numbers.
npscalar: Extension type code for numpy scalar values.
"""
ndarray = 1
native_complex = 2
npscalar = 3
def _msgpack_ext_pack(x):
"""Encode custom types for MessagePack serialization.
Args:
x: Object to encode (numpy array, JAX array, complex number, etc.).
Returns:
msgpack.ExtType: Encoded object or original if not a custom type.
Raises:
ValueError: If array is not fully addressable (sharded across devices).
Note:
Arrays must be fully addressable for serialization.
For sharded arrays across multiple devices, ensure they are
gathered to a single device before serialization.
"""
if isinstance(x, np.ndarray | jax.Array):
if isinstance(x, jax.Array) and hasattr(x, "is_fully_addressable"):
if not x.is_fully_addressable:
raise ValueError(
"Cannot serialize non-fully-addressable array. "
"Consider gathering the array to a single device first."
)
return msgpack.ExtType(_MsgpackExtType.ndarray, _ndarray_to_bytes(x))
if isinstance(x, np.generic):
return msgpack.ExtType(_MsgpackExtType.npscalar, _ndarray_to_bytes(np.asarray(x)))
elif isinstance(x, complex):
return msgpack.ExtType(_MsgpackExtType.native_complex, msgpack.packb((x.real, x.imag)))
return x
def _msgpack_ext_unpack(code, data):
"""Decode custom types from MessagePack.
Args:
code: MessagePack extension type code.
data: Encoded data bytes.
Returns:
Decoded object (numpy array, complex number, etc.).
Raises:
ValueError: If code is not recognized.
"""
if code == _MsgpackExtType.ndarray:
return _ndarray_from_bytes(data)
elif code == _MsgpackExtType.native_complex:
complex_tuple = msgpack.unpackb(data)
return complex(complex_tuple[0], complex_tuple[1])
elif code == _MsgpackExtType.npscalar:
ar = _ndarray_from_bytes(data)
return ar[()]
return msgpack.ExtType(code, data)
MAX_CHUNK_SIZE = 2**30
def _np_convert_in_place(d):
"""Convert JAX arrays to numpy arrays in a nested structure, in place.
Recursively traverses dictionaries and converts any JAX Array leaves
to numpy arrays. This is necessary because msgpack cannot directly
serialize JAX arrays.
Args:
d: A dictionary or JAX array to convert.
Returns:
The input with JAX arrays converted to numpy arrays. For dictionaries,
the conversion is done in place; for direct JAX arrays, a new numpy
array is returned.
"""
if isinstance(d, dict):
for k, v in d.items():
if isinstance(v, jax.Array):
d[k] = np.array(v)
elif isinstance(v, dict):
_np_convert_in_place(v)
elif isinstance(d, jax.Array):
return np.array(d)
return d
def _tuple_to_dict(tpl):
"""Convert a tuple to a dictionary with string index keys.
Args:
tpl: A tuple to convert.
Returns:
dict: Dictionary mapping string indices to tuple values.
"""
return {str(x): y for x, y in enumerate(tpl)}
def _dict_to_tuple(dct):
"""Convert a string-index dictionary back to a tuple.
Args:
dct: Dictionary with string indices as keys ("0", "1", ...).
Returns:
tuple: Tuple of values in index order.
"""
return tuple(dct[str(i)] for i in range(len(dct)))
def _chunk(arr) -> dict[str, Any]:
"""Split a large array into chunks for serialization.
Arrays larger than MAX_CHUNK_SIZE are split into multiple smaller
chunks to avoid msgpack size limitations and memory issues.
Args:
arr: A numpy array to chunk.
Returns:
dict[str, Any]: A dictionary containing:
- "__msgpack_chunked_array__": True (marker)
- "shape": Dictionary representation of the original shape
- "chunks": Dictionary of array chunks
"""
chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize))
data = {"__msgpack_chunked_array__": True, "shape": _tuple_to_dict(arr.shape)}
flatarr = arr.reshape(-1)
chunks = [flatarr[i : i + chunksize] for i in range(0, flatarr.size, chunksize)]
data["chunks"] = _tuple_to_dict(chunks)
return data
def _unchunk(data: dict[str, Any]):
"""Reconstruct an array from its chunked dictionary representation.
Args:
data: Dictionary containing chunked array data with keys:
- "__msgpack_chunked_array__": True (marker)
- "shape": Dictionary representation of the original shape
- "chunks": Dictionary of array chunks
Returns:
np.ndarray: The reconstructed array with original shape.
Raises:
ValueError: If the chunked array marker is not present.
"""
if "__msgpack_chunked_array__" not in data:
raise ValueError("Expected chunked array marker '__msgpack_chunked_array__'.")
shape = _dict_to_tuple(data["shape"])
flatarr = np.concatenate(_dict_to_tuple(data["chunks"]))
return flatarr.reshape(shape)
def _chunk_array_leaves_in_place(d):
"""Recursively chunk oversized arrays in a nested structure.
Traverses dictionaries and replaces any numpy arrays exceeding
MAX_CHUNK_SIZE with their chunked dictionary representation.
Args:
d: A dictionary or numpy array to process.
Returns:
The input with oversized arrays replaced by chunked representations.
For dictionaries, modification is done in place.
"""
if isinstance(d, dict):
for k, v in d.items():
if isinstance(v, np.ndarray):
if v.size * v.dtype.itemsize > MAX_CHUNK_SIZE:
d[k] = _chunk(v)
elif isinstance(v, dict):
_chunk_array_leaves_in_place(v)
elif isinstance(d, np.ndarray):
if d.size * d.dtype.itemsize > MAX_CHUNK_SIZE:
return _chunk(d)
return d
def _unchunk_array_leaves_in_place(d):
"""Recursively reconstruct arrays from chunked representations.
Traverses dictionaries and replaces any chunked array dictionaries
(identified by the "__msgpack_chunked_array__" marker) with
reconstructed numpy arrays.
Args:
d: A dictionary potentially containing chunked array representations.
Returns:
The input with chunked representations replaced by numpy arrays.
For dictionaries, modification is done in place.
"""
if isinstance(d, dict):
if "__msgpack_chunked_array__" in d:
return _unchunk(d)
else:
for k, v in d.items():
if isinstance(v, dict) and "__msgpack_chunked_array__" in v:
d[k] = _unchunk(v)
elif isinstance(v, dict):
_unchunk_array_leaves_in_place(v)
return d
[docs]def validate_serializable(pytree) -> tuple[bool, list[str]]:
"""Validate if a pytree is fully serializable.
Args:
pytree: PyTree to validate
Returns:
Tuple of (is_valid, list_of_issues)
"""
issues = []
def check_leaf(path, x):
if isinstance(x, jax.Array):
if hasattr(x, "is_fully_addressable") and not x.is_fully_addressable:
issues.append(f"Non-addressable array at {path}")
return x
from ._tree_util import tree_map_with_path
tree_map_with_path(check_leaf, pytree)
return len(issues) == 0, issues
[docs]def get_serialization_info(pytree) -> dict[str, Any]:
"""Get information about pytree serialization.
Args:
pytree: PyTree to analyze
Returns:
Dictionary with serialization information
"""
from ._tree_util import tree_bytes, tree_size
info = {
"num_leaves": len(jax.tree_util.tree_leaves(pytree)),
"tree_structure": str(jax.tree_util.tree_structure(pytree)),
"total_elements": tree_size(pytree),
"memory_bytes": tree_bytes(pytree),
"is_serializable": is_serializable(pytree),
}
leaves = jax.tree_util.tree_leaves(pytree)
dtypes = set()
for leaf in leaves:
if hasattr(leaf, "dtype"):
dtypes.add(str(leaf.dtype))
info["dtypes"] = list(dtypes)
return info
[docs]def msgpack_serialize(pytree, in_place: bool = False) -> bytes:
"""Save data structure to bytes in msgpack format.
Low-level function that only supports python trees with array leaves,
for custom objects use ``to_bytes``. It splits arrays above MAX_CHUNK_SIZE into
multiple chunks.
Args:
pytree: python tree of dict, list, tuple with python primitives
and array leaves.
in_place: boolean specifying if pytree should be modified in place.
Returns:
msgpack-encoded bytes of pytree.
"""
if not in_place:
pytree = jax.tree_util.tree_map(lambda x: x, pytree)
pytree = _np_convert_in_place(pytree)
pytree = _chunk_array_leaves_in_place(pytree)
return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
[docs]def msgpack_restore(encoded_pytree: bytes):
"""Restore data structure from bytes in msgpack format.
Low-level function that only supports python trees with array leaves,
for custom objects use ``from_bytes``.
Args:
encoded_pytree: msgpack-encoded bytes of python tree.
Returns:
Python tree of dict, list, tuple with python primitive
and array leaves.
"""
state_dict = msgpack.unpackb(encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False)
return _unchunk_array_leaves_in_place(state_dict)
[docs]def from_bytes(target, encoded_bytes: bytes):
"""Restore optimizer or other object from msgpack-serialized state-dict.
Args:
target: template object with state-dict registrations that matches
the structure being deserialized from ``encoded_bytes``.
encoded_bytes: msgpack serialized object structurally isomorphic to
``target``. Typically a flax model or optimizer.
Returns:
A new object structurally isomorphic to ``target`` containing the updated
leaf data from saved data.
"""
state_dict = msgpack_restore(encoded_bytes)
return from_state_dict(target, state_dict)
[docs]def to_bytes(target) -> bytes:
"""Save optimizer or other object as msgpack-serialized state-dict.
Args:
target: template object with state-dict registrations to be
serialized to msgpack format. Typically a flax model or optimizer.
Returns:
Bytes of msgpack-encoded state-dict of ``target`` object.
"""
state_dict = to_state_dict(target)
return msgpack_serialize(state_dict, in_place=True)
[docs]def save_to_file(target, filepath: str) -> None:
"""Save pytree to a file.
Args:
target: PyTree to save
filepath: Path to save file
"""
from pathlib import Path
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
data = to_bytes(target)
with open(filepath, "wb") as f:
f.write(data)
[docs]def load_from_file(target, filepath: str):
"""Load pytree from a file.
Args:
target: Template pytree for structure
filepath: Path to load from
Returns:
Loaded pytree with same structure as target
"""
with open(filepath, "rb") as f:
data = f.read()
return from_bytes(target, data)
[docs]def compress_bytes(data: bytes) -> bytes:
"""Compress serialized bytes using gzip.
Args:
data: Bytes to compress
Returns:
Compressed bytes
"""
import gzip
return gzip.compress(data, compresslevel=6)
[docs]def decompress_bytes(data: bytes) -> bytes:
"""Decompress gzipped bytes.
Args:
data: Compressed bytes
Returns:
Decompressed bytes
"""
import gzip
return gzip.decompress(data)
[docs]def to_compressed_bytes(target) -> bytes:
"""Save pytree as compressed msgpack bytes.
Args:
target: PyTree to serialize
Returns:
Compressed msgpack bytes
"""
return compress_bytes(to_bytes(target))
[docs]def from_compressed_bytes(target, compressed_data: bytes):
"""Load pytree from compressed msgpack bytes.
Args:
target: Template pytree
compressed_data: Compressed msgpack bytes
Returns:
Loaded pytree
"""
return from_bytes(target, decompress_bytes(compressed_data))