eformer.escale.partition.manager#
This module provides classes and functions for managing JAX sharding configurations and applying sharding constraints within a context.
It includes the PartitionAxis class for defining logical-to-physical axis mappings and the PartitionManager context manager for applying these rules.
- class eformer.escale.partition.manager.PartitionAxis(data_parallel_axis: str = 'dp', fully_sharded_data_parallel_axis: str = 'fsdp', tensor_parallel_axis: str = 'tp', sequence_parallel_axis: str = 'sp', expert_parallel_axis: str = 'ep', batch_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, query_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, kv_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, key_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, hidden_state_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, mlp_intermediate_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, vocab_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, expert_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, expert_gate_axis: tuple[str, ...] | str | Any | None = None, attention_dim_axis: tuple[str, ...] | str | Any | None = None, attention_kv_dim_axis: tuple[str, ...] | str | Any | None = None, bias_head_sequence_axis: tuple[str, ...] | str | Any | None = None, bias_key_sequence_axis: tuple[str, ...] | str | Any | None = None, decode_batch_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, decode_query_sequence_axis: tuple[str, ...] | str | Any | None = None, decode_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, decode_kv_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, decode_key_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>, decode_attention_dim_axis: tuple[str, ...] | str | Any | None = None, decode_attention_kv_dim_axis: tuple[str, ...] | str | Any | None = None)[source]#
Bases:
xTreeConfiguration for partitioning model axes across a device mesh.
Defines the mesh dimension names for standard parallelism strategies and maps logical model axes to these dimensions. Allows overriding defaults.
- Mesh Dimensions Attributes:
data_parallel_axis: Name for data parallel mesh dim. Default: “dp”. fully_sharded_data_parallel_axis: Name for FSDP mesh dim. Default: “fsdp”. tensor_parallel_axis: Name for tensor parallel mesh dim. Default: “tp”. sequence_parallel_axis: Name for sequence parallel mesh dim. Default: “sp”. expert_parallel_axis: Name for expert parallel mesh dim (MoE). Default: “ep”.
- Logical Model Axes Attributes:
Maps logical tensor axes (like batch, sequence, hidden) to one or more mesh dimension names defined above, or None if not partitioned. Defaults are derived from the standard mesh dimension names but can be overridden during instantiation. For example, head_axis defaults to the value of tensor_parallel_axis (‘tp’).
batch_axis: Mesh axis for the batch dimension. sequence_axis: Mesh axis for the general sequence length dimension. query_sequence_axis: Mesh axis for the query sequence length dimension. head_axis: Mesh axis for the attention head dimension. key_sequence_axis: Mesh axis for the key/value sequence length dimension. hidden_state_axis: Mesh axis for the embedding or hidden state dimension. mlp_intermediate_axis: Mesh axis for the intermediate dimension in MLP layers. vocab_axis: Mesh axis for the vocabulary dimension. expert_axis: Mesh axis for the expert dimension. expert_gate_axis: Mesh axis for the expert gate dimension. attention_dim_axis: Mesh axis for the dimension within each attention head. bias_head_sequence_axis: Mesh axis for bias related to head and sequence dimensions. bias_key_sequence_axis: Mesh axis for bias related to key/value sequence dimensions.
decode_batch_axis: Mesh axis for the batch dimension during decoding. decode_query_sequence_axis: Mesh axis for the query sequence length during decoding. decode_head_axis: Mesh axis for the attention head dimension during decoding. decode_key_sequence_axis: Mesh axis for the key/value sequence length during decoding. decode_attention_dim_axis: Mesh axis for the dimension within each attention head during decoding.
- attention_dim_axis: tuple[str, ...] | str | Any | None = None#
- attention_kv_dim_axis: tuple[str, ...] | str | Any | None = None#
- batch_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- bias_head_sequence_axis: tuple[str, ...] | str | Any | None = None#
- bias_key_sequence_axis: tuple[str, ...] | str | Any | None = None#
- data_parallel_axis: str = 'dp'#
- decode_attention_dim_axis: tuple[str, ...] | str | Any | None = None#
- decode_attention_kv_dim_axis: tuple[str, ...] | str | Any | None = None#
- decode_batch_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- decode_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- decode_key_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- decode_kv_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- decode_query_sequence_axis: tuple[str, ...] | str | Any | None = None#
- expert_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- expert_gate_axis: tuple[str, ...] | str | Any | None = None#
- expert_parallel_axis: str = 'ep'#
- fully_sharded_data_parallel_axis: str = 'fsdp'#
- head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- key_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- kv_head_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- mlp_intermediate_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- query_sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- replace(**updates)#
Returns a new instance of the dataclass with specified fields updated.
- Parameters
**updates – Keyword arguments where keys are field names and values are the new values for those fields.
- Returns
A new instance of the dataclass with the updated fields.
- resolve_axis(axes: Sequence[str | None], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']) list[str | None][source]#
Generates a Axis from a sequence of semantic axis names and a mode.
Maps a sequence of semantic axis name strings (like BATCH, LENGTH) to the actual mesh axis names defined in this PartitionAxis instance, considering the current runtime mode (e.g., training vs. generation).
- Parameters
axes – A sequence of semantic axis name strings (e.g., [BATCH, LENGTH, HEAD]) or None (or “_”) for axes that shouldn’t be sharded.
mode – The current operational mode (e.g., MODE_TRAIN, MODE_DECODE) which determines if generation-specific rules should be applied.
- Returns
A instance representing the sharding for the given sequence of axes.
- Raises
ValueError – If an unknown semantic axis name is encountered or if a resolved axis rule is still NOT_GIVEN (should be caught by _safety_check but included for robustness).
LookupError – If an internal attribute name derived from the semantic map isn’t found in the instance (shouldn’t happen with correct class definition).
- resolve_spec(axes: Sequence[str | None], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']) PartitionSpec[source]#
Generates a PartitionSpec from a sequence of semantic axis names and a mode.
Maps a sequence of semantic axis name strings (like BATCH, LENGTH) to the actual mesh axis names defined in this PartitionAxis instance, considering the current runtime mode (e.g., training vs. generation).
- Parameters
axes – A sequence of semantic axis name strings (e.g., [BATCH, LENGTH, HEAD]) or None (or “_”) for axes that shouldn’t be sharded.
mode – The current operational mode (e.g., MODE_TRAIN, MODE_DECODE) which determines if generation-specific rules should be applied.
- Returns
A jax.sharding.PartitionSpec instance representing the sharding for the given sequence of axes.
- Raises
ValueError – If an unknown semantic axis name is encountered or if a resolved axis rule is still NOT_GIVEN (should be caught by _safety_check but included for robustness).
LookupError – If an internal attribute name derived from the semantic map isn’t found in the instance (shouldn’t happen with correct class definition).
- sequence_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- sequence_parallel_axis: str = 'sp'#
- tensor_parallel_axis: str = 'tp'#
- vocab_axis: tuple[str, ...] | str | Any | None = <eformer.common_types._Empty object>#
- class eformer.escale.partition.manager.PartitionManager(paxis: PartitionAxis)[source]#
Bases:
PyTreeContext manager for applying sharding constraints using PartitionAxis.
This class acts as a context manager (with PartitionManager(…)) to set a context-local variable (_CURRENT_PARTITION_MANAGER) that makes the current manager implicitly available via functions like get_current_partition_manager() or the static shard() method.
- Parameters
paxis – The PartitionAxis instance defining the sharding strategy to be used within this context.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- paxis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- resolve(axes: ~typing.Union[~typing.Sequence[str | None], ~eformer.common_types.DynamicShardingAxes] = <eformer.common_types._Empty object>, mode: ~typing.Union[~typing.Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int] = <eformer.common_types._Empty object>, dynamic_axes: eformer.common_types.DynamicShardingAxes | None = <eformer.common_types._Empty object>, shape: ~typing.Sequence[int] = <eformer.common_types._Empty object>) PartitionSpec[source]#
Resolve semantic axis names to a PartitionSpec.
Converts semantic axis names (like BATCH, LENGTH, HEAD) into a concrete PartitionSpec using the configured PartitionAxis mapping. Supports dynamic mode detection based on array shape.
- Parameters
axes – Sequence of semantic axis names, or a DynamicShardingAxes tuple containing both axes and mode.
mode – Runtime mode (MODE_TRAIN, MODE_DECODE) or an integer dimension index for dynamic mode detection. When an integer is provided, mode is inferred based on whether shape[mode] == 1.
dynamic_axes – Alternative way to provide axes and mode together as a DynamicShardingAxes named tuple.
shape – Array shape, required when mode is an integer for dynamic mode detection.
- Returns
A PartitionSpec mapping semantic axes to mesh dimensions.
- Raises
ValueError – If axes/mode are not provided and dynamic_axes is also not provided, or if shape is missing for dynamic mode.
Example
>>> manager = PartitionManager(paxis=paxis) >>> # Direct specification >>> spec = manager.resolve([BATCH, LENGTH, HEAD], mode=MODE_TRAIN) >>> # Dynamic mode detection (decode if dim 1 has size 1) >>> spec = manager.resolve([BATCH, LENGTH, HEAD], mode=1, shape=x.shape)
- shard(x: ~jax.jaxlib._jax.Array, axes: ~typing.Sequence[str | None] = <eformer.common_types._Empty object>, mode: ~typing.Union[~typing.Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int] = <eformer.common_types._Empty object>, dynamic_axes: eformer.common_types.DynamicShardingAxes | None = <eformer.common_types._Empty object>, auto_correct: bool = True) Array[source]#
Applies sharding constraint to a JAX array based on the active PartitionManager context.
Retrieves the current PartitionManager implicitly using get_current_partition_manager() and uses its PartitionAxis to resolve the semantic axis names (axes) into a PartitionSpec. It then applies the sharding constraint to the array x.
Supports specifying axes and mode directly, or providing a DynamicShardingAxes named tuple. Can also infer the mode based on a dimension size if an integer mode is provided.
- Parameters
x – The JAX array to apply the sharding constraint to.
axes – A sequence of semantic axis name strings or None. Required if dynamic_axes is NOT_GIVEN.
mode – The runtime mode (string constant) or an integer representing the dimension index to check for mode inference. Required if dynamic_axes is NOT_GIVEN.
dynamic_axes – An optional DynamicShardingAxes named tuple that provides both axes and mode. If provided, axes and mode arguments are ignored.
auto_correct – If True, automatically corrects the resolved PartitionSpec based on array shape and mesh compatibility using get_corrected_named_sharding. Defaults to True.
- Returns
The array x with the sharding constraint applied.
- Raises
LookupError – If called outside of an active PartitionManager context.
ValueError – If neither axes/mode nor dynamic_axes are provided.
ValueError – Propagated from PartitionAxis.resolve_spec or if resolved axis rule is NOT_GIVEN.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- eformer.escale.partition.manager.apply_logical_sharding(x: ~jax.jaxlib._jax.Array, partition_manager: ~eformer.escale.partition.manager.PartitionManager, axes: ~typing.Sequence[str | None] = <eformer.common_types._Empty object>, mode: ~typing.Union[~typing.Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int] = <eformer.common_types._Empty object>, dynamic_axes: eformer.common_types.DynamicShardingAxes | None = <eformer.common_types._Empty object>, auto_correct: bool = True)[source]#
Applies logical sharding to a JAX array using an available PartitionManager.
This function is a convenience wrapper around PartitionManager.shard. It attempts to find a PartitionManager from the current context first (get_current_partition_manager), and if none is found, it falls back to the last created manager (get_partition_manager).
- Parameters
x – The JAX array to apply sharding to.
partition_manager – An explicit PartitionManager instance to use.
axes – A sequence of semantic axis name strings or None. Required if dynamic_axes is NOT_GIVEN and partition_manager is NOT_GIVEN.
mode – The runtime mode or dimension index for inference. Required if dynamic_axes is NOT_GIVEN and partition_manager is NOT_GIVEN.
dynamic_axes – An optional DynamicShardingAxes tuple. If provided, axes and mode are ignored.
auto_correct – If True, automatically corrects the resolved PartitionSpec. Defaults to True.
- Returns
The JAX array with sharding constraints applied.
- Raises
ValueError – If neither axes/mode nor dynamic_axes are provided when a manager is found or provided.
- eformer.escale.partition.manager.get_safe_hash_int(text, algorithm='md5')[source]#
Convert text to an integer hash using the specified algorithm.
Provides a safe way to generate integer hashes from strings, useful for creating hashable keys from complex objects.
- Parameters
text – The text to hash. Will be converted to string if not already.
algorithm – Hash algorithm to use. Defaults to “md5”. Supports any algorithm available in hashlib (e.g., “sha256”, “sha1”).
- Returns
An integer representation of the hash digest.
- Raises
ValueError – If the specified algorithm is not supported by hashlib.
Exception – If any other error occurs during hash generation.
Example
>>> get_safe_hash_int("hello world") 309817674445039181685702831361671 >>> get_safe_hash_int("hello world", algorithm="sha256") ... # Different integer value
- eformer.escale.partition.manager.hash_fn(self) int[source]#
Compute a hash value for an object based on its dictionary values.
Creates a hash by concatenating string representations of hashable attribute values (int, float, bool, dict, list) and computing an MD5 hash of the result.
- Parameters
self – The object to hash (bound method).
- Returns
An integer hash value derived from the object’s attributes.