eformer.escale.partition.manager#
This module provides classes and functions for managing JAX sharding configurations and applying sharding constraints within a context.
It includes the PartitionAxis class for defining logical-to-physical axis mappings and the PartitionManager context manager for applying these rules.
- class eformer.escale.partition.manager.PartitionAxis(data_parallel_axis: str = 'dp', fully_sharded_data_parallel_axis: str = 'fsdp', tensor_parallel_axis: str = 'tp', sequence_parallel_axis: str = 'sp', expert_parallel_axis: str = 'ep', batch_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, query_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, kv_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, key_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, hidden_state_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, mlp_intermediate_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, vocab_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, expert_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, expert_gate_axis: tuple[str, ...] | str | Any | None = None, attention_dim_axis: tuple[str, ...] | str | Any | None = None, attention_kv_dim_axis: tuple[str, ...] | str | Any | None = None, bias_head_sequence_axis: tuple[str, ...] | str | Any | None = None, bias_key_sequence_axis: tuple[str, ...] | str | Any | None = None, decode_batch_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, decode_query_sequence_axis: tuple[str, ...] | str | Any | None = None, decode_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, decode_kv_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, decode_key_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, decode_attention_dim_axis: tuple[str, ...] | str | Any | None = None, decode_attention_kv_dim_axis: tuple[str, ...] | str | Any | None = None)[source]#
Bases:
xTreeConfiguration for partitioning model axes across a device mesh.
Defines the mesh dimension names for standard parallelism strategies and maps logical model axes to these dimensions. Allows overriding defaults.
- Mesh Dimensions Attributes:
data_parallel_axis: Name for data parallel mesh dim. Default: “dp”. fully_sharded_data_parallel_axis: Name for FSDP mesh dim. Default: “fsdp”. tensor_parallel_axis: Name for tensor parallel mesh dim. Default: “tp”. sequence_parallel_axis: Name for sequence parallel mesh dim. Default: “sp”. expert_parallel_axis: Name for expert parallel mesh dim (MoE). Default: “ep”.
- Logical Model Axes Attributes:
Maps logical tensor axes (like batch, sequence, hidden) to one or more mesh dimension names defined above, or None if not partitioned. Defaults are derived from the standard mesh dimension names but can be overridden during instantiation. For example, head_axis defaults to the value of tensor_parallel_axis (‘tp’).
batch_axis: Mesh axis for the batch dimension. sequence_axis: Mesh axis for the general sequence length dimension. query_sequence_axis: Mesh axis for the query sequence length dimension. head_axis: Mesh axis for the attention head dimension. key_sequence_axis: Mesh axis for the key/value sequence length dimension. hidden_state_axis: Mesh axis for the embedding or hidden state dimension. mlp_intermediate_axis: Mesh axis for the intermediate dimension in MLP layers. vocab_axis: Mesh axis for the vocabulary dimension. expert_axis: Mesh axis for the expert dimension. expert_gate_axis: Mesh axis for the expert gate dimension. attention_dim_axis: Mesh axis for the dimension within each attention head. bias_head_sequence_axis: Mesh axis for bias related to head and sequence dimensions. bias_key_sequence_axis: Mesh axis for bias related to key/value sequence dimensions.
decode_batch_axis: Mesh axis for the batch dimension during decoding. decode_query_sequence_axis: Mesh axis for the query sequence length during decoding. decode_head_axis: Mesh axis for the attention head dimension during decoding. decode_key_sequence_axis: Mesh axis for the key/value sequence length during decoding. decode_attention_dim_axis: Mesh axis for the dimension within each attention head during decoding.
- attention_dim_axis: tuple[str, ...] | str | Any | None = None#
- attention_kv_dim_axis: tuple[str, ...] | str | Any | None = None#
- batch_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- bias_head_sequence_axis: tuple[str, ...] | str | Any | None = None#
- bias_key_sequence_axis: tuple[str, ...] | str | Any | None = None#
- data_parallel_axis: str = 'dp'#
- decode_attention_dim_axis: tuple[str, ...] | str | Any | None = None#
- decode_attention_kv_dim_axis: tuple[str, ...] | str | Any | None = None#
- decode_batch_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- decode_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- decode_key_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- decode_kv_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- decode_query_sequence_axis: tuple[str, ...] | str | Any | None = None#
- expert_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- expert_gate_axis: tuple[str, ...] | str | Any | None = None#
- expert_parallel_axis: str = 'ep'#
- fully_sharded_data_parallel_axis: str = 'fsdp'#
- classmethod get_registered_axes() dict[str, dict[str, Any]][source]#
Return a snapshot of globally registered custom axis mappings.
- head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- key_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- kv_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- mlp_intermediate_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- query_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- classmethod register(semantic_axis: str, axis_rule: ~typing.Any, *, generation_axis_rule: ~typing.Any = <eformer.common_types._Empty object>, override: bool = False) None[source]#
Register a semantic axis mapping globally.
This updates a process-global registry used by all current/future
PartitionAxisinstances. Mapping values may be: - an attribute name onPartitionAxis(for example"head_axis"), - a literal mesh axis name (for example"tp"), - a tuple/list of either of the above, - orNone.- Parameters
semantic_axis – Semantic axis token to register.
axis_rule – Resolution rule for standard/train mode.
generation_axis_rule – Optional explicit rule for generation modes. If omitted and
axis_rulemaps to a known standard attribute, the corresponding decode attribute mapping is inferred.override – If
False, raises whensemantic_axisalready exists in built-in or custom maps. SetTrueto replace existing rules.
- replace(**updates)#
Returns a new instance of the dataclass with specified fields updated.
- Parameters
**updates – Keyword arguments where keys are field names and values are the new values for those fields.
- Returns
A new instance of the dataclass with the updated fields.
- resolve_axis(axes: Sequence[str | None], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']) list[str | None][source]#
Generates a Axis from a sequence of semantic axis names and a mode.
Maps a sequence of semantic axis name strings (like BATCH, LENGTH) to the actual mesh axis names defined in this PartitionAxis instance, considering the current runtime mode (e.g., training vs. generation).
- Parameters
axes – A sequence of semantic axis name strings (e.g., [BATCH, LENGTH, HEAD]) or None (or “_”) for axes that shouldn’t be sharded.
mode – The current operational mode (e.g., MODE_TRAIN, MODE_DECODE) which determines if generation-specific rules should be applied.
- Returns
A instance representing the sharding for the given sequence of axes.
- Raises
ValueError – If an unknown semantic axis name is encountered or if a resolved axis rule is still NOT_GIVEN (should be caught by _safety_check but included for robustness).
LookupError – If an internal attribute name derived from the semantic map isn’t found in the instance (shouldn’t happen with correct class definition).
- resolve_spec(axes: Sequence[str | None], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']) PartitionSpec[source]#
Generates a PartitionSpec from a sequence of semantic axis names and a mode.
Maps a sequence of semantic axis name strings (like BATCH, LENGTH) to the actual mesh axis names defined in this PartitionAxis instance, considering the current runtime mode (e.g., training vs. generation).
- Parameters
axes – A sequence of semantic axis name strings (e.g., [BATCH, LENGTH, HEAD]) or None (or “_”) for axes that shouldn’t be sharded.
mode – The current operational mode (e.g., MODE_TRAIN, MODE_DECODE) which determines if generation-specific rules should be applied.
- Returns
A jax.sharding.PartitionSpec instance representing the sharding for the given sequence of axes.
- Raises
ValueError – If an unknown semantic axis name is encountered or if a resolved axis rule is still NOT_GIVEN (should be caught by _safety_check but included for robustness).
LookupError – If an internal attribute name derived from the semantic map isn’t found in the instance (shouldn’t happen with correct class definition).
- sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- sequence_parallel_axis: str = 'sp'#
- tensor_parallel_axis: str = 'tp'#
- classmethod unregister(semantic_axis: str, *, missing_ok: bool = True) None[source]#
Remove a previously registered semantic axis mapping.
- vocab_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- class eformer.escale.partition.manager.PartitionManager(paxis: PartitionAxis)[source]#
Bases:
PyTreeContext manager for applying sharding constraints using PartitionAxis.
This class acts as a context manager (with PartitionManager(…)) to set a context-local variable (_CURRENT_PARTITION_MANAGER) that makes the current manager implicitly available via functions like get_current_partition_manager() or the static shard() method.
- Parameters
paxis – The PartitionAxis instance defining the sharding strategy to be used within this context.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- paxis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- resolve(axes: ~typing.Union[~typing.Sequence[str | None], ~eformer.common_types.DynamicShardingAxes] = <eformer.common_types._Empty object>, mode: ~typing.Union[~typing.Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int] = <eformer.common_types._Empty object>, dynamic_axes: eformer.common_types.DynamicShardingAxes | None = <eformer.common_types._Empty object>, shape: ~typing.Sequence[int] = <eformer.common_types._Empty object>) PartitionSpec[source]#
Resolve semantic axis names to a PartitionSpec.
Converts semantic axis names (like BATCH, LENGTH, HEAD) into a concrete PartitionSpec using the configured PartitionAxis mapping. Supports dynamic mode detection based on array shape.
- Parameters
axes – Sequence of semantic axis names, or a DynamicShardingAxes tuple containing both axes and mode.
mode – Runtime mode (MODE_TRAIN, MODE_DECODE) or an integer dimension index for dynamic mode detection. When an integer is provided, mode is inferred based on whether shape[mode] == 1.
dynamic_axes – Alternative way to provide axes and mode together as a DynamicShardingAxes named tuple.
shape – Array shape, required when mode is an integer for dynamic mode detection.
- Returns
A PartitionSpec mapping semantic axes to mesh dimensions.
- Raises
ValueError – If axes/mode are not provided and dynamic_axes is also not provided, or if shape is missing for dynamic mode.
Example
>>> manager = PartitionManager(paxis=paxis) >>> # Direct specification >>> spec = manager.resolve([BATCH, LENGTH, HEAD], mode=MODE_TRAIN) >>> # Dynamic mode detection (decode if dim 1 has size 1) >>> spec = manager.resolve([BATCH, LENGTH, HEAD], mode=1, shape=x.shape)
- shard(x: ~jax.jaxlib._jax.Array, axes: ~typing.Sequence[str | None] = <eformer.common_types._Empty object>, mode: ~typing.Union[~typing.Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int] = <eformer.common_types._Empty object>, dynamic_axes: eformer.common_types.DynamicShardingAxes | None = <eformer.common_types._Empty object>, auto_correct: bool = True) Array[source]#
Applies sharding constraint to a JAX array using this manager’s PartitionAxis.
Uses this PartitionManager instance to resolve semantic axis names (axes) into a PartitionSpec, then applies the sharding constraint to x.
Supports specifying axes and mode directly, or providing a DynamicShardingAxes named tuple. Can also infer the mode based on a dimension size if an integer mode is provided.
- Parameters
x – The JAX array to apply the sharding constraint to.
axes – A sequence of semantic axis name strings or None. Required if dynamic_axes is NOT_GIVEN.
mode – The runtime mode (string constant) or an integer representing the dimension index to check for mode inference. Required if dynamic_axes is NOT_GIVEN.
dynamic_axes – An optional DynamicShardingAxes named tuple that provides both axes and mode. If provided, axes and mode arguments are ignored.
auto_correct – If True, automatically corrects the resolved PartitionSpec based on array shape and mesh compatibility using get_corrected_named_sharding. Defaults to True.
- Returns
The array x with the sharding constraint applied.
- Raises
ValueError – If neither axes/mode nor dynamic_axes are provided.
ValueError – Propagated from PartitionAxis.resolve_spec or if resolved axis rule is NOT_GIVEN.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- eformer.escale.partition.manager.apply_logical_sharding(x: ~jax.jaxlib._jax.Array, partition_manager: eformer.escale.partition.manager.PartitionManager | None = <eformer.common_types._Empty object>, axes: ~typing.Sequence[str | None] = <eformer.common_types._Empty object>, mode: ~typing.Union[~typing.Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int] = <eformer.common_types._Empty object>, dynamic_axes: eformer.common_types.DynamicShardingAxes | None = <eformer.common_types._Empty object>, auto_correct: bool = True)[source]#
Applies logical sharding to a JAX array using an available PartitionManager.
This function is a convenience wrapper around PartitionManager.shard. It attempts to find a PartitionManager from the current context first (get_current_partition_manager), and if none is found, it falls back to the last created manager (get_partition_manager).
- Parameters
x – The JAX array to apply sharding to.
partition_manager – An explicit PartitionManager instance to use. If not provided, the function tries the current context manager first, then the last created manager.
axes – A sequence of semantic axis name strings or None. Required if dynamic_axes is NOT_GIVEN and partition_manager is NOT_GIVEN.
mode – The runtime mode or dimension index for inference. Required if dynamic_axes is NOT_GIVEN and partition_manager is NOT_GIVEN.
dynamic_axes – An optional DynamicShardingAxes tuple. If provided, axes and mode are ignored.
auto_correct – If True, automatically corrects the resolved PartitionSpec. Defaults to True.
- Returns
The JAX array with sharding constraints applied.
- Raises
ValueError – If neither axes/mode nor dynamic_axes are provided when a manager is found or provided.
- eformer.escale.partition.manager.get_current_partition_manager() eformer.escale.partition.manager.PartitionManager | None[source]#
Get the current context-local partition manager, if set.
- eformer.escale.partition.manager.get_partition_manager() eformer.escale.partition.manager.PartitionManager | None[source]#
Get the last created partition manager instance.
- eformer.escale.partition.manager.get_safe_hash_int(text, algorithm='md5')[source]#
Convert text to an integer hash using the specified algorithm.
Provides a safe way to generate integer hashes from strings, useful for creating hashable keys from complex objects.
- Parameters
text – The text to hash. Will be converted to string if not already.
algorithm – Hash algorithm to use. Defaults to “md5”. Supports any algorithm available in hashlib (e.g., “sha256”, “sha1”).
- Returns
An integer representation of the hash digest.
- Raises
ValueError – If the specified algorithm is not supported by hashlib.
Exception – If any other error occurs during hash generation.
Example
>>> get_safe_hash_int("hello world") 309817674445039181685702831361671 >>> get_safe_hash_int("hello world", algorithm="sha256") ... # Different integer value