Source code for eformer.escale.partition.manager

# Copyright 2026 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This module provides classes and functions for managing JAX sharding configurations
and applying sharding constraints within a context.

It includes the `PartitionAxis` class for defining logical-to-physical axis mappings
and the `PartitionManager` context manager for applying these rules.
"""

import contextvars
import dataclasses
import hashlib
import threading
import typing as tp

import jax
from jax.sharding import PartitionSpec

from eformer.common_types import (
    BATCH,
    BIAS_HEAD_SEQ,
    BIAS_KV_SEQ,
    DATA_PARALLEL,
    EMBED,
    EMPTY,
    EXPERT,
    EXPERT_GATE,
    EXPERT_PARALLEL,
    FULLY_SHARDED_DATA_PARALLEL,
    GENERATION_MODES,
    HEAD,
    HEAD_DIM,
    KV_HEAD,
    KV_HEAD_DIM,
    KV_LENGTH,
    LENGTH,
    MLP_INTERMEDIATE,
    MODE_DECODE,
    MODE_TRAIN,
    NOT_GIVEN,
    QUERY_LENGTH,
    RUNTIME_MODE_TYPES,
    SEQUENCE_PARALLEL,
    TENSOR_PARALLEL,
    VOCAB,
    AxisType,
    DynamicShardingAxes,
)
from eformer.pytree import PyTree, xTree

from .constraints import get_corrected_named_sharding, with_sharding_constraint

_CURRENT_PARTITION_MANAGER = contextvars.ContextVar("_CURRENT_PARTITION_MANAGER", default=None)
_LAST_PARTITION_MANAGER: tp.Any = None


def _to_hashable(value: tp.Any) -> tp.Any:
    """Convert nested structures and dataclass-like objects to hashable tuples."""
    if dataclasses.is_dataclass(value):
        value = {field.name: getattr(value, field.name) for field in dataclasses.fields(value)}

    if isinstance(value, dict):
        return tuple(sorted((str(k), _to_hashable(v)) for k, v in value.items()))
    if isinstance(value, list | tuple):
        return tuple(_to_hashable(v) for v in value)
    if isinstance(value, set):
        return tuple(sorted(_to_hashable(v) for v in value))
    if hasattr(value, "__dict__"):
        return (value.__class__.__qualname__, _to_hashable(vars(value)))

    try:
        hash(value)
    except TypeError:
        return repr(value)
    return value


[docs]def hash_fn(self) -> int: """Compute a hash value using dataclass fields (or object dict fallback).""" if dataclasses.is_dataclass(self): payload = tuple( (field.name, _to_hashable(getattr(self, field.name))) for field in dataclasses.fields(self) if field.compare ) return hash((self.__class__.__qualname__, payload)) return hash((self.__class__.__qualname__, _to_hashable(vars(self))))
[docs]def get_safe_hash_int(text, algorithm="md5"): """Convert text to an integer hash using the specified algorithm. Provides a safe way to generate integer hashes from strings, useful for creating hashable keys from complex objects. Args: text: The text to hash. Will be converted to string if not already. algorithm: Hash algorithm to use. Defaults to "md5". Supports any algorithm available in hashlib (e.g., "sha256", "sha1"). Returns: An integer representation of the hash digest. Raises: ValueError: If the specified algorithm is not supported by hashlib. Exception: If any other error occurs during hash generation. Example: >>> get_safe_hash_int("hello world") 309817674445039181685702831361671 >>> get_safe_hash_int("hello world", algorithm="sha256") ... # Different integer value """ try: text_str = str(text) hash_object = getattr(hashlib, algorithm)(text_str.encode()) return int.from_bytes(hash_object.digest(), byteorder="big") except AttributeError as e: raise ValueError(f"Unsupported hash algorithm: {algorithm}") from e except Exception as e: raise Exception(f"Error generating hash: {e!s}") from e
[docs]class PartitionAxis(xTree): """ Configuration for partitioning model axes across a device mesh. Defines the mesh dimension names for standard parallelism strategies and maps logical model axes to these dimensions. Allows overriding defaults. Mesh Dimensions Attributes: data_parallel_axis: Name for data parallel mesh dim. Default: "dp". fully_sharded_data_parallel_axis: Name for FSDP mesh dim. Default: "fsdp". tensor_parallel_axis: Name for tensor parallel mesh dim. Default: "tp". sequence_parallel_axis: Name for sequence parallel mesh dim. Default: "sp". expert_parallel_axis: Name for expert parallel mesh dim (MoE). Default: "ep". Logical Model Axes Attributes: Maps logical tensor axes (like batch, sequence, hidden) to one or more mesh dimension names defined above, or None if not partitioned. Defaults are derived from the standard mesh dimension names but can be overridden during instantiation. For example, `head_axis` defaults to the value of `tensor_parallel_axis` ('tp'). batch_axis: Mesh axis for the batch dimension. sequence_axis: Mesh axis for the general sequence length dimension. query_sequence_axis: Mesh axis for the query sequence length dimension. head_axis: Mesh axis for the attention head dimension. key_sequence_axis: Mesh axis for the key/value sequence length dimension. hidden_state_axis: Mesh axis for the embedding or hidden state dimension. mlp_intermediate_axis: Mesh axis for the intermediate dimension in MLP layers. vocab_axis: Mesh axis for the vocabulary dimension. expert_axis: Mesh axis for the expert dimension. expert_gate_axis: Mesh axis for the expert gate dimension. attention_dim_axis: Mesh axis for the dimension within each attention head. bias_head_sequence_axis: Mesh axis for bias related to head and sequence dimensions. bias_key_sequence_axis: Mesh axis for bias related to key/value sequence dimensions. decode_batch_axis: Mesh axis for the batch dimension during decoding. decode_query_sequence_axis: Mesh axis for the query sequence length during decoding. decode_head_axis: Mesh axis for the attention head dimension during decoding. decode_key_sequence_axis: Mesh axis for the key/value sequence length during decoding. decode_attention_dim_axis: Mesh axis for the dimension within each attention head during decoding. """ data_parallel_axis: str = "dp" fully_sharded_data_parallel_axis: str = "fsdp" tensor_parallel_axis: str = "tp" sequence_parallel_axis: str = "sp" expert_parallel_axis: str = "ep" batch_axis: AxisType = NOT_GIVEN sequence_axis: AxisType = NOT_GIVEN query_sequence_axis: AxisType = NOT_GIVEN head_axis: AxisType = NOT_GIVEN kv_head_axis: AxisType = NOT_GIVEN key_sequence_axis: AxisType = NOT_GIVEN hidden_state_axis: AxisType = NOT_GIVEN mlp_intermediate_axis: AxisType = NOT_GIVEN vocab_axis: AxisType = NOT_GIVEN expert_axis: AxisType = NOT_GIVEN expert_gate_axis: AxisType = None attention_dim_axis: AxisType = None attention_kv_dim_axis: AxisType = None bias_head_sequence_axis: AxisType = None bias_key_sequence_axis: AxisType = None decode_batch_axis: AxisType = NOT_GIVEN decode_query_sequence_axis: AxisType = None decode_head_axis: AxisType = NOT_GIVEN decode_kv_head_axis: AxisType = NOT_GIVEN decode_key_sequence_axis: AxisType = NOT_GIVEN decode_attention_dim_axis: AxisType = None decode_attention_kv_dim_axis: AxisType = None _SEMANTIC_MAP: tp.ClassVar[dict[str, str]] = { BATCH: "batch_axis", LENGTH: "sequence_axis", QUERY_LENGTH: "query_sequence_axis", KV_LENGTH: "key_sequence_axis", EMBED: "hidden_state_axis", HEAD: "head_axis", KV_HEAD: "kv_head_axis", MLP_INTERMEDIATE: "mlp_intermediate_axis", VOCAB: "vocab_axis", EXPERT: "expert_axis", EXPERT_GATE: "expert_gate_axis", HEAD_DIM: "attention_dim_axis", KV_HEAD_DIM: "attention_kv_dim_axis", BIAS_HEAD_SEQ: "bias_head_sequence_axis", BIAS_KV_SEQ: "bias_key_sequence_axis", EMPTY: None, DATA_PARALLEL: "data_parallel_axis", FULLY_SHARDED_DATA_PARALLEL: "fully_sharded_data_parallel_axis", TENSOR_PARALLEL: "tensor_parallel_axis", SEQUENCE_PARALLEL: "sequence_parallel_axis", EXPERT_PARALLEL: "expert_parallel_axis", } """ Maps semantic axis name constants (e.g., BATCH) to their corresponding attribute names in the PartitionAxis class (e.g., "batch_axis"). """ _STANDARD_TO_GENERATION_ATTR_MAP: tp.ClassVar[dict[str, str]] = { "batch_axis": "decode_batch_axis", "query_sequence_axis": "decode_query_sequence_axis", "key_sequence_axis": "decode_key_sequence_axis", "head_axis": "decode_head_axis", "kv_head_axis": "decode_kv_head_axis", "attention_dim_axis": "decode_attention_dim_axis", "attention_kv_dim_axis": "decode_attention_kv_dim_axis", } """ Maps standard axis attribute names to their corresponding generation-specific attribute names. Used to apply different sharding rules during generation modes. """ _REGISTRY_LOCK: tp.ClassVar[threading.RLock] = threading.RLock() _REGISTERED_SEMANTIC_MAP: tp.ClassVar[dict[str, tp.Any]] = {} _REGISTERED_GENERATION_MAP: tp.ClassVar[dict[str, tp.Any]] = {}
[docs] @classmethod def register( cls, semantic_axis: str, axis_rule: tp.Any, *, generation_axis_rule: tp.Any = NOT_GIVEN, override: bool = False, ) -> None: """Register a semantic axis mapping globally. This updates a process-global registry used by all current/future ``PartitionAxis`` instances. Mapping values may be: - an attribute name on ``PartitionAxis`` (for example ``"head_axis"``), - a literal mesh axis name (for example ``"tp"``), - a tuple/list of either of the above, - or ``None``. Args: semantic_axis: Semantic axis token to register. axis_rule: Resolution rule for standard/train mode. generation_axis_rule: Optional explicit rule for generation modes. If omitted and ``axis_rule`` maps to a known standard attribute, the corresponding decode attribute mapping is inferred. override: If ``False``, raises when ``semantic_axis`` already exists in built-in or custom maps. Set ``True`` to replace existing rules. """ name = str(semantic_axis).strip() if not name: raise ValueError("`semantic_axis` must be a non-empty string.") with cls._REGISTRY_LOCK: is_existing_builtin = name in cls._SEMANTIC_MAP is_existing_custom = name in cls._REGISTERED_SEMANTIC_MAP if not override and (is_existing_builtin or is_existing_custom): raise ValueError(f"Semantic axis '{name}' already exists. Use override=True to replace it.") cls._REGISTERED_SEMANTIC_MAP[name] = axis_rule if generation_axis_rule is NOT_GIVEN: inferred = None if isinstance(axis_rule, str): inferred = cls._STANDARD_TO_GENERATION_ATTR_MAP.get(axis_rule) if inferred is not None: cls._REGISTERED_GENERATION_MAP[name] = inferred else: cls._REGISTERED_GENERATION_MAP.pop(name, None) else: cls._REGISTERED_GENERATION_MAP[name] = generation_axis_rule
[docs] @classmethod def unregister(cls, semantic_axis: str, *, missing_ok: bool = True) -> None: """Remove a previously registered semantic axis mapping.""" name = str(semantic_axis).strip() if not name: raise ValueError("`semantic_axis` must be a non-empty string.") with cls._REGISTRY_LOCK: removed = False if name in cls._REGISTERED_SEMANTIC_MAP: cls._REGISTERED_SEMANTIC_MAP.pop(name, None) removed = True if name in cls._REGISTERED_GENERATION_MAP: cls._REGISTERED_GENERATION_MAP.pop(name, None) removed = True if not removed and not missing_ok: raise KeyError(f"Semantic axis '{name}' is not registered.")
[docs] @classmethod def clear_registered_axes(cls) -> None: """Clear all custom semantic axis registrations.""" with cls._REGISTRY_LOCK: cls._REGISTERED_SEMANTIC_MAP.clear() cls._REGISTERED_GENERATION_MAP.clear()
[docs] @classmethod def get_registered_axes(cls) -> dict[str, dict[str, tp.Any]]: """Return a snapshot of globally registered custom axis mappings.""" with cls._REGISTRY_LOCK: return { name: { "axis_rule": cls._REGISTERED_SEMANTIC_MAP[name], "generation_axis_rule": cls._REGISTERED_GENERATION_MAP.get(name, NOT_GIVEN), } for name in cls._REGISTERED_SEMANTIC_MAP }
@classmethod def _lookup_semantic_mapping(cls, semantic_axis: str) -> tp.Any: """Lookup semantic mapping from custom registry, then built-ins.""" if semantic_axis in cls._REGISTERED_SEMANTIC_MAP: return cls._REGISTERED_SEMANTIC_MAP[semantic_axis] return cls._SEMANTIC_MAP.get(semantic_axis) @classmethod def _lookup_generation_mapping(cls, semantic_axis: str) -> tp.Any: """Lookup generation mapping from custom registry.""" return cls._REGISTERED_GENERATION_MAP.get(semantic_axis, NOT_GIVEN) def _resolve_axis_rule(self, axis_rule: tp.Any, _visited: set[str] | None = None) -> tp.Any: """Resolve rule references (attribute names or semantic aliases) to concrete axis rules.""" if isinstance(axis_rule, list): return [ self._resolve_axis_rule( item, _visited=set(_visited) if _visited is not None else None, ) for item in axis_rule ] if isinstance(axis_rule, tuple): return tuple( self._resolve_axis_rule( item, _visited=set(_visited) if _visited is not None else None, ) for item in axis_rule ) if isinstance(axis_rule, str): if hasattr(self, axis_rule): return getattr(self, axis_rule) mapped = self._lookup_semantic_mapping(axis_rule) if mapped is not None: visited = set() if _visited is None else set(_visited) if axis_rule in visited: raise ValueError(f"Cyclic semantic axis registration detected at '{axis_rule}'.") visited.add(axis_rule) return self._resolve_axis_rule(mapped, _visited=visited) return axis_rule def __post_init__(self): """ Post-initialization hook to resolve default axis values. If an axis attribute is set to NOT_GIVEN, its value is resolved based on default logic, typically using the standard mesh dimension names. """ resolved_values = {} def resolve_field(name, default_logic): """Helper to resolve a single field's value if it's NOT_GIVEN.""" current_value = getattr(self, name) if current_value is NOT_GIVEN: resolved_values[name] = default_logic() elif name not in resolved_values: resolved_values[name] = current_value def get_resolved(name): """Helper to get a field's value, prioritizing resolved values.""" return resolved_values.get(name, getattr(self, name)) resolve_field( "batch_axis", lambda: (self.fully_sharded_data_parallel_axis, self.data_parallel_axis), ) resolve_field("sequence_axis", lambda: self.sequence_parallel_axis) resolve_field("query_sequence_axis", lambda: self.sequence_parallel_axis) resolve_field("head_axis", lambda: self.tensor_parallel_axis) resolve_field("kv_head_axis", lambda: self.tensor_parallel_axis) resolve_field("key_sequence_axis", lambda: self.sequence_parallel_axis) resolve_field("hidden_state_axis", lambda: self.tensor_parallel_axis) resolve_field("mlp_intermediate_axis", lambda: self.tensor_parallel_axis) resolve_field("vocab_axis", lambda: self.tensor_parallel_axis) resolve_field("expert_axis", lambda: self.expert_parallel_axis) resolve_field("decode_batch_axis", lambda: get_resolved("batch_axis")) resolve_field("decode_head_axis", lambda: get_resolved("head_axis")) resolve_field("decode_kv_head_axis", lambda: get_resolved("kv_head_axis")) resolve_field("decode_key_sequence_axis", lambda: get_resolved("key_sequence_axis")) for fld in dataclasses.fields(self): if fld.name not in resolved_values and fld.name not in [ "_SEMANTIC_MAP", "_STANDARD_TO_GENERATION_ATTR_MAP", ]: resolved_values[fld.name] = getattr(self, fld.name) for name, value in resolved_values.items(): object.__setattr__(self, name, value) self._safety_check() def _safety_check(self): """ Checks if any axis attribute still has the NOT_GIVEN value after resolution. Raises: ValueError: If any attribute is still NOT_GIVEN, indicating a configuration error. """ for fld in dataclasses.fields(self): if fld.name not in ["_SEMANTIC_MAP", "_STANDARD_TO_GENERATION_ATTR_MAP"]: val = getattr(self, fld.name) if val == NOT_GIVEN: raise ValueError(f"Partitioning rule `{fld.name}` was not resolved.")
[docs] def resolve_axis( self, axes: tp.Sequence[str | None], mode: RUNTIME_MODE_TYPES, # type:ignore ) -> list[str | None]: """ Generates a Axis from a sequence of semantic axis names and a mode. Maps a sequence of semantic axis name strings (like BATCH, LENGTH) to the actual mesh axis names defined in this `PartitionAxis` instance, considering the current runtime mode (e.g., training vs. generation). Args: axes: A sequence of semantic axis name strings (e.g., [BATCH, LENGTH, HEAD]) or None (or "_") for axes that shouldn't be sharded. mode: The current operational mode (e.g., MODE_TRAIN, MODE_DECODE) which determines if generation-specific rules should be applied. Returns: A instance representing the sharding for the given sequence of axes. Raises: ValueError: If an unknown semantic axis name is encountered or if a resolved axis rule is still NOT_GIVEN (should be caught by `_safety_check` but included for robustness). LookupError: If an internal attribute name derived from the semantic map isn't found in the instance (shouldn't happen with correct class definition). """ resolved_rules: list[AxisType] = [] for axis_name in axes: if axis_name is None or axis_name == "_": resolved_rules.append(None) continue # Composite axis rules can include direct mesh names and/or semantic names. if isinstance(axis_name, (list, tuple)): standard_rule = [] for sub_axis in axis_name: sub_mapped = self._lookup_semantic_mapping(sub_axis) standard_rule.append(sub_axis if sub_mapped is None else sub_mapped) else: standard_rule = self._lookup_semantic_mapping(axis_name) if standard_rule is None: raise ValueError(f"Unknown semantic axis name: '{axis_name}'") target_rule = standard_rule if mode in GENERATION_MODES: if isinstance(axis_name, (list, tuple)): gen_composite = [] any_changed = False for idx, sub_axis in enumerate(axis_name): sub_standard = standard_rule[idx] sub_target = sub_standard sub_gen = self._lookup_generation_mapping(sub_axis) if sub_gen is not NOT_GIVEN: sub_target = sub_gen any_changed = True elif isinstance(sub_standard, str): sub_gen_attr = self._STANDARD_TO_GENERATION_ATTR_MAP.get(sub_standard) if sub_gen_attr and hasattr(self, sub_gen_attr): sub_gen_val = getattr(self, sub_gen_attr) if sub_gen_val is not None and sub_gen_val is not NOT_GIVEN: sub_target = sub_gen_attr any_changed = True gen_composite.append(sub_target) if any_changed: target_rule = gen_composite else: custom_gen_rule = self._lookup_generation_mapping(axis_name) if custom_gen_rule is not NOT_GIVEN: target_rule = custom_gen_rule elif isinstance(standard_rule, str): gen_attr_name = self._STANDARD_TO_GENERATION_ATTR_MAP.get(standard_rule) if gen_attr_name and hasattr(self, gen_attr_name): gen_val = getattr(self, gen_attr_name) if gen_val is not None and gen_val is not NOT_GIVEN: target_rule = gen_attr_name mesh_axis_rule = self._resolve_axis_rule(target_rule) if mesh_axis_rule is NOT_GIVEN: raise ValueError(f"Resolved axis rule for '{axis_name}' is still NOT_GIVEN.") resolved_rules.append(mesh_axis_rule) return resolved_rules
[docs] def resolve_spec( self, axes: tp.Sequence[str | None], mode: RUNTIME_MODE_TYPES, # type:ignore ) -> PartitionSpec: """ Generates a PartitionSpec from a sequence of semantic axis names and a mode. Maps a sequence of semantic axis name strings (like BATCH, LENGTH) to the actual mesh axis names defined in this `PartitionAxis` instance, considering the current runtime mode (e.g., training vs. generation). Args: axes: A sequence of semantic axis name strings (e.g., [BATCH, LENGTH, HEAD]) or None (or "_") for axes that shouldn't be sharded. mode: The current operational mode (e.g., MODE_TRAIN, MODE_DECODE) which determines if generation-specific rules should be applied. Returns: A jax.sharding.PartitionSpec instance representing the sharding for the given sequence of axes. Raises: ValueError: If an unknown semantic axis name is encountered or if a resolved axis rule is still NOT_GIVEN (should be caught by `_safety_check` but included for robustness). LookupError: If an internal attribute name derived from the semantic map isn't found in the instance (shouldn't happen with correct class definition). """ return PartitionSpec(*self.resolve_axis(axes=axes, mode=mode))
__hash__ = hash_fn
[docs]class PartitionManager(PyTree): """ Context manager for applying sharding constraints using PartitionAxis. This class acts as a context manager (`with PartitionManager(...)`) to set a context-local variable (`_CURRENT_PARTITION_MANAGER`) that makes the current manager implicitly available via functions like `get_current_partition_manager()` or the static `shard()` method. Args: paxis: The PartitionAxis instance defining the sharding strategy to be used within this context. """ paxis: PartitionAxis def __post_init__(self): global _LAST_PARTITION_MANAGER if not isinstance(self.paxis, PartitionAxis): raise TypeError(f"Expected PartitionAxis, got {type(self.paxis)}") _LAST_PARTITION_MANAGER = self def __enter__(self): global _LAST_PARTITION_MANAGER token = _CURRENT_PARTITION_MANAGER.set(self) object.__setattr__(self, "_context_token", token) _LAST_PARTITION_MANAGER = self return self def __exit__(self, exc_type, exc_val, exc_tb): token = getattr(self, "_context_token", None) if token is not None: _CURRENT_PARTITION_MANAGER.reset(token) object.__setattr__(self, "_context_token", None) return False
[docs] def shard( self, x: jax.Array, axes: tp.Sequence[str | None] = NOT_GIVEN, mode: RUNTIME_MODE_TYPES | int = NOT_GIVEN, # type:ignore dynamic_axes: DynamicShardingAxes | None = NOT_GIVEN, auto_correct: bool = True, ) -> jax.Array: """ Applies sharding constraint to a JAX array using this manager's PartitionAxis. Uses this `PartitionManager` instance to resolve semantic axis names (`axes`) into a `PartitionSpec`, then applies the sharding constraint to `x`. Supports specifying axes and mode directly, or providing a `DynamicShardingAxes` named tuple. Can also infer the mode based on a dimension size if an integer `mode` is provided. Args: x: The JAX array to apply the sharding constraint to. axes: A sequence of semantic axis name strings or None. Required if `dynamic_axes` is NOT_GIVEN. mode: The runtime mode (string constant) or an integer representing the dimension index to check for mode inference. Required if `dynamic_axes` is NOT_GIVEN. dynamic_axes: An optional `DynamicShardingAxes` named tuple that provides both `axes` and `mode`. If provided, `axes` and `mode` arguments are ignored. auto_correct: If True, automatically corrects the resolved `PartitionSpec` based on array shape and mesh compatibility using `get_corrected_named_sharding`. Defaults to True. Returns: The array `x` with the sharding constraint applied. Raises: ValueError: If neither `axes`/`mode` nor `dynamic_axes` are provided. ValueError: Propagated from `PartitionAxis.resolve_spec` or if resolved axis rule is NOT_GIVEN. """ spec = self.resolve( axes=axes, mode=mode, dynamic_axes=dynamic_axes, shape=x.shape, ) if auto_correct: spec = get_corrected_named_sharding(x.shape, spec).spec return with_sharding_constraint(x, spec)
[docs] def resolve( self, axes: tp.Sequence[str | None] | DynamicShardingAxes = NOT_GIVEN, mode: RUNTIME_MODE_TYPES | int = NOT_GIVEN, # type:ignore dynamic_axes: DynamicShardingAxes | None = NOT_GIVEN, shape: tp.Sequence[int] = NOT_GIVEN, ) -> PartitionSpec: """Resolve semantic axis names to a PartitionSpec. Converts semantic axis names (like BATCH, LENGTH, HEAD) into a concrete PartitionSpec using the configured PartitionAxis mapping. Supports dynamic mode detection based on array shape. Args: axes: Sequence of semantic axis names, or a DynamicShardingAxes tuple containing both axes and mode. mode: Runtime mode (MODE_TRAIN, MODE_DECODE) or an integer dimension index for dynamic mode detection. When an integer is provided, mode is inferred based on whether shape[mode] == 1. dynamic_axes: Alternative way to provide axes and mode together as a DynamicShardingAxes named tuple. shape: Array shape, required when mode is an integer for dynamic mode detection. Returns: A PartitionSpec mapping semantic axes to mesh dimensions. Raises: ValueError: If axes/mode are not provided and dynamic_axes is also not provided, or if shape is missing for dynamic mode. Example: >>> manager = PartitionManager(paxis=paxis) >>> # Direct specification >>> spec = manager.resolve([BATCH, LENGTH, HEAD], mode=MODE_TRAIN) >>> # Dynamic mode detection (decode if dim 1 has size 1) >>> spec = manager.resolve([BATCH, LENGTH, HEAD], mode=1, shape=x.shape) """ if dynamic_axes is NOT_GIVEN and axes is not NOT_GIVEN: if isinstance(axes, tuple) and hasattr(axes, "_fields"): dynamic_axes = axes axes = NOT_GIVEN elif isinstance(axes, type) and issubclass(axes, tuple) and hasattr(axes, "_fields"): dynamic_axes = DynamicShardingAxes(axes=axes.axes, mode=axes.mode) axes = NOT_GIVEN if axes is NOT_GIVEN or mode is NOT_GIVEN: if dynamic_axes is NOT_GIVEN: raise ValueError("if axes or mode is empty you should provide dynamic axes") axes = dynamic_axes.axes mode = dynamic_axes.mode if isinstance(mode, int): if shape is NOT_GIVEN: raise ValueError("when using dynamic mode detection shape should be provided") mode = MODE_DECODE if shape[mode] == 1 else MODE_TRAIN return self.paxis.resolve_spec(axes, mode)
def __str__(self): """String representation of the PartitionManager.""" return "PartitionManager(...)" def __repr__(self): """Representation of the PartitionManager.""" return "PartitionManager(...)" __hash__ = hash_fn
[docs]def get_current_partition_manager() -> PartitionManager | None: """Get the current context-local partition manager, if set.""" return _CURRENT_PARTITION_MANAGER.get()
[docs]def get_partition_manager() -> PartitionManager | None: """Get the last created partition manager instance.""" return _LAST_PARTITION_MANAGER
[docs]def apply_logical_sharding( x: jax.Array, partition_manager: PartitionManager | None = NOT_GIVEN, axes: tp.Sequence[str | None] = NOT_GIVEN, mode: RUNTIME_MODE_TYPES | int = NOT_GIVEN, # type:ignore dynamic_axes: DynamicShardingAxes | None = NOT_GIVEN, auto_correct: bool = True, ): """ Applies logical sharding to a JAX array using an available PartitionManager. This function is a convenience wrapper around `PartitionManager.shard`. It attempts to find a `PartitionManager` from the current context first (`get_current_partition_manager`), and if none is found, it falls back to the last created manager (`get_partition_manager`). Args: x: The JAX array to apply sharding to. partition_manager: An explicit `PartitionManager` instance to use. If not provided, the function tries the current context manager first, then the last created manager. axes: A sequence of semantic axis name strings or None. Required if `dynamic_axes` is NOT_GIVEN and `partition_manager` is NOT_GIVEN. mode: The runtime mode or dimension index for inference. Required if `dynamic_axes` is NOT_GIVEN and `partition_manager` is NOT_GIVEN. dynamic_axes: An optional `DynamicShardingAxes` tuple. If provided, `axes` and `mode` are ignored. auto_correct: If True, automatically corrects the resolved PartitionSpec. Defaults to True. Returns: The JAX array with sharding constraints applied. Raises: ValueError: If neither `axes`/`mode` nor `dynamic_axes` are provided when a manager is found or provided. """ resolved_manager = partition_manager if resolved_manager is NOT_GIVEN or resolved_manager is None: resolved_manager = get_current_partition_manager() or get_partition_manager() if resolved_manager is None: raise ValueError( "No PartitionManager is available. Provide `partition_manager` or use `with PartitionManager(...)`." ) return resolved_manager.shard( x=x, axes=axes, mode=mode, dynamic_axes=dynamic_axes, auto_correct=auto_correct, )