eformer.pytree._tree_util#
- class eformer.pytree._tree_util.MetaValueRecreator(seed: int = 42)[source]#
Bases:
objectHelper class for generating unique meta values with state tracking.
Provides methods to generate incrementing counts and random keys in a reproducible manner. Useful for reinitializing model metadata that requires unique values.
- _count#
Internal counter for generating unique count values.
- _rng#
JAX random key state for generating random values.
Examples
>>> recreator = MetaValueRecreator(seed=42) >>> count1 = recreator.get_count() # Returns 0 >>> count2 = recreator.get_count() # Returns 1 >>> key = recreator.get_rng() # Returns a unique random key
- class eformer.pytree._tree_util.NonePolicy(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumPolicy for handling None values in tree operations.
- PRESERVE#
Keep None values as-is in the tree.
- REPLACE#
Replace None values with a specified replacement.
- ERROR#
Raise an error when None values are encountered.
- ERROR = 'error'#
- PRESERVE = 'preserve'#
- REPLACE = 'replace'#
- class eformer.pytree._tree_util.PackedLeaf(offset: int, shape: tuple[int, ...])[source]#
Bases:
objectMetadata describing the location and shape of a leaf in a packed array.
Used by pack_pytree and unpack_pytree to track where each leaf’s data is stored within the flattened 1-D array representation.
- offset#
Starting index of this leaf’s data in the packed array.
- Type
int
- shape#
Original shape of this leaf array before packing.
- Type
tuple[int, …]
- offset: int#
- shape: tuple[int, ...]#
- class eformer.pytree._tree_util.StateValidationResult(is_valid: bool, missing_keys: set, invalid_types: dict[str, type])[source]#
Bases:
objectResult of validating a state dictionary against a target structure.
This class stores the outcome of state validation, including whether the validation passed and details about any issues found.
- is_valid#
True if validation passed, False otherwise.
- Type
bool
- missing_keys#
Set of keys present in target but missing from state.
- Type
set
- invalid_types#
Dictionary mapping key paths to their incorrect types.
- Type
dict[str, type]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- invalid_types: dict[str, type]#
- is_valid: bool#
- missing_keys: set#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class eformer.pytree._tree_util.TreeFilter(*args, **kwargs)[source]#
Bases:
ProtocolProtocol defining the interface for tree filter functions.
Tree filters are callable objects that take a mask (boolean or callable) and an argument (the tree to filter), returning a filtered tree dictionary.
This protocol enables type checking for functions that implement tree filtering logic.
- eformer.pytree._tree_util.deepcopy_tree(model)[source]#
Creates a deep copy of a JAX model.
This function takes a JAX model, extracts its leaves (the individual components of the model), deep copies them, and then reconstructs the model with the copied leaves.
- Parameters
model – A JAX model to be deep copied. This can be any nested structure of JAX arrays, lists, tuples, dicts, etc.
- Returns
A deep copy of the input model with the same structure but with all leaves deep copied.
- eformer.pytree._tree_util.empty_node = _EmptyNode( )#
Singleton instance of _EmptyNode used as the empty node marker.
- eformer.pytree._tree_util.flatten_dict(xs: Union[dict, Mapping], keep_empty_nodes: bool = False, is_leaf: Optional[Callable[[tuple, Any], bool]] = None, sep: str | None = None, fumap: bool = False) dict[tuple | str, Any][source]#
Enhanced dictionary flattening with better type handling and validation.
- Parameters
xs – Dictionary or mapping to flatten
keep_empty_nodes – Whether to keep empty dictionary nodes
is_leaf – Optional function to determine leaf nodes
sep – Optional separator for string keys
- Returns
Flattened dictionary
- Raises
TypeError – If input is not a dictionary or mapping
- eformer.pytree._tree_util.flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, is_leaf: None | collections.abc.Callable[[tuple[Any, ...], collections.abc.Mapping[Any, Any]], bool] = None, sep: None = None) dict[tuple[Any, ...], Any][source]#
- eformer.pytree._tree_util.flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, is_leaf: None | collections.abc.Callable[[tuple[Any, ...], collections.abc.Mapping[Any, Any]], bool] = None, sep: str) dict[str, Any]
Flatten a nested mapping.
The nested keys are flattened to a tuple. See
unflatten_mappingon how to restore the nested mapping.Example:
>>> from flax import nnx >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = nnx.traversals.flatten_mapping(xs) >>> flat_xs {('foo',): 1, ('bar', 'a'): 2}
Note that empty mappings are ignored and will not be restored by
unflatten_mapping.- Parameters
xs – a nested mapping
keep_empty_nodes – replaces empty mappings with
traverse_util.empty_node.is_leaf – an optional function that takes the next nested mapping and nested keys and returns True if the nested mapping is a leaf (i.e., should not be flattened further).
sep – if specified, then the keys of the returned mapping will be
sep-joined strings (ifNone, then keys will be tuples).
- Returns
The flattened mapping.
- eformer.pytree._tree_util.flatten_to_sequence(xs: Mapping[Any, Any], /, *, is_leaf: collections.abc.Callable[[tuple[Any, ...], collections.abc.Mapping[Any, Any]], bool] | None = None) list[tuple[Any, Any]][source]#
Flatten a nested mapping.
The nested keys are flattened to a tuple. See
unflatten_mappingon how to restore the nested mapping.Example:
>>> from flax import nnx >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = nnx.traversals.flatten_to_sequence(xs) >>> flat_xs [(('foo',), 1), (('bar', 'a'), 2)]
Note that empty mappings are ignored and will not be restored by
unflatten_mapping.- Parameters
xs – a nested mapping
is_leaf – an optional function that takes the next nested mapping and nested keys and returns True if the nested mapping is a leaf (i.e., should not be flattened further).
- Returns
The flattened mapping.
- eformer.pytree._tree_util.flatten_tree(xs: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None, sep: str | None = None) dict[str, Any][source]#
Flatten a JAX tree and convert paths to strings.
- Parameters
xs – The JAX tree to flatten.
is_leaf – Optional function to determine leaf nodes.
sep – Separator to use when joining path elements.
- Returns
A flattened dictionary with string keys representing the tree paths.
- eformer.pytree._tree_util.int_key_to_string(xs)[source]#
Convert integer keys in a dictionary to strings.
- Parameters
xs – Dictionary possibly with integer or tuple keys.
- Returns
Dictionary with string keys.
- Return type
dict
Examples
>>> d = {(0, 1): 'value'} >>> int_key_to_string(d) >>>
- eformer.pytree._tree_util.is_array(element: Any) bool[source]#
Check if an element is a JAX array or NumPy array.
- Parameters
element – The object to check.
- Returns
True if element is a JAX Array, NumPy ndarray, or NumPy generic type.
- Return type
bool
Examples
>>> is_array(jnp.array([1, 2, 3])) True >>> is_array(np.array([1, 2, 3])) True >>> is_array([1, 2, 3]) False
- eformer.pytree._tree_util.is_array_like(element: Any) bool[source]#
Check if an element is array-like (arrays or scalar numeric types).
- Parameters
element – The object to check.
- Returns
True if element is an array or numeric scalar type.
- Return type
bool
Note
This includes JAX arrays, NumPy arrays, and Python numeric types (int, float, complex, bool), as well as objects with __jax_array__ attribute.
Examples
>>> is_array_like(jnp.array([1, 2])) True >>> is_array_like(5.0) True >>> is_array_like("string") False
- eformer.pytree._tree_util.is_flatten(tree: dict) bool[source]#
Checks if a dictionary represents a flattened tree.
A flattened tree is a dictionary where the keys are tuples representing the path to the leaf nodes. This function checks if any of the keys in the input dictionary is a tuple, indicating a flattened tree.
- Parameters
tree – The dictionary to check.
- Returns
True if the dictionary is a flattened tree, False otherwise.
- Return type
bool
- eformer.pytree._tree_util.is_iterable(obj)[source]#
Check if an object is iterable.
- Parameters
obj – Object to check.
- Returns
True if object is iterable, False otherwise.
- Return type
bool
Examples
>>> is_iterable([1, 2, 3]) True >>> is_iterable(42) False
- eformer.pytree._tree_util.join_key(prefix, k)[source]#
Concatenate a prefix and key using dot-notation.
Creates hierarchical key paths by joining components with dots. Handles None keys and empty prefixes gracefully.
- Parameters
prefix – The prefix string (can be empty string).
k – The key to append (can be None).
- Returns
The joined key path.
- Return type
str
Examples
>>> join_key('layer', 'weight') 'layer.weight' >>> join_key('', 'bias') 'bias' >>> join_key('layer', None) 'layer'
- eformer.pytree._tree_util.key_path_to_str(path: Sequence) str[source]#
Convert a JAX key path element to a string representation.
Handles various JAX key types (SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey) and converts them to readable string format.
- Parameters
path – A sequence containing JAX key path elements. Only the last element is processed.
- Returns
- String representation of the last path element, or empty
string if path is empty.
- Return type
str
Examples
>>> from jax._src.tree_util import DictKey, SequenceKey >>> key_path_to_str([DictKey("weights")]) 'weights' >>> key_path_to_str([SequenceKey(0)]) '0'
- eformer.pytree._tree_util.leaf_key_paths(pytree, prefix: str | None = '', *, is_leaf: collections.abc.Callable[[Any], bool] | None = None, use_state_dict_keys: bool = False)[source]#
Return a tree mirroring pytree whose leaves are their dot-path strings.
- Parameters
pytree – The input tree to traverse.
prefix – Optional prefix added to every returned path.
Noneresets to"".is_leaf – Optional custom leaf predicate forwarded to
jax.tree_util.tree_flatten_with_path().use_state_dict_keys – Reserved for compatibility with other libraries; currently unused.
- Returns
A PyTree with the same structure as
pytreewhose leaves are strings representing the dotted traversal path, orNonewhenpytreehas no leaves.
Example
>>> tree = {"layer": {"w": 1, "b": 2}, "scale": 3} >>> leaf_key_paths(tree) {'layer': {'w': 'layer.w', 'b': 'layer.b'}, 'scale': 'scale'}
- eformer.pytree._tree_util.merge(*pytrees: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None) PyTree[source]#
Combine multiple PyTrees into a single PyTree.
Takes the first non-None value at each position across all input trees.
- Parameters
*pytrees – Variable number of PyTrees to merge.
is_leaf – Optional function to determine if a node is a leaf.
- Returns
Combined tree with first non-None value at each position.
- Return type
Note
This is useful for combining partial trees or filling in missing values.
Examples
>>> tree1 = {"a": 1, "b": None} >>> tree2 = {"a": None, "b": 2} >>> merged = merge(tree1, tree2) >>>
- eformer.pytree._tree_util.named_tree_map(f: Callable[[str, Any, Any], Any], tree: PyTree, *rest: Any, is_leaf: Optional[Callable[[Any], bool]] = None, sep: str | None = None) PyTree[source]#
An extended version of jax.tree_util.tree_map.
This function extends jax.tree_util.tree_map by providing the path (as a string) to the current leaf node as an argument to the mapped function f.
- Parameters
f – The function to apply to each leaf node, taking the path and value as input.
tree – The JAX tree to map over.
*rest – Additional arguments to be passed to f.
is_leaf – Optional function to determine leaf nodes.
sep – Separator to use when joining path elements.
- Returns
A new tree with the same structure as tree but with the values modified by f.
- eformer.pytree._tree_util.pack_pytree(tree: ~eformer.pytree._pytree.PyTree, dtype=<class 'jax.numpy.float32'>) tuple[eformer.pytree._pytree.PyTree, jax.jaxlib._jax.Array][source]#
Pack all leaves of a pytree into a single 1-D array.
This function flattens all array leaves into a contiguous 1-D array, which is useful for optimization algorithms that work on flat parameter vectors or for efficient storage/transmission.
- Parameters
tree – Pytree of array-like objects to pack.
dtype – Desired dtype of the packed array (default: jnp.float32).
- Returns
- A pair
(offset_tree, flat_array)where: offset_treehas the same structure astreebut each leaf is replaced with aPackedLeafcontaining offset and shape information.flat_arrayis a 1-D array containing all leaf data.
- A pair
- Return type
tuple
Examples
>>> tree = {"weights": jnp.ones((2, 3)), "bias": jnp.zeros(3)} >>> offset_tree, packed = pack_pytree(tree) >>> packed.shape (9,) >>> original = unpack_pytree(offset_tree, packed)
- eformer.pytree._tree_util.recursive_merge(full_tree, updates)[source]#
Recursively merge two PyTrees where updates may have fewer parameters.
- Parameters
full_tree – The complete parameter tree
updates – Tree with updated values (subset of full_tree)
- Returns
Merged tree with updated values where available
- eformer.pytree._tree_util.specs_to_name_sharding(tree: dict, mesh: jax._src.mesh.Mesh | None = None) dict[source]#
Converts a dictionary of specifications to a dictionary of NamedSharding objects.
- Parameters
tree (Dict) – A dictionary where the keys are names and the values are specifications.
mesh (Optional[Mesh]) – An optional Mesh object. If not provided, the default physical mesh from pxla.thread_resources.env.physical_mesh is used.
- Returns
- A dictionary where the keys are the same as the input dictionary, and the values are NamedSharding
objects created from the specifications and the provided or default mesh.
- Return type
Dict
- eformer.pytree._tree_util.split(pytree: PyTree, filter_spec: Union[bool, Callable[[Any], bool]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) tuple[eformer.pytree._pytree.PyTree, eformer.pytree._pytree.PyTree][source]#
Split a PyTree into two based on a filter specification.
- Parameters
pytree – The PyTree to split.
filter_spec – Either a boolean or callable that determines the split. If bool, applies uniformly. If callable, applied to each leaf.
replace – Value to use for filtered-out positions (default: None).
is_leaf – Optional function to determine leaf nodes.
- Returns
- Two PyTrees where:
First contains values where filter is True (others replaced)
Second contains values where filter is False (others replaced)
- Return type
Examples
>>> tree = {"a": jnp.array([1, 2]), "b": jnp.array([3, 4])} >>> >>> large, small = split(tree, lambda x: x.size > 2)
- eformer.pytree._tree_util.string_key_to_int(xs)[source]#
Convert string keys in a dictionary to integers where possible.
- Parameters
xs – Dictionary with string or tuple keys.
- Returns
Dictionary with integer keys where applicable.
- Return type
dict
Examples
>>> d = {('0', '1'): 'value'} >>> string_key_to_int(d) >>>
- eformer.pytree._tree_util.tree_abs(tree: PyTree) PyTree[source]#
Compute absolute values of all elements in a pytree.
- Parameters
tree – Input pytree containing numerical values.
- Returns
New tree with absolute values.
- Return type
Examples
>>> tree = {"a": jnp.array([-1, 2, -3]), "b": -4.5} >>> result = tree_abs(tree) >>>
- eformer.pytree._tree_util.tree_add(tree1: PyTree, tree2: PyTree) PyTree[source]#
Element-wise addition of two pytrees.
- Parameters
tree1 – First pytree.
tree2 – Second pytree (must have same structure as tree1).
- Returns
New tree with element-wise sum of values.
- Return type
Examples
>>> tree1 = {"a": jnp.array([1, 2]), "b": jnp.array([3])} >>> tree2 = {"a": jnp.array([4, 5]), "b": jnp.array([6])} >>> result = tree_add(tree1, tree2) >>>
- eformer.pytree._tree_util.tree_all(tree: PyTree) bool[source]#
Check if all values in the pytree are True.
- Parameters
tree – Input pytree containing boolean or numerical values.
- Returns
True if all elements in all arrays are True/non-zero.
- Return type
bool
Examples
>>> tree = {"a": jnp.array([True, True]), "b": jnp.array([True])} >>> tree_all(tree) >>> >>> >>> tree2 = {"x": jnp.array([1, 2]), "y": jnp.array([0])} >>> tree_all(tree2) >>>
- eformer.pytree._tree_util.tree_any(tree: PyTree) bool[source]#
Check if any value in the pytree is True.
- Parameters
tree – Input pytree containing boolean or numerical values.
- Returns
True if any element in any array is True/non-zero.
- Return type
bool
Examples
>>> tree = {"a": jnp.array([False, False]), "b": jnp.array([True])} >>> tree_any(tree) >>> >>> >>> tree2 = {"x": jnp.array([0, 0]), "y": jnp.array([0])} >>> tree_any(tree2) >>>
- eformer.pytree._tree_util.tree_apply(fns: dict[Any, Callable[[Any], Any]], tree: dict[Any, Any]) dict[Any, Any][source]#
Apply a dictionary of functions to a corresponding PyTree.
- Parameters
fns – A dictionary where keys match the PyTree structure and values are functions.
tree – The PyTree to apply functions to.
- Returns
A new PyTree with the same structure as tree, but with values modified by the functions in fns.
- eformer.pytree._tree_util.tree_bytes(tree: PyTree) int[source]#
Calculate the total memory usage of a pytree in bytes.
- Parameters
tree – Input pytree
- Returns
Total memory usage in bytes
- eformer.pytree._tree_util.tree_cast(tree: PyTree, dtype: Any) PyTree[source]#
Cast all arrays in a pytree to a specified dtype.
- Parameters
tree – Input pytree containing arrays.
dtype – Target dtype (e.g., jnp.float32, jnp.int32).
- Returns
New tree with arrays cast to the specified dtype.
- Return type
Examples
>>> tree = {"a": jnp.array([1, 2], dtype=jnp.int32)} >>> result = tree_cast(tree, jnp.float32) >>>
- eformer.pytree._tree_util.tree_clip(tree: PyTree, min_val: Any = None, max_val: Any = None) PyTree[source]#
Clip values in a pytree to a specified range.
- Parameters
tree – Input pytree containing numerical arrays.
min_val – Minimum value for clipping (inclusive).
max_val – Maximum value for clipping (inclusive).
- Returns
New tree with values clipped to [min_val, max_val].
- Return type
Examples
>>> tree = {"weights": jnp.array([-2, 0, 5, 10])} >>> clipped = tree_clip(tree, min_val=0, max_val=5) >>>
- eformer.pytree._tree_util.tree_concatenate(trees: list[eformer.pytree._pytree.PyTree], axis: int = 0) PyTree[source]#
Concatenate corresponding arrays in a list of PyTrees.
- Parameters
trees – List of PyTrees with matching structure.
axis – Axis along which to concatenate arrays (default: 0).
- Returns
Single tree with concatenated arrays.
- Return type
Examples
>>> tree1 = {"a": jnp.array([1, 2])} >>> tree2 = {"a": jnp.array([3, 4])} >>> result = tree_concatenate([tree1, tree2]) >>>
- eformer.pytree._tree_util.tree_divide(tree1: PyTree, tree2: eformer.pytree._pytree.PyTree | Any) PyTree[source]#
Element-wise division of pytrees or scalar division.
- Parameters
tree1 – First pytree (dividend).
tree2 – Second pytree (same structure) or scalar divisor.
- Returns
New tree with element-wise or scalar quotient.
- Return type
Examples
>>> tree = {"a": jnp.array([4.0, 6.0]), "b": jnp.array([8.0])} >>> >>> result1 = tree_divide(tree, 2.0) >>> >>> >>> >>> tree2 = {"a": jnp.array([2.0, 3.0]), "b": jnp.array([4.0])} >>> result2 = tree_divide(tree, tree2) >>>
- eformer.pytree._tree_util.tree_dot(tree1: PyTree, tree2: PyTree) Any[source]#
Compute dot product of two pytrees.
Computes the sum of element-wise products across all arrays in the trees.
- Parameters
tree1 – First pytree.
tree2 – Second pytree (must have same structure).
- Returns
Scalar value representing the dot product.
Examples
>>> tree1 = {"a": jnp.array([1, 2]), "b": jnp.array([3])} >>> tree2 = {"a": jnp.array([4, 5]), "b": jnp.array([6])} >>> result = tree_dot(tree1, tree2) >>>
- eformer.pytree._tree_util.tree_equal(*pytrees: PyTree, typematch: bool = False, rtol=0.0, atol=0.0) bool[source]#
Check if multiple PyTrees are equal in structure and values.
- Parameters
*pytrees – Variable number of PyTrees to compare.
typematch – If True, also check that types match exactly.
rtol – Relative tolerance for floating point comparison.
atol – Absolute tolerance for floating point comparison.
- Returns
True if all trees have same structure and equal values.
- Return type
bool
Examples
>>> tree1 = {"a": jnp.array([1.0, 2.0])} >>> tree2 = {"a": jnp.array([1.0, 2.0])} >>> tree_equal(tree1, tree2) True >>> tree3 = {"a": jnp.array([1.0, 2.1])} >>> tree_equal(tree1, tree3, atol=0.2) True
- eformer.pytree._tree_util.tree_exp(tree: PyTree) PyTree[source]#
Compute exponential (e^x) of all elements in a pytree.
- Parameters
tree – Input pytree containing numerical arrays.
- Returns
New tree with exponential values.
- Return type
Examples
>>> tree = {"a": jnp.array([0.0, 1.0]), "b": jnp.array([2.0])} >>> result = tree_exp(tree) >>>
- eformer.pytree._tree_util.tree_expand_dims(tree: PyTree, axis: int) PyTree[source]#
Expand dimensions of arrays in a pytree.
- Parameters
tree – Input pytree containing arrays.
axis – Position in the expanded axes where the new axis is placed.
- Returns
New tree with arrays having an additional dimension.
- Return type
Examples
>>> tree = {"a": jnp.array([1, 2, 3])} >>> result = tree_expand_dims(tree, axis=0) >>> >>> >>> result2 = tree_expand_dims(tree, axis=1) >>>
- eformer.pytree._tree_util.tree_filter(tree: PyTree, predicate: Callable[[Any], bool]) PyTree[source]#
Filter a PyTree keeping only leaves that satisfy the predicate.
- Parameters
tree – Input PyTree to filter.
predicate – Function that returns True for leaves to keep.
- Returns
Filtered tree with same structure but only matching leaves.
- Return type
Note
This may change the tree structure if entire branches are filtered out.
Examples
>>> tree = {"a": jnp.array([1, 2]), "b": jnp.array([3])} >>> filtered = tree_filter(tree, lambda x: x.size > 1) >>>
- eformer.pytree._tree_util.tree_flatten_one_level_with_keys(pytree: PyTree) tuple[list[tuple[Optional[KeyEntry], eformer.pytree._pytree.PyTree]], jaxlib._jax.pytree.PyTreeDef][source]#
Adapted form equinox.tree_flatten_one_level to return keys
If the passed in PyTree is a leaf, it will return a single-element list with None as the key and the PyTree as the value.
- eformer.pytree._tree_util.tree_flatten_with_paths(tree: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None) tuple[list[tuple[tuple, Any]], jaxlib._jax.pytree.PyTreeDef][source]#
Flattens a pytree while keeping track of paths to leaves.
This function is useful when you need both the flattened values and their locations in the original tree structure.
- Parameters
tree – Input pytree to flatten.
is_leaf – Optional function to determine if a node is a leaf.
- Returns
- A pair of (paths_and_values, treedef) where:
paths_and_values is a list of (path, value) tuples
treedef is the tree structure definition
- Return type
tuple
Examples
>>> tree = {"weights": jnp.array([1, 2]), "bias": jnp.array([3])} >>> paths_vals, treedef = tree_flatten_with_paths(tree) >>> >>>
- eformer.pytree._tree_util.tree_isfinite(tree: PyTree) PyTree[source]#
Check for finite values in a pytree.
- Parameters
tree – Input pytree containing numerical arrays.
- Returns
- New tree with boolean arrays indicating finite values
(not NaN or infinity).
- Return type
Examples
>>> tree = {"a": jnp.array([1.0, jnp.nan, jnp.inf, 2.0])} >>> result = tree_isfinite(tree) >>>
- eformer.pytree._tree_util.tree_isinf(tree: PyTree) PyTree[source]#
Check for infinite values in a pytree.
- Parameters
tree – Input pytree containing numerical arrays.
- Returns
New tree with boolean arrays indicating infinity locations.
- Return type
Examples
>>> tree = {"a": jnp.array([1.0, jnp.inf, -jnp.inf])} >>> result = tree_isinf(tree) >>>
- eformer.pytree._tree_util.tree_isnan(tree: PyTree) PyTree[source]#
Check for NaN values in a pytree.
- Parameters
tree – Input pytree containing numerical arrays.
- Returns
New tree with boolean arrays indicating NaN locations.
- Return type
Examples
>>> tree = {"a": jnp.array([1.0, jnp.nan, 3.0])} >>> result = tree_isnan(tree) >>>
- eformer.pytree._tree_util.tree_leaves_with_paths(tree: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None) list[tuple[tuple, Any]][source]#
Returns list of (path, leaf_value) pairs in the pytree.
- Parameters
tree – Input PyTree to extract leaves from.
is_leaf – Optional function to determine if a node is a leaf.
- Returns
List of tuples where each tuple is (path, leaf_value).
- Return type
list
Examples
>>> tree = {"a": 1, "b": {"c": 2}} >>> paths_and_vals = tree_leaves_with_paths(tree) >>>
- eformer.pytree._tree_util.tree_log(tree: PyTree) PyTree[source]#
Compute natural logarithm of all elements in a pytree.
- Parameters
tree – Input pytree containing positive numerical arrays.
- Returns
New tree with natural logarithm values.
- Return type
Examples
>>> tree = {"a": jnp.array([1.0, jnp.e]), "b": jnp.array([jnp.e**2])} >>> result = tree_log(tree) >>>
- eformer.pytree._tree_util.tree_map_with_path(f: Callable, tree: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None) PyTree[source]#
Maps a function over a pytree while providing the path to each leaf.
- Parameters
f – Function that takes (path, leaf_value) as arguments. The path is a tuple of string keys representing the location in the tree.
tree – Input pytree to map over.
is_leaf – Optional function to determine if a node is a leaf.
- Returns
New tree with same structure but values transformed by f.
- Return type
Examples
>>> tree = {"a": 1, "b": {"c": 2, "d": 3}} >>> result = tree_map_with_path( ... lambda path, x: f"path={path}, value={x}", ... tree ... ) >>> >>> >>>
- eformer.pytree._tree_util.tree_max(tree: PyTree) Any[source]#
Find maximum value across all arrays in a pytree.
- Parameters
tree – Input pytree
- Returns
Maximum value
- eformer.pytree._tree_util.tree_mean(tree: PyTree, axis: int | None = None) eformer.pytree._pytree.PyTree | Any[source]#
Compute mean of all values in a pytree.
- Parameters
tree – Input pytree
axis – Optional axis for mean (applies to each array)
- Returns
Mean of all values
- eformer.pytree._tree_util.tree_min(tree: PyTree) Any[source]#
Find minimum value across all arrays in a pytree.
- Parameters
tree – Input pytree
- Returns
Minimum value
- eformer.pytree._tree_util.tree_multiply(tree1: PyTree, tree2: eformer.pytree._pytree.PyTree | Any) PyTree[source]#
Element-wise multiplication of pytrees or scalar multiplication.
- Parameters
tree1 – First pytree.
tree2 – Second pytree (same structure) or scalar value.
- Returns
New tree with element-wise or scalar product.
- Return type
Examples
>>> tree = {"a": jnp.array([1, 2]), "b": jnp.array([3])} >>> >>> result1 = tree_multiply(tree, 2) >>> >>> >>> >>> tree2 = {"a": jnp.array([2, 3]), "b": jnp.array([4])} >>> result2 = tree_multiply(tree, tree2) >>>
- eformer.pytree._tree_util.tree_norm(tree: PyTree, ord: Any = 2) Any[source]#
Compute the norm of a pytree.
- Parameters
tree – Input pytree
ord – Order of the norm (default: 2 for L2 norm)
- Returns
Norm value
- eformer.pytree._tree_util.tree_ones_like(tree: PyTree) PyTree[source]#
Create a PyTree of ones with the same structure and shapes.
- Parameters
tree – Template PyTree to match structure and shapes.
- Returns
New tree with same structure but all array values set to one.
- Return type
Examples
>>> tree = {"a": jnp.array([1.5, 2.5])} >>> ones = tree_ones_like(tree) >>>
- eformer.pytree._tree_util.tree_path_to_string(path: tuple[Any, ...], sep: str | None = None) str[source]#
Convert a JAX tree path to a string representation.
- Parameters
path – The JAX tree path tuple.
sep – Separator to use when joining path elements.
- Returns
The string representation of the path.
- eformer.pytree._tree_util.tree_random_like(tree: PyTree, key: PRNGKey, distribution: str = 'normal', **kwargs) PyTree[source]#
Create a pytree with random values matching the structure of input tree.
- Parameters
tree – Template pytree to match structure and shapes.
key – JAX random key for reproducible randomness.
distribution – Distribution type (‘normal’, ‘uniform’, ‘bernoulli’).
**kwargs – Additional arguments for the distribution: - For ‘normal’: mean, std - For ‘uniform’: minval, maxval - For ‘bernoulli’: p (probability)
- Returns
New tree with same structure but random values.
- Return type
Examples
>>> key = jax.random.PRNGKey(0) >>> tree = {"weights": jnp.zeros((2, 3))} >>> >>> >>> result1 = tree_random_like(tree, key, "normal") >>> >>> >>> result2 = tree_random_like(tree, key, "uniform") >>> >>> >>> result3 = tree_random_like(tree, key, "uniform", minval=-1, maxval=1)
- eformer.pytree._tree_util.tree_reciprocal(tree: PyTree) PyTree[source]#
Compute reciprocal (1/x) of all elements in a pytree.
- Parameters
tree – Input pytree containing numerical arrays.
- Returns
New tree with reciprocal values (1/x for each element).
- Return type
Examples
>>> tree = {"a": jnp.array([2.0, 4.0]), "b": jnp.array([0.5])} >>> result = tree_reciprocal(tree) >>>
- eformer.pytree._tree_util.tree_reduce(reducer: Callable[[Any, Any], Any], tree: PyTree, initializer: Any | None = None) Any[source]#
Reduce a pytree to a single value using a reduction function.
- Parameters
reducer – Binary function to reduce values
tree – Input pytree
initializer – Initial value for reduction
- Returns
Reduced value
- eformer.pytree._tree_util.tree_replace_infs(tree: PyTree, value: Any = 0.0) PyTree[source]#
Replace infinite values in a pytree.
- Parameters
tree – Input pytree containing numerical arrays.
value – Value to replace infinities with (default: 0.0).
- Returns
New tree with infinite values replaced.
- Return type
Examples
>>> tree = {"a": jnp.array([1.0, jnp.inf, -jnp.inf, 2.0])} >>> result = tree_replace_infs(tree, value=999.0) >>>
- eformer.pytree._tree_util.tree_replace_nans(tree: PyTree, value: Any = 0.0) PyTree[source]#
Replace NaN values in a pytree.
- Parameters
tree – Input pytree containing numerical arrays.
value – Value to replace NaNs with (default: 0.0).
- Returns
New tree with NaN values replaced.
- Return type
Examples
>>> tree = {"a": jnp.array([1.0, jnp.nan, 3.0])} >>> result = tree_replace_nans(tree, value=-1.0) >>>
- eformer.pytree._tree_util.tree_reshape(tree: PyTree, shape: tuple[int, ...]) PyTree[source]#
Reshape arrays in a pytree to a new shape.
- Parameters
tree – Input pytree containing arrays.
shape – New shape for arrays. Use -1 for automatic dimension.
- Returns
New tree with reshaped arrays.
- Return type
Examples
>>> tree = {"a": jnp.array([[1, 2], [3, 4]])} >>> result = tree_reshape(tree, (4,)) >>> >>> >>> >>> result2 = tree_reshape(tree, (-1, 1)) >>>
- eformer.pytree._tree_util.tree_round(tree: PyTree, decimals: int = 0) PyTree[source]#
Round all values in a pytree to a given number of decimals.
- Parameters
tree – Input pytree containing numerical arrays.
decimals – Number of decimal places to round to (default: 0).
- Returns
New tree with rounded values.
- Return type
Examples
>>> tree = {"a": jnp.array([1.234, 5.678])} >>> result = tree_round(tree, decimals=1) >>> >>> >>> result2 = tree_round(tree) >>>
- eformer.pytree._tree_util.tree_sign(tree: PyTree) PyTree[source]#
Compute sign of all elements in a pytree.
- Parameters
tree – Input pytree containing numerical values.
- Returns
New tree with sign values (-1, 0, or 1).
- Return type
Examples
>>> tree = {"a": jnp.array([-2.5, 0, 3.7])} >>> result = tree_sign(tree) >>>
- eformer.pytree._tree_util.tree_size(tree: PyTree) int[source]#
Calculate the total number of elements in a pytree.
- Parameters
tree – Input pytree
- Returns
Total number of elements across all arrays in the tree
- eformer.pytree._tree_util.tree_sqrt(tree: PyTree) PyTree[source]#
Compute square root of all elements in a pytree.
- Parameters
tree – Input pytree containing non-negative numerical arrays.
- Returns
New tree with square root values.
- Return type
Examples
>>> tree = {"a": jnp.array([4.0, 9.0]), "b": jnp.array([16.0])} >>> result = tree_sqrt(tree) >>>
- eformer.pytree._tree_util.tree_squeeze(tree: PyTree, axis: int | tuple[int, ...] | None = None) PyTree[source]#
Remove single-dimensional entries from arrays in a pytree.
- Parameters
tree – Input pytree containing arrays.
axis – Axis or axes to squeeze. If None, all axes of size 1 are removed.
- Returns
New tree with squeezed arrays.
- Return type
Examples
>>> tree = {"a": jnp.array([[[1], [2]]])} >>> result = tree_squeeze(tree, axis=2) >>> >>> >>> >>> tree2 = {"b": jnp.array([[[3]]])} >>> result2 = tree_squeeze(tree2) >>>
- eformer.pytree._tree_util.tree_stack(trees: list[eformer.pytree._pytree.PyTree], axis: int = 0) PyTree[source]#
Stack corresponding arrays in a list of PyTrees.
- Parameters
trees – List of PyTrees with matching structure.
axis – Axis along which to stack arrays (default: 0).
- Returns
Single tree with stacked arrays.
- Return type
Examples
>>> tree1 = {"a": jnp.array([1, 2])} >>> tree2 = {"a": jnp.array([3, 4])} >>> result = tree_stack([tree1, tree2]) >>>
- eformer.pytree._tree_util.tree_structure_equal(tree1: PyTree, tree2: PyTree) bool[source]#
Check if two PyTrees have the same structure.
- Parameters
tree1 – First PyTree to compare.
tree2 – Second PyTree to compare.
- Returns
True if both trees have identical structure, False otherwise.
- Return type
bool
Note
This only compares structure, not values. Trees with different values but same nesting will return True.
Examples
>>> tree1 = {"a": 1, "b": {"c": 2}} >>> tree2 = {"a": 10, "b": {"c": 20}} >>> tree_structure_equal(tree1, tree2) True
- eformer.pytree._tree_util.tree_subtract(tree1: PyTree, tree2: PyTree) PyTree[source]#
Element-wise subtraction of two pytrees.
- Parameters
tree1 – First pytree (minuend).
tree2 – Second pytree (subtrahend, must have same structure).
- Returns
New tree with element-wise difference (tree1 - tree2).
- Return type
Examples
>>> tree1 = {"a": jnp.array([5, 7]), "b": jnp.array([9])} >>> tree2 = {"a": jnp.array([1, 2]), "b": jnp.array([3])} >>> result = tree_subtract(tree1, tree2) >>>
- eformer.pytree._tree_util.tree_sum(tree: PyTree, axis: int | None = None) eformer.pytree._pytree.PyTree | Any[source]#
Sum all values in a pytree.
- Parameters
tree – Input pytree
axis – Optional axis for sum (applies to each array)
- Returns
Sum of all values
- eformer.pytree._tree_util.tree_transpose(tree: PyTree, axes: tuple[int, ...] | None = None) PyTree[source]#
Transpose arrays in a pytree.
- Parameters
tree – Input pytree containing arrays.
axes – Permutation of axes. If None, reverses axis order.
- Returns
New tree with transposed arrays.
- Return type
Examples
>>> tree = {"matrix": jnp.array([[1, 2], [3, 4]])} >>> result = tree_transpose(tree) >>> >>> >>> >>> tensor = {"data": jnp.ones((2, 3, 4))} >>> result = tree_transpose(tensor, axes=(2, 0, 1)) >>>
- eformer.pytree._tree_util.tree_where(condition: PyTree, x: PyTree, y: PyTree) PyTree[source]#
Element-wise where operation on PyTrees.
- Parameters
condition – PyTree of boolean conditions.
x – PyTree of values to select when condition is True.
y – PyTree of values to select when condition is False.
- Returns
Tree with selected values based on conditions.
- Return type
Examples
>>> cond = {"a": jnp.array([True, False])} >>> x = {"a": jnp.array([1, 2])} >>> y = {"a": jnp.array([3, 4])} >>> result = tree_where(cond, x, y) >>>
- eformer.pytree._tree_util.tree_zeros_like(tree: PyTree) PyTree[source]#
Create a PyTree of zeros with the same structure and shapes.
- Parameters
tree – Template PyTree to match structure and shapes.
- Returns
New tree with same structure but all array values set to zero.
- Return type
Examples
>>> tree = {"a": jnp.array([1.5, 2.5])} >>> zeros = tree_zeros_like(tree) >>>
- eformer.pytree._tree_util.unflatten_dict(xs, sep=None)[source]#
Unflatten a dictionary with tuple or string keys.
- Parameters
xs – Flattened dictionary with tuple or separated string keys.
sep – Optional separator for string keys.
- Returns
Nested dictionary structure.
- Return type
dict
Examples
>>> flat = {('a', 'b'): 1, ('a', 'c'): 2} >>> unflatten_dict(flat) >>>
- eformer.pytree._tree_util.unflatten_mapping(xs: Sequence[tuple[tuple[Any, ...], Any]], /, *, sep: None = None) dict[Any, Any][source]#
- eformer.pytree._tree_util.unflatten_mapping(xs: Mapping[tuple[Any, ...], Any], /, *, sep: None = None) dict[Any, Any]
- eformer.pytree._tree_util.unflatten_mapping(xs: Mapping[str, Any], /, *, sep: str) dict[Any, Any]
Unflatten a mapping.
See
flatten_mappingExample:
>>> from flax import nnx >>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, ... } >>> xs = nnx.traversals.unflatten_mapping(flat_xs) >>> xs {'foo': 1, 'bar': {'a': 2}}
- Parameters
xs – a flattened mapping.
sep – separator (same as used with
flatten_mapping()).
- Returns
The nested mapping.
- eformer.pytree._tree_util.unpack_pytree(offset_tree: PyTree, packed: Array) PyTree[source]#
Reconstruct a pytree from its packed representation.
This is the inverse operation of
pack_pytree(). It uses the offset and shape information stored in offset_tree to extract and reshape data from the packed array.- Parameters
offset_tree – Tree of
PackedLeafobjects from pack_pytree.packed – The 1-D array containing packed leaf data.
- Returns
Reconstructed tree with original structure and array shapes.
- Return type
Examples
>>> tree = {"weights": jnp.ones((2, 3)), "bias": jnp.zeros(3)} >>> offset_tree, packed = pack_pytree(tree) >>> reconstructed = unpack_pytree(offset_tree, packed) >>> jnp.allclose(tree["weights"], reconstructed["weights"]) True