eformer.common_types#
This module defines common types, constants, and named tuples used across the eformer library, particularly for JAX and sharding configurations.
- eformer.common_types.Array#
Type alias for JAX arrays.
- class eformer.common_types.AttnKVSharding(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for attention keys/values.
- axes: ClassVar = ['__BATCH__', '__KV_LENGTH__', '__KV_HEAD__', '__KV_HEAD_DIM__']#
- mode: ClassVar = 1#
- class eformer.common_types.AttnQSharding(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for attention queries.
- axes: ClassVar = ['__BATCH__', '__QUERY_LENGTH__', '__HEAD__', '__HEAD_DIM__']#
- mode: ClassVar = 1#
- eformer.common_types.AxisIdxes#
Type alias for a tuple of axis indices.
alias of
tuple[int, …]
- eformer.common_types.AxisNames#
Type alias for a tuple of mesh axis names.
alias of
tuple[str, …]
- eformer.common_types.AxisType = tuple[str, ...] | str | typing.Any | None#
Type alias for a mesh axis specification.
Can be a single string (axis name), a tuple of strings, None (for no sharding), or potentially other types depending on context (though typically str or tuple[str, …]).
- eformer.common_types.BATCH = '__BATCH__'#
Semantic axis name for the batch dimension.
- eformer.common_types.BIAS_HEAD_SEQ = '__BIAS_HEAD_SEQ__'#
Semantic axis name for bias related to head and sequence dimensions.
- eformer.common_types.BIAS_KV_SEQ = '__BIAS_KV_SEQ__'#
Semantic axis name for bias related to key/value sequence dimensions.
- class eformer.common_types.ColumnWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for Column Wise sharding.
- axes: ClassVar = [['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__'], '__TENSOR_PARALLEL__']#
- mode: ClassVar = '__train__'#
- eformer.common_types.DEFAULT_MASK_VALUE = -2.381976426469702e+38#
Default value used for masking, typically in attention mechanisms.
- eformer.common_types.DType#
Type alias for JAX data types.
- class eformer.common_types.DynamicShardingAxes(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
NamedTupleA NamedTuple to define sharding axes and mode dynamically.
Used to specify sharding based on the runtime mode or other dynamic factors.
- axes#
A sequence of semantic axis names or None.
- Type
Sequence[str | None]
- mode#
The runtime mode (string constant) or an integer representing the dimension index to check for mode inference.
- Type
Union[Literal[‘__autoregressive__’, ‘__prefill__’, ‘__train__’, ‘__insert__’], int]
- axes: Sequence[str | None]#
Alias for field number 0
- mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int]#
Alias for field number 1
- eformer.common_types.EMBED = '__EMBED__'#
Semantic axis name for the embedding or hidden state dimension.
- eformer.common_types.EMPTY_VAL#
Sentinel value indicating that a parameter was not explicitly provided.
- eformer.common_types.EXPERT = '__EXPERT__'#
Semantic axis name for the expert dimension in Mixture-of-Experts models.
- eformer.common_types.EXPERT_GATE = '__EXPERT_GATE__'#
Semantic axis name for the expert gate dimension in Mixture-of-Experts models.
- class eformer.common_types.ExpertActivations(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesSharding for expert activation tensors of shape [batch, sequence, num_experts, hidden]. - Batch dimension: DP (data parallel) - Sequence dimension: SP (sequence parallel) - Expert dimension: EP (expert parallel) - Hidden dimension: TP (tensor parallel) or FSDP
- axes: ClassVar = ['__DATA_PARALLEL__', '__SEQUENCE_PARALLEL__', '__EXPERT_PARALLEL__', '__TENSOR_PARALLEL__']#
- mode: ClassVar = '__train__'#
- class eformer.common_types.ExpertActivationsAlt(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesAlternative activation sharding for shape [batch, sequence, hidden]. When experts are already selected/routed.
- axes: ClassVar = ['__DATA_PARALLEL__', '__SEQUENCE_PARALLEL__', ['__TENSOR_PARALLEL__', '__FULLY_SHARDED_DATA_PARALLEL__']]#
- mode: ClassVar = '__train__'#
- class eformer.common_types.ExpertColumnWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for Column Wise sharding.
For a typical expert layer weight tensor of shape [num_experts, hidden_size, intermediate_size]: - Dimension 0 (num_experts): Shard across EP (expert parallel) - Dimension 1 (hidden_size): Shard across FSDP (parameter sharding) - Dimension 2 (intermediate_size): Shard across TP (tensor parallel - column-wise)
DP is used for batch dimension in activations, SP for sequence length.
- axes: ClassVar = ['__EXPERT_PARALLEL__', '__FULLY_SHARDED_DATA_PARALLEL__', '__TENSOR_PARALLEL__']#
- mode: ClassVar = '__train__'#
- class eformer.common_types.ExpertColumnWiseAlt(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesAlternative column-wise sharding using SP for sequence-related parameters. Use this if your expert weights have a sequence-related dimension.
- axes: ClassVar = ['__EXPERT_PARALLEL__', ['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__'], '__TENSOR_PARALLEL__']#
- mode: ClassVar = '__train__'#
- class eformer.common_types.ExpertRowWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for Row Wise sharding.
For a typical expert layer weight tensor of shape [num_experts, intermediate_size, hidden_size]: - Dimension 0 (num_experts): Shard across EP (expert parallel) - Dimension 1 (intermediate_size): Shard across TP (tensor parallel - row-wise) - Dimension 2 (hidden_size): Shard across FSDP (parameter sharding)
DP is used for batch dimension in activations, SP for sequence length.
- axes: ClassVar = ['__EXPERT_PARALLEL__', '__TENSOR_PARALLEL__', '__FULLY_SHARDED_DATA_PARALLEL__']#
- mode: ClassVar = '__train__'#
- class eformer.common_types.ExpertRowWiseAlt(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesAlternative row-wise sharding using SP for sequence-related parameters. Use this if your expert weights have a sequence-related dimension.
- axes: ClassVar = ['__EXPERT_PARALLEL__', '__TENSOR_PARALLEL__', ['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__']]#
- mode: ClassVar = '__train__'#
- eformer.common_types.GENERATION_MODES = {'__autoregressive__', '__insert__'}#
Set of runtime modes considered as generation modes.
- class eformer.common_types.HiddenStateSharding(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for hidden states.
- axes: ClassVar = ['__BATCH__', '__QUERY_LENGTH__', '__EMBED__']#
- mode: ClassVar = 1#
- eformer.common_types.KV_HEAD = '__KV_HEAD__'#
Semantic axis name for the attention head dimension.
- eformer.common_types.KV_HEAD_DIM = '__KV_HEAD_DIM__'#
Semantic axis name for the dimension within each attention head.
- eformer.common_types.KV_LENGTH = '__KV_LENGTH__'#
Semantic axis name for the key/value sequence length dimension.
- eformer.common_types.LENGTH = '__LENGTH__'#
Semantic axis name for the sequence length dimension (general).
- eformer.common_types.MLP_INTERMEDIATE = '__MLP_INTERMEDIATE__'#
Semantic axis name for the intermediate dimension in MLP layers.
- eformer.common_types.MODE_DECODE = '__autoregressive__'#
Runtime mode for autoregressive decoding.
- eformer.common_types.MODE_INSERT = '__insert__'#
Runtime mode for inserting into the cache.
- eformer.common_types.MODE_PREFILL = '__prefill__'#
Runtime mode for prefilling the cache.
- eformer.common_types.MODE_TRAIN = '__train__'#
Runtime mode for training.
- class eformer.common_types.Mesh(devices: numpy.ndarray | collections.abc.Sequence[jaxlib._jax.Device], axis_names: str | collections.abc.Sequence[Any], axis_types: tuple[jax._src.mesh.AxisType, ...] | None = None)[source]#
Bases:
BaseMesh,ContextDecoratorDeclare the hardware resources available in the scope of this manager.
See Distributed arrays and automatic parallelization and Explicit Sharding tutorials.
- Parameters
devices – A NumPy ndarray object containing JAX device objects (as obtained e.g. from
jax.devices()).axis_names – A sequence of resource axis names to be assigned to the dimensions of the
devicesargument. Its length should match the rank ofdevices.axis_types – and optional tuple of
jax.sharding.AxisTypeentries corresponding to theaxis_names. See Explicit Sharding for more information.
Examples
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> devices = np.array(jax.devices()).reshape(4, 2) >>> mesh = Mesh(devices, ('x', 'y')) >>> inp = np.arange(16).reshape(8, 2) >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) >>> out = jax.jit(lambda x: x * 2)(arr) >>> assert out.sharding == NamedSharding(mesh, P('x', 'y'))
- property abstract_mesh#
- axis_names: tuple[Any, ...]#
- property axis_sizes: tuple[int, ...]#
- property device_ids#
- property empty#
- property is_multi_process#
- property local_devices#
- property local_mesh#
- property shape#
- property shape_tuple#
- property size#
- eformer.common_types.PRNGKey#
Type alias for JAX PRNG keys.
- eformer.common_types.QUERY_LENGTH = '__QUERY_LENGTH__'#
Semantic axis name for the query sequence length dimension.
- eformer.common_types.RUNTIME_MODE_TYPES#
Type alias for the possible runtime modes.
alias of
Literal[‘__autoregressive__’, ‘__prefill__’, ‘__train__’, ‘__insert__’]
- class eformer.common_types.Replicated(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for Column Wise sharding.
- axes: ClassVar = ['_']#
- mode: ClassVar = '__train__'#
- class eformer.common_types.RowWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for Row Wise sharding.
- axes: ClassVar = ['__TENSOR_PARALLEL__', ['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__']]#
- mode: ClassVar = '__train__'#
- class eformer.common_types.SColumnWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for Column Wise sharding.
- axes: ClassVar = [['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__']]#
- mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int] = '__train__'#
- class eformer.common_types.SRowWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesDynamic sharding specification for Row Wise sharding.
- axes: ClassVar = ['__TENSOR_PARALLEL__']#
- mode: ClassVar = '__train__'#
- eformer.common_types.Shape#
Type alias for array shapes.
alias of
Sequence[int]
- class eformer.common_types.UnifiedExpertColumnWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesUnified column-wise sharding using SP for sequence-related parameters. Use this if your expert weights have a sequence-related dimension.
- axes: ClassVar = [['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__', '__EXPERT_PARALLEL__'], '_', '__TENSOR_PARALLEL__']#
- mode: ClassVar = '__train__'#
- class eformer.common_types.UnifiedExpertRowWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#
Bases:
DynamicShardingAxesUnified row-wise sharding using SP for sequence-related parameters. Use this if your expert weights have a sequence-related dimension.
- axes: ClassVar = [['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__', '__EXPERT_PARALLEL__'], '__TENSOR_PARALLEL__', '_']#
- mode: ClassVar = '__train__'#
- eformer.common_types.VOCAB = '__VOCAB__'#
Semantic axis name for the vocabulary dimension.