eformer.pytree._tree_util

Contents

eformer.pytree._tree_util#

class eformer.pytree._tree_util.MetaValueRecreator(seed: int = 42)[source]#

Bases: object

Helper class for generating unique meta values with state tracking.

Provides methods to generate incrementing counts and random keys in a reproducible manner. Useful for reinitializing model metadata that requires unique values.

_count#

Internal counter for generating unique count values.

_rng#

JAX random key state for generating random values.

Examples

>>> recreator = MetaValueRecreator(seed=42)
>>> count1 = recreator.get_count()  # Returns 0
>>> count2 = recreator.get_count()  # Returns 1
>>> key = recreator.get_rng()  # Returns a unique random key
get_count() Array[source]#

Get the next count value and increment the counter.

Returns

Current count as a uint32 array.

Return type

jnp.ndarray

get_rng() PRNGKey[source]#

Get a new random key and update internal state.

Returns

A new random key split from the internal state.

Return type

jax.random.PRNGKey

class eformer.pytree._tree_util.NonePolicy(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: StrEnum

Policy for handling None values in tree operations.

PRESERVE#

Keep None values as-is in the tree.

REPLACE#

Replace None values with a specified replacement.

ERROR#

Raise an error when None values are encountered.

ERROR = 'error'#
PRESERVE = 'preserve'#
REPLACE = 'replace'#
class eformer.pytree._tree_util.PackedLeaf(offset: int, shape: tuple[int, ...])[source]#

Bases: object

Metadata describing the location and shape of a leaf in a packed array.

Used by pack_pytree and unpack_pytree to track where each leaf’s data is stored within the flattened 1-D array representation.

offset#

Starting index of this leaf’s data in the packed array.

Type

int

shape#

Original shape of this leaf array before packing.

Type

tuple[int, …]

offset: int#
shape: tuple[int, ...]#
class eformer.pytree._tree_util.StateValidationResult(is_valid: bool, missing_keys: set, invalid_types: dict[str, type])[source]#

Bases: object

Result of validating a state dictionary against a target structure.

This class stores the outcome of state validation, including whether the validation passed and details about any issues found.

is_valid#

True if validation passed, False otherwise.

Type

bool

missing_keys#

Set of keys present in target but missing from state.

Type

set

invalid_types#

Dictionary mapping key paths to their incorrect types.

Type

dict[str, type]

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

invalid_types: dict[str, type]#
is_valid: bool#
missing_keys: set#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class eformer.pytree._tree_util.TreeFilter(*args, **kwargs)[source]#

Bases: Protocol

Protocol defining the interface for tree filter functions.

Tree filters are callable objects that take a mask (boolean or callable) and an argument (the tree to filter), returning a filtered tree dictionary.

This protocol enables type checking for functions that implement tree filtering logic.

eformer.pytree._tree_util.deepcopy_tree(model)[source]#

Creates a deep copy of a JAX model.

This function takes a JAX model, extracts its leaves (the individual components of the model), deep copies them, and then reconstructs the model with the copied leaves.

Parameters

model – A JAX model to be deep copied. This can be any nested structure of JAX arrays, lists, tuples, dicts, etc.

Returns

A deep copy of the input model with the same structure but with all leaves deep copied.

eformer.pytree._tree_util.empty_node = _EmptyNode(  )#

Singleton instance of _EmptyNode used as the empty node marker.

eformer.pytree._tree_util.flatten_dict(xs: Union[dict, Mapping], keep_empty_nodes: bool = False, is_leaf: Optional[Callable[[tuple, Any], bool]] = None, sep: str | None = None, fumap: bool = False) dict[tuple | str, Any][source]#

Enhanced dictionary flattening with better type handling and validation.

Parameters
  • xs – Dictionary or mapping to flatten

  • keep_empty_nodes – Whether to keep empty dictionary nodes

  • is_leaf – Optional function to determine leaf nodes

  • sep – Optional separator for string keys

Returns

Flattened dictionary

Raises

TypeError – If input is not a dictionary or mapping

eformer.pytree._tree_util.flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, is_leaf: None | collections.abc.Callable[[tuple[Any, ...], collections.abc.Mapping[Any, Any]], bool] = None, sep: None = None) dict[tuple[Any, ...], Any][source]#
eformer.pytree._tree_util.flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, is_leaf: None | collections.abc.Callable[[tuple[Any, ...], collections.abc.Mapping[Any, Any]], bool] = None, sep: str) dict[str, Any]

Flatten a nested mapping.

The nested keys are flattened to a tuple. See unflatten_mapping on how to restore the nested mapping.

Example:

>>> from flax import nnx
>>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
>>> flat_xs = nnx.traversals.flatten_mapping(xs)
>>> flat_xs
{('foo',): 1, ('bar', 'a'): 2}

Note that empty mappings are ignored and will not be restored by unflatten_mapping.

Parameters
  • xs – a nested mapping

  • keep_empty_nodes – replaces empty mappings with traverse_util.empty_node.

  • is_leaf – an optional function that takes the next nested mapping and nested keys and returns True if the nested mapping is a leaf (i.e., should not be flattened further).

  • sep – if specified, then the keys of the returned mapping will be sep-joined strings (if None, then keys will be tuples).

Returns

The flattened mapping.

eformer.pytree._tree_util.flatten_to_sequence(xs: Mapping[Any, Any], /, *, is_leaf: collections.abc.Callable[[tuple[Any, ...], collections.abc.Mapping[Any, Any]], bool] | None = None) list[tuple[Any, Any]][source]#

Flatten a nested mapping.

The nested keys are flattened to a tuple. See unflatten_mapping on how to restore the nested mapping.

Example:

>>> from flax import nnx
>>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
>>> flat_xs = nnx.traversals.flatten_to_sequence(xs)
>>> flat_xs
[(('foo',), 1), (('bar', 'a'), 2)]

Note that empty mappings are ignored and will not be restored by unflatten_mapping.

Parameters
  • xs – a nested mapping

  • is_leaf – an optional function that takes the next nested mapping and nested keys and returns True if the nested mapping is a leaf (i.e., should not be flattened further).

Returns

The flattened mapping.

eformer.pytree._tree_util.flatten_tree(xs: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None, sep: str | None = None) dict[str, Any][source]#

Flatten a JAX tree and convert paths to strings.

Parameters
  • xs – The JAX tree to flatten.

  • is_leaf – Optional function to determine leaf nodes.

  • sep – Separator to use when joining path elements.

Returns

A flattened dictionary with string keys representing the tree paths.

eformer.pytree._tree_util.int_key_to_string(xs)[source]#

Convert integer keys in a dictionary to strings.

Parameters

xs – Dictionary possibly with integer or tuple keys.

Returns

Dictionary with string keys.

Return type

dict

Examples

>>> d = {(0, 1): 'value'}
>>> int_key_to_string(d)
>>>
eformer.pytree._tree_util.is_array(element: Any) bool[source]#

Check if an element is a JAX array or NumPy array.

Parameters

element – The object to check.

Returns

True if element is a JAX Array, NumPy ndarray, or NumPy generic type.

Return type

bool

Examples

>>> is_array(jnp.array([1, 2, 3]))
True
>>> is_array(np.array([1, 2, 3]))
True
>>> is_array([1, 2, 3])
False
eformer.pytree._tree_util.is_array_like(element: Any) bool[source]#

Check if an element is array-like (arrays or scalar numeric types).

Parameters

element – The object to check.

Returns

True if element is an array or numeric scalar type.

Return type

bool

Note

This includes JAX arrays, NumPy arrays, and Python numeric types (int, float, complex, bool), as well as objects with __jax_array__ attribute.

Examples

>>> is_array_like(jnp.array([1, 2]))
True
>>> is_array_like(5.0)
True
>>> is_array_like("string")
False
eformer.pytree._tree_util.is_flatten(tree: dict) bool[source]#

Checks if a dictionary represents a flattened tree.

A flattened tree is a dictionary where the keys are tuples representing the path to the leaf nodes. This function checks if any of the keys in the input dictionary is a tuple, indicating a flattened tree.

Parameters

tree – The dictionary to check.

Returns

True if the dictionary is a flattened tree, False otherwise.

Return type

bool

eformer.pytree._tree_util.is_iterable(obj)[source]#

Check if an object is iterable.

Parameters

obj – Object to check.

Returns

True if object is iterable, False otherwise.

Return type

bool

Examples

>>> is_iterable([1, 2, 3])
True
>>> is_iterable(42)
False
eformer.pytree._tree_util.join_key(prefix, k)[source]#

Concatenate a prefix and key using dot-notation.

Creates hierarchical key paths by joining components with dots. Handles None keys and empty prefixes gracefully.

Parameters
  • prefix – The prefix string (can be empty string).

  • k – The key to append (can be None).

Returns

The joined key path.

Return type

str

Examples

>>> join_key('layer', 'weight')
'layer.weight'
>>> join_key('', 'bias')
'bias'
>>> join_key('layer', None)
'layer'
eformer.pytree._tree_util.key_path_to_str(path: Sequence) str[source]#

Convert a JAX key path element to a string representation.

Handles various JAX key types (SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey) and converts them to readable string format.

Parameters

path – A sequence containing JAX key path elements. Only the last element is processed.

Returns

String representation of the last path element, or empty

string if path is empty.

Return type

str

Examples

>>> from jax._src.tree_util import DictKey, SequenceKey
>>> key_path_to_str([DictKey("weights")])
'weights'
>>> key_path_to_str([SequenceKey(0)])
'0'
eformer.pytree._tree_util.leaf_key_paths(pytree, prefix: str | None = '', *, is_leaf: collections.abc.Callable[[Any], bool] | None = None, use_state_dict_keys: bool = False)[source]#

Return a tree mirroring pytree whose leaves are their dot-path strings.

Parameters
  • pytree – The input tree to traverse.

  • prefix – Optional prefix added to every returned path. None resets to "".

  • is_leaf – Optional custom leaf predicate forwarded to jax.tree_util.tree_flatten_with_path().

  • use_state_dict_keys – Reserved for compatibility with other libraries; currently unused.

Returns

A PyTree with the same structure as pytree whose leaves are strings representing the dotted traversal path, or None when pytree has no leaves.

Example

>>> tree = {"layer": {"w": 1, "b": 2}, "scale": 3}
>>> leaf_key_paths(tree)
{'layer': {'w': 'layer.w', 'b': 'layer.b'}, 'scale': 'scale'}
eformer.pytree._tree_util.merge(*pytrees: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None) PyTree[source]#

Combine multiple PyTrees into a single PyTree.

Takes the first non-None value at each position across all input trees.

Parameters
  • *pytrees – Variable number of PyTrees to merge.

  • is_leaf – Optional function to determine if a node is a leaf.

Returns

Combined tree with first non-None value at each position.

Return type

PyTree

Note

This is useful for combining partial trees or filling in missing values.

Examples

>>> tree1 = {"a": 1, "b": None}
>>> tree2 = {"a": None, "b": 2}
>>> merged = merge(tree1, tree2)
>>>
eformer.pytree._tree_util.named_tree_map(f: Callable[[str, Any, Any], Any], tree: PyTree, *rest: Any, is_leaf: Optional[Callable[[Any], bool]] = None, sep: str | None = None) PyTree[source]#

An extended version of jax.tree_util.tree_map.

This function extends jax.tree_util.tree_map by providing the path (as a string) to the current leaf node as an argument to the mapped function f.

Parameters
  • f – The function to apply to each leaf node, taking the path and value as input.

  • tree – The JAX tree to map over.

  • *rest – Additional arguments to be passed to f.

  • is_leaf – Optional function to determine leaf nodes.

  • sep – Separator to use when joining path elements.

Returns

A new tree with the same structure as tree but with the values modified by f.

eformer.pytree._tree_util.pack_pytree(tree: ~eformer.pytree._pytree.PyTree, dtype=<class 'jax.numpy.float32'>) tuple[eformer.pytree._pytree.PyTree, jax.jaxlib._jax.Array][source]#

Pack all leaves of a pytree into a single 1-D array.

This function flattens all array leaves into a contiguous 1-D array, which is useful for optimization algorithms that work on flat parameter vectors or for efficient storage/transmission.

Parameters
  • tree – Pytree of array-like objects to pack.

  • dtype – Desired dtype of the packed array (default: jnp.float32).

Returns

A pair (offset_tree, flat_array) where:
  • offset_tree has the same structure as tree but each leaf is replaced with a PackedLeaf containing offset and shape information.

  • flat_array is a 1-D array containing all leaf data.

Return type

tuple

Examples

>>> tree = {"weights": jnp.ones((2, 3)), "bias": jnp.zeros(3)}
>>> offset_tree, packed = pack_pytree(tree)
>>> packed.shape
(9,)
>>> original = unpack_pytree(offset_tree, packed)
eformer.pytree._tree_util.recursive_merge(full_tree, updates)[source]#

Recursively merge two PyTrees where updates may have fewer parameters.

Parameters
  • full_tree – The complete parameter tree

  • updates – Tree with updated values (subset of full_tree)

Returns

Merged tree with updated values where available

eformer.pytree._tree_util.specs_to_name_sharding(tree: dict, mesh: jax._src.mesh.Mesh | None = None) dict[source]#

Converts a dictionary of specifications to a dictionary of NamedSharding objects.

Parameters
  • tree (Dict) – A dictionary where the keys are names and the values are specifications.

  • mesh (Optional[Mesh]) – An optional Mesh object. If not provided, the default physical mesh from pxla.thread_resources.env.physical_mesh is used.

Returns

A dictionary where the keys are the same as the input dictionary, and the values are NamedSharding

objects created from the specifications and the provided or default mesh.

Return type

Dict

eformer.pytree._tree_util.split(pytree: PyTree, filter_spec: Union[bool, Callable[[Any], bool]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) tuple[eformer.pytree._pytree.PyTree, eformer.pytree._pytree.PyTree][source]#

Split a PyTree into two based on a filter specification.

Parameters
  • pytree – The PyTree to split.

  • filter_spec – Either a boolean or callable that determines the split. If bool, applies uniformly. If callable, applied to each leaf.

  • replace – Value to use for filtered-out positions (default: None).

  • is_leaf – Optional function to determine leaf nodes.

Returns

Two PyTrees where:
  • First contains values where filter is True (others replaced)

  • Second contains values where filter is False (others replaced)

Return type

tuple[PyTree, PyTree]

Examples

>>> tree = {"a": jnp.array([1, 2]), "b": jnp.array([3, 4])}
>>>
>>> large, small = split(tree, lambda x: x.size > 2)
eformer.pytree._tree_util.string_key_to_int(xs)[source]#

Convert string keys in a dictionary to integers where possible.

Parameters

xs – Dictionary with string or tuple keys.

Returns

Dictionary with integer keys where applicable.

Return type

dict

Examples

>>> d = {('0', '1'): 'value'}
>>> string_key_to_int(d)
>>>
eformer.pytree._tree_util.tree_abs(tree: PyTree) PyTree[source]#

Compute absolute values of all elements in a pytree.

Parameters

tree – Input pytree containing numerical values.

Returns

New tree with absolute values.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([-1, 2, -3]), "b": -4.5}
>>> result = tree_abs(tree)
>>>
eformer.pytree._tree_util.tree_add(tree1: PyTree, tree2: PyTree) PyTree[source]#

Element-wise addition of two pytrees.

Parameters
  • tree1 – First pytree.

  • tree2 – Second pytree (must have same structure as tree1).

Returns

New tree with element-wise sum of values.

Return type

PyTree

Examples

>>> tree1 = {"a": jnp.array([1, 2]), "b": jnp.array([3])}
>>> tree2 = {"a": jnp.array([4, 5]), "b": jnp.array([6])}
>>> result = tree_add(tree1, tree2)
>>>
eformer.pytree._tree_util.tree_all(tree: PyTree) bool[source]#

Check if all values in the pytree are True.

Parameters

tree – Input pytree containing boolean or numerical values.

Returns

True if all elements in all arrays are True/non-zero.

Return type

bool

Examples

>>> tree = {"a": jnp.array([True, True]), "b": jnp.array([True])}
>>> tree_all(tree)
>>>
>>>
>>> tree2 = {"x": jnp.array([1, 2]), "y": jnp.array([0])}
>>> tree_all(tree2)
>>>
eformer.pytree._tree_util.tree_any(tree: PyTree) bool[source]#

Check if any value in the pytree is True.

Parameters

tree – Input pytree containing boolean or numerical values.

Returns

True if any element in any array is True/non-zero.

Return type

bool

Examples

>>> tree = {"a": jnp.array([False, False]), "b": jnp.array([True])}
>>> tree_any(tree)
>>>
>>>
>>> tree2 = {"x": jnp.array([0, 0]), "y": jnp.array([0])}
>>> tree_any(tree2)
>>>
eformer.pytree._tree_util.tree_apply(fns: dict[Any, Callable[[Any], Any]], tree: dict[Any, Any]) dict[Any, Any][source]#

Apply a dictionary of functions to a corresponding PyTree.

Parameters
  • fns – A dictionary where keys match the PyTree structure and values are functions.

  • tree – The PyTree to apply functions to.

Returns

A new PyTree with the same structure as tree, but with values modified by the functions in fns.

eformer.pytree._tree_util.tree_bytes(tree: PyTree) int[source]#

Calculate the total memory usage of a pytree in bytes.

Parameters

tree – Input pytree

Returns

Total memory usage in bytes

eformer.pytree._tree_util.tree_cast(tree: PyTree, dtype: Any) PyTree[source]#

Cast all arrays in a pytree to a specified dtype.

Parameters
  • tree – Input pytree containing arrays.

  • dtype – Target dtype (e.g., jnp.float32, jnp.int32).

Returns

New tree with arrays cast to the specified dtype.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1, 2], dtype=jnp.int32)}
>>> result = tree_cast(tree, jnp.float32)
>>>
eformer.pytree._tree_util.tree_clip(tree: PyTree, min_val: Any = None, max_val: Any = None) PyTree[source]#

Clip values in a pytree to a specified range.

Parameters
  • tree – Input pytree containing numerical arrays.

  • min_val – Minimum value for clipping (inclusive).

  • max_val – Maximum value for clipping (inclusive).

Returns

New tree with values clipped to [min_val, max_val].

Return type

PyTree

Examples

>>> tree = {"weights": jnp.array([-2, 0, 5, 10])}
>>> clipped = tree_clip(tree, min_val=0, max_val=5)
>>>
eformer.pytree._tree_util.tree_concatenate(trees: list[eformer.pytree._pytree.PyTree], axis: int = 0) PyTree[source]#

Concatenate corresponding arrays in a list of PyTrees.

Parameters
  • trees – List of PyTrees with matching structure.

  • axis – Axis along which to concatenate arrays (default: 0).

Returns

Single tree with concatenated arrays.

Return type

PyTree

Examples

>>> tree1 = {"a": jnp.array([1, 2])}
>>> tree2 = {"a": jnp.array([3, 4])}
>>> result = tree_concatenate([tree1, tree2])
>>>
eformer.pytree._tree_util.tree_divide(tree1: PyTree, tree2: eformer.pytree._pytree.PyTree | Any) PyTree[source]#

Element-wise division of pytrees or scalar division.

Parameters
  • tree1 – First pytree (dividend).

  • tree2 – Second pytree (same structure) or scalar divisor.

Returns

New tree with element-wise or scalar quotient.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([4.0, 6.0]), "b": jnp.array([8.0])}
>>>
>>> result1 = tree_divide(tree, 2.0)
>>>
>>>
>>>
>>> tree2 = {"a": jnp.array([2.0, 3.0]), "b": jnp.array([4.0])}
>>> result2 = tree_divide(tree, tree2)
>>>
eformer.pytree._tree_util.tree_dot(tree1: PyTree, tree2: PyTree) Any[source]#

Compute dot product of two pytrees.

Computes the sum of element-wise products across all arrays in the trees.

Parameters
  • tree1 – First pytree.

  • tree2 – Second pytree (must have same structure).

Returns

Scalar value representing the dot product.

Examples

>>> tree1 = {"a": jnp.array([1, 2]), "b": jnp.array([3])}
>>> tree2 = {"a": jnp.array([4, 5]), "b": jnp.array([6])}
>>> result = tree_dot(tree1, tree2)
>>>
eformer.pytree._tree_util.tree_equal(*pytrees: PyTree, typematch: bool = False, rtol=0.0, atol=0.0) bool[source]#

Check if multiple PyTrees are equal in structure and values.

Parameters
  • *pytrees – Variable number of PyTrees to compare.

  • typematch – If True, also check that types match exactly.

  • rtol – Relative tolerance for floating point comparison.

  • atol – Absolute tolerance for floating point comparison.

Returns

True if all trees have same structure and equal values.

Return type

bool

Examples

>>> tree1 = {"a": jnp.array([1.0, 2.0])}
>>> tree2 = {"a": jnp.array([1.0, 2.0])}
>>> tree_equal(tree1, tree2)
True
>>> tree3 = {"a": jnp.array([1.0, 2.1])}
>>> tree_equal(tree1, tree3, atol=0.2)
True
eformer.pytree._tree_util.tree_exp(tree: PyTree) PyTree[source]#

Compute exponential (e^x) of all elements in a pytree.

Parameters

tree – Input pytree containing numerical arrays.

Returns

New tree with exponential values.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([0.0, 1.0]), "b": jnp.array([2.0])}
>>> result = tree_exp(tree)
>>>
eformer.pytree._tree_util.tree_expand_dims(tree: PyTree, axis: int) PyTree[source]#

Expand dimensions of arrays in a pytree.

Parameters
  • tree – Input pytree containing arrays.

  • axis – Position in the expanded axes where the new axis is placed.

Returns

New tree with arrays having an additional dimension.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1, 2, 3])}
>>> result = tree_expand_dims(tree, axis=0)
>>>
>>>
>>> result2 = tree_expand_dims(tree, axis=1)
>>>
eformer.pytree._tree_util.tree_filter(tree: PyTree, predicate: Callable[[Any], bool]) PyTree[source]#

Filter a PyTree keeping only leaves that satisfy the predicate.

Parameters
  • tree – Input PyTree to filter.

  • predicate – Function that returns True for leaves to keep.

Returns

Filtered tree with same structure but only matching leaves.

Return type

PyTree

Note

This may change the tree structure if entire branches are filtered out.

Examples

>>> tree = {"a": jnp.array([1, 2]), "b": jnp.array([3])}
>>> filtered = tree_filter(tree, lambda x: x.size > 1)
>>>
eformer.pytree._tree_util.tree_flatten_one_level_with_keys(pytree: PyTree) tuple[list[tuple[Optional[KeyEntry], eformer.pytree._pytree.PyTree]], jaxlib._jax.pytree.PyTreeDef][source]#

Adapted form equinox.tree_flatten_one_level to return keys

If the passed in PyTree is a leaf, it will return a single-element list with None as the key and the PyTree as the value.

eformer.pytree._tree_util.tree_flatten_with_paths(tree: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None) tuple[list[tuple[tuple, Any]], jaxlib._jax.pytree.PyTreeDef][source]#

Flattens a pytree while keeping track of paths to leaves.

This function is useful when you need both the flattened values and their locations in the original tree structure.

Parameters
  • tree – Input pytree to flatten.

  • is_leaf – Optional function to determine if a node is a leaf.

Returns

A pair of (paths_and_values, treedef) where:
  • paths_and_values is a list of (path, value) tuples

  • treedef is the tree structure definition

Return type

tuple

Examples

>>> tree = {"weights": jnp.array([1, 2]), "bias": jnp.array([3])}
>>> paths_vals, treedef = tree_flatten_with_paths(tree)
>>>
>>>
eformer.pytree._tree_util.tree_isfinite(tree: PyTree) PyTree[source]#

Check for finite values in a pytree.

Parameters

tree – Input pytree containing numerical arrays.

Returns

New tree with boolean arrays indicating finite values

(not NaN or infinity).

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.0, jnp.nan, jnp.inf, 2.0])}
>>> result = tree_isfinite(tree)
>>>
eformer.pytree._tree_util.tree_isinf(tree: PyTree) PyTree[source]#

Check for infinite values in a pytree.

Parameters

tree – Input pytree containing numerical arrays.

Returns

New tree with boolean arrays indicating infinity locations.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.0, jnp.inf, -jnp.inf])}
>>> result = tree_isinf(tree)
>>>
eformer.pytree._tree_util.tree_isnan(tree: PyTree) PyTree[source]#

Check for NaN values in a pytree.

Parameters

tree – Input pytree containing numerical arrays.

Returns

New tree with boolean arrays indicating NaN locations.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.0, jnp.nan, 3.0])}
>>> result = tree_isnan(tree)
>>>
eformer.pytree._tree_util.tree_leaves_with_paths(tree: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None) list[tuple[tuple, Any]][source]#

Returns list of (path, leaf_value) pairs in the pytree.

Parameters
  • tree – Input PyTree to extract leaves from.

  • is_leaf – Optional function to determine if a node is a leaf.

Returns

List of tuples where each tuple is (path, leaf_value).

Return type

list

Examples

>>> tree = {"a": 1, "b": {"c": 2}}
>>> paths_and_vals = tree_leaves_with_paths(tree)
>>>
eformer.pytree._tree_util.tree_log(tree: PyTree) PyTree[source]#

Compute natural logarithm of all elements in a pytree.

Parameters

tree – Input pytree containing positive numerical arrays.

Returns

New tree with natural logarithm values.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.0, jnp.e]), "b": jnp.array([jnp.e**2])}
>>> result = tree_log(tree)
>>>
eformer.pytree._tree_util.tree_map_with_path(f: Callable, tree: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None) PyTree[source]#

Maps a function over a pytree while providing the path to each leaf.

Parameters
  • f – Function that takes (path, leaf_value) as arguments. The path is a tuple of string keys representing the location in the tree.

  • tree – Input pytree to map over.

  • is_leaf – Optional function to determine if a node is a leaf.

Returns

New tree with same structure but values transformed by f.

Return type

PyTree

Examples

>>> tree = {"a": 1, "b": {"c": 2, "d": 3}}
>>> result = tree_map_with_path(
...     lambda path, x: f"path={path}, value={x}",
...     tree
... )
>>>
>>>
>>>
eformer.pytree._tree_util.tree_max(tree: PyTree) Any[source]#

Find maximum value across all arrays in a pytree.

Parameters

tree – Input pytree

Returns

Maximum value

eformer.pytree._tree_util.tree_mean(tree: PyTree, axis: int | None = None) eformer.pytree._pytree.PyTree | Any[source]#

Compute mean of all values in a pytree.

Parameters
  • tree – Input pytree

  • axis – Optional axis for mean (applies to each array)

Returns

Mean of all values

eformer.pytree._tree_util.tree_min(tree: PyTree) Any[source]#

Find minimum value across all arrays in a pytree.

Parameters

tree – Input pytree

Returns

Minimum value

eformer.pytree._tree_util.tree_multiply(tree1: PyTree, tree2: eformer.pytree._pytree.PyTree | Any) PyTree[source]#

Element-wise multiplication of pytrees or scalar multiplication.

Parameters
  • tree1 – First pytree.

  • tree2 – Second pytree (same structure) or scalar value.

Returns

New tree with element-wise or scalar product.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1, 2]), "b": jnp.array([3])}
>>>
>>> result1 = tree_multiply(tree, 2)
>>>
>>>
>>>
>>> tree2 = {"a": jnp.array([2, 3]), "b": jnp.array([4])}
>>> result2 = tree_multiply(tree, tree2)
>>>
eformer.pytree._tree_util.tree_norm(tree: PyTree, ord: Any = 2) Any[source]#

Compute the norm of a pytree.

Parameters
  • tree – Input pytree

  • ord – Order of the norm (default: 2 for L2 norm)

Returns

Norm value

eformer.pytree._tree_util.tree_ones_like(tree: PyTree) PyTree[source]#

Create a PyTree of ones with the same structure and shapes.

Parameters

tree – Template PyTree to match structure and shapes.

Returns

New tree with same structure but all array values set to one.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.5, 2.5])}
>>> ones = tree_ones_like(tree)
>>>
eformer.pytree._tree_util.tree_path_to_string(path: tuple[Any, ...], sep: str | None = None) str | tuple[str, ...][source]#

Convert a JAX tree path to a string representation.

Parameters
  • path – The JAX tree path tuple.

  • sep – Separator to use when joining path elements.

Returns

The string representation of the path.

eformer.pytree._tree_util.tree_random_like(tree: PyTree, key: PRNGKey, distribution: str = 'normal', **kwargs) PyTree[source]#

Create a pytree with random values matching the structure of input tree.

Parameters
  • tree – Template pytree to match structure and shapes.

  • key – JAX random key for reproducible randomness.

  • distribution – Distribution type (‘normal’, ‘uniform’, ‘bernoulli’).

  • **kwargs – Additional arguments for the distribution: - For ‘normal’: mean, std - For ‘uniform’: minval, maxval - For ‘bernoulli’: p (probability)

Returns

New tree with same structure but random values.

Return type

PyTree

Examples

>>> key = jax.random.PRNGKey(0)
>>> tree = {"weights": jnp.zeros((2, 3))}
>>>
>>>
>>> result1 = tree_random_like(tree, key, "normal")
>>>
>>>
>>> result2 = tree_random_like(tree, key, "uniform")
>>>
>>>
>>> result3 = tree_random_like(tree, key, "uniform", minval=-1, maxval=1)
eformer.pytree._tree_util.tree_reciprocal(tree: PyTree) PyTree[source]#

Compute reciprocal (1/x) of all elements in a pytree.

Parameters

tree – Input pytree containing numerical arrays.

Returns

New tree with reciprocal values (1/x for each element).

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([2.0, 4.0]), "b": jnp.array([0.5])}
>>> result = tree_reciprocal(tree)
>>>
eformer.pytree._tree_util.tree_reduce(reducer: Callable[[Any, Any], Any], tree: PyTree, initializer: Any | None = None) Any[source]#

Reduce a pytree to a single value using a reduction function.

Parameters
  • reducer – Binary function to reduce values

  • tree – Input pytree

  • initializer – Initial value for reduction

Returns

Reduced value

eformer.pytree._tree_util.tree_replace_infs(tree: PyTree, value: Any = 0.0) PyTree[source]#

Replace infinite values in a pytree.

Parameters
  • tree – Input pytree containing numerical arrays.

  • value – Value to replace infinities with (default: 0.0).

Returns

New tree with infinite values replaced.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.0, jnp.inf, -jnp.inf, 2.0])}
>>> result = tree_replace_infs(tree, value=999.0)
>>>
eformer.pytree._tree_util.tree_replace_nans(tree: PyTree, value: Any = 0.0) PyTree[source]#

Replace NaN values in a pytree.

Parameters
  • tree – Input pytree containing numerical arrays.

  • value – Value to replace NaNs with (default: 0.0).

Returns

New tree with NaN values replaced.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.0, jnp.nan, 3.0])}
>>> result = tree_replace_nans(tree, value=-1.0)
>>>
eformer.pytree._tree_util.tree_reshape(tree: PyTree, shape: tuple[int, ...]) PyTree[source]#

Reshape arrays in a pytree to a new shape.

Parameters
  • tree – Input pytree containing arrays.

  • shape – New shape for arrays. Use -1 for automatic dimension.

Returns

New tree with reshaped arrays.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([[1, 2], [3, 4]])}
>>> result = tree_reshape(tree, (4,))
>>>
>>>
>>>
>>> result2 = tree_reshape(tree, (-1, 1))
>>>
eformer.pytree._tree_util.tree_round(tree: PyTree, decimals: int = 0) PyTree[source]#

Round all values in a pytree to a given number of decimals.

Parameters
  • tree – Input pytree containing numerical arrays.

  • decimals – Number of decimal places to round to (default: 0).

Returns

New tree with rounded values.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.234, 5.678])}
>>> result = tree_round(tree, decimals=1)
>>>
>>>
>>> result2 = tree_round(tree)
>>>
eformer.pytree._tree_util.tree_sign(tree: PyTree) PyTree[source]#

Compute sign of all elements in a pytree.

Parameters

tree – Input pytree containing numerical values.

Returns

New tree with sign values (-1, 0, or 1).

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([-2.5, 0, 3.7])}
>>> result = tree_sign(tree)
>>>
eformer.pytree._tree_util.tree_size(tree: PyTree) int[source]#

Calculate the total number of elements in a pytree.

Parameters

tree – Input pytree

Returns

Total number of elements across all arrays in the tree

eformer.pytree._tree_util.tree_sqrt(tree: PyTree) PyTree[source]#

Compute square root of all elements in a pytree.

Parameters

tree – Input pytree containing non-negative numerical arrays.

Returns

New tree with square root values.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([4.0, 9.0]), "b": jnp.array([16.0])}
>>> result = tree_sqrt(tree)
>>>
eformer.pytree._tree_util.tree_squeeze(tree: PyTree, axis: int | tuple[int, ...] | None = None) PyTree[source]#

Remove single-dimensional entries from arrays in a pytree.

Parameters
  • tree – Input pytree containing arrays.

  • axis – Axis or axes to squeeze. If None, all axes of size 1 are removed.

Returns

New tree with squeezed arrays.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([[[1], [2]]])}
>>> result = tree_squeeze(tree, axis=2)
>>>
>>>
>>>
>>> tree2 = {"b": jnp.array([[[3]]])}
>>> result2 = tree_squeeze(tree2)
>>>
eformer.pytree._tree_util.tree_stack(trees: list[eformer.pytree._pytree.PyTree], axis: int = 0) PyTree[source]#

Stack corresponding arrays in a list of PyTrees.

Parameters
  • trees – List of PyTrees with matching structure.

  • axis – Axis along which to stack arrays (default: 0).

Returns

Single tree with stacked arrays.

Return type

PyTree

Examples

>>> tree1 = {"a": jnp.array([1, 2])}
>>> tree2 = {"a": jnp.array([3, 4])}
>>> result = tree_stack([tree1, tree2])
>>>
eformer.pytree._tree_util.tree_structure_equal(tree1: PyTree, tree2: PyTree) bool[source]#

Check if two PyTrees have the same structure.

Parameters
  • tree1 – First PyTree to compare.

  • tree2 – Second PyTree to compare.

Returns

True if both trees have identical structure, False otherwise.

Return type

bool

Note

This only compares structure, not values. Trees with different values but same nesting will return True.

Examples

>>> tree1 = {"a": 1, "b": {"c": 2}}
>>> tree2 = {"a": 10, "b": {"c": 20}}
>>> tree_structure_equal(tree1, tree2)
True
eformer.pytree._tree_util.tree_subtract(tree1: PyTree, tree2: PyTree) PyTree[source]#

Element-wise subtraction of two pytrees.

Parameters
  • tree1 – First pytree (minuend).

  • tree2 – Second pytree (subtrahend, must have same structure).

Returns

New tree with element-wise difference (tree1 - tree2).

Return type

PyTree

Examples

>>> tree1 = {"a": jnp.array([5, 7]), "b": jnp.array([9])}
>>> tree2 = {"a": jnp.array([1, 2]), "b": jnp.array([3])}
>>> result = tree_subtract(tree1, tree2)
>>>
eformer.pytree._tree_util.tree_sum(tree: PyTree, axis: int | None = None) eformer.pytree._pytree.PyTree | Any[source]#

Sum all values in a pytree.

Parameters
  • tree – Input pytree

  • axis – Optional axis for sum (applies to each array)

Returns

Sum of all values

eformer.pytree._tree_util.tree_transpose(tree: PyTree, axes: tuple[int, ...] | None = None) PyTree[source]#

Transpose arrays in a pytree.

Parameters
  • tree – Input pytree containing arrays.

  • axes – Permutation of axes. If None, reverses axis order.

Returns

New tree with transposed arrays.

Return type

PyTree

Examples

>>> tree = {"matrix": jnp.array([[1, 2], [3, 4]])}
>>> result = tree_transpose(tree)
>>>
>>>
>>>
>>> tensor = {"data": jnp.ones((2, 3, 4))}
>>> result = tree_transpose(tensor, axes=(2, 0, 1))
>>>
eformer.pytree._tree_util.tree_where(condition: PyTree, x: PyTree, y: PyTree) PyTree[source]#

Element-wise where operation on PyTrees.

Parameters
  • condition – PyTree of boolean conditions.

  • x – PyTree of values to select when condition is True.

  • y – PyTree of values to select when condition is False.

Returns

Tree with selected values based on conditions.

Return type

PyTree

Examples

>>> cond = {"a": jnp.array([True, False])}
>>> x = {"a": jnp.array([1, 2])}
>>> y = {"a": jnp.array([3, 4])}
>>> result = tree_where(cond, x, y)
>>>
eformer.pytree._tree_util.tree_zeros_like(tree: PyTree) PyTree[source]#

Create a PyTree of zeros with the same structure and shapes.

Parameters

tree – Template PyTree to match structure and shapes.

Returns

New tree with same structure but all array values set to zero.

Return type

PyTree

Examples

>>> tree = {"a": jnp.array([1.5, 2.5])}
>>> zeros = tree_zeros_like(tree)
>>>
eformer.pytree._tree_util.unflatten_dict(xs, sep=None)[source]#

Unflatten a dictionary with tuple or string keys.

Parameters
  • xs – Flattened dictionary with tuple or separated string keys.

  • sep – Optional separator for string keys.

Returns

Nested dictionary structure.

Return type

dict

Examples

>>> flat = {('a', 'b'): 1, ('a', 'c'): 2}
>>> unflatten_dict(flat)
>>>
eformer.pytree._tree_util.unflatten_mapping(xs: Sequence[tuple[tuple[Any, ...], Any]], /, *, sep: None = None) dict[Any, Any][source]#
eformer.pytree._tree_util.unflatten_mapping(xs: Mapping[tuple[Any, ...], Any], /, *, sep: None = None) dict[Any, Any]
eformer.pytree._tree_util.unflatten_mapping(xs: Mapping[str, Any], /, *, sep: str) dict[Any, Any]

Unflatten a mapping.

See flatten_mapping

Example:

>>> from flax import nnx
>>> flat_xs = {
...   ('foo',): 1,
...   ('bar', 'a'): 2,
... }
>>> xs = nnx.traversals.unflatten_mapping(flat_xs)
>>> xs
{'foo': 1, 'bar': {'a': 2}}
Parameters
  • xs – a flattened mapping.

  • sep – separator (same as used with flatten_mapping()).

Returns

The nested mapping.

eformer.pytree._tree_util.unpack_pytree(offset_tree: PyTree, packed: Array) PyTree[source]#

Reconstruct a pytree from its packed representation.

This is the inverse operation of pack_pytree(). It uses the offset and shape information stored in offset_tree to extract and reshape data from the packed array.

Parameters
  • offset_tree – Tree of PackedLeaf objects from pack_pytree.

  • packed – The 1-D array containing packed leaf data.

Returns

Reconstructed tree with original structure and array shapes.

Return type

PyTree

Examples

>>> tree = {"weights": jnp.ones((2, 3)), "bias": jnp.zeros(3)}
>>> offset_tree, packed = pack_pytree(tree)
>>> reconstructed = unpack_pytree(offset_tree, packed)
>>> jnp.allclose(tree["weights"], reconstructed["weights"])
True