eformer.common_types

Contents

eformer.common_types#

This module defines common types, constants, and named tuples used across the eformer library, particularly for JAX and sharding configurations.

eformer.common_types.Array#

Type alias for JAX arrays.

class eformer.common_types.AttnKVSharding(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for attention keys/values.

axes: ClassVar = ['__BATCH__', '__KV_LENGTH__', '__KV_HEAD__', '__KV_HEAD_DIM__']#
mode: ClassVar = 1#
class eformer.common_types.AttnQSharding(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for attention queries.

axes: ClassVar = ['__BATCH__', '__QUERY_LENGTH__', '__HEAD__', '__HEAD_DIM__']#
mode: ClassVar = 1#
eformer.common_types.AxisIdxes#

Type alias for a tuple of axis indices.

alias of tuple[int, …]

eformer.common_types.AxisNames#

Type alias for a tuple of mesh axis names.

alias of tuple[str, …]

eformer.common_types.AxisType = tuple[str, ...] | str | typing.Any | None#

Type alias for a mesh axis specification.

Can be a single string (axis name), a tuple of strings, None (for no sharding), or potentially other types depending on context (though typically str or tuple[str, …]).

eformer.common_types.BATCH = '__BATCH__'#

Semantic axis name for the batch dimension.

eformer.common_types.BIAS_HEAD_SEQ = '__BIAS_HEAD_SEQ__'#

Semantic axis name for bias related to head and sequence dimensions.

eformer.common_types.BIAS_KV_SEQ = '__BIAS_KV_SEQ__'#

Semantic axis name for bias related to key/value sequence dimensions.

class eformer.common_types.ColumnWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for Column Wise sharding.

axes: ClassVar = [['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__'], '__TENSOR_PARALLEL__']#
mode: ClassVar = '__train__'#
eformer.common_types.DEFAULT_MASK_VALUE = -2.381976426469702e+38#

Default value used for masking, typically in attention mechanisms.

eformer.common_types.DType#

Type alias for JAX data types.

class eformer.common_types.DynamicShardingAxes(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: NamedTuple

A NamedTuple to define sharding axes and mode dynamically.

Used to specify sharding based on the runtime mode or other dynamic factors.

axes#

A sequence of semantic axis names or None.

Type

Sequence[str | None]

mode#

The runtime mode (string constant) or an integer representing the dimension index to check for mode inference.

Type

Union[Literal[‘__autoregressive__’, ‘__prefill__’, ‘__train__’, ‘__insert__’], int]

axes: Sequence[str | None]#

Alias for field number 0

mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int]#

Alias for field number 1

eformer.common_types.EMBED = '__EMBED__'#

Semantic axis name for the embedding or hidden state dimension.

eformer.common_types.EMPTY_VAL#

Sentinel value indicating that a parameter was not explicitly provided.

eformer.common_types.EXPERT = '__EXPERT__'#

Semantic axis name for the expert dimension in Mixture-of-Experts models.

eformer.common_types.EXPERT_GATE = '__EXPERT_GATE__'#

Semantic axis name for the expert gate dimension in Mixture-of-Experts models.

class eformer.common_types.ExpertActivations(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Sharding for expert activation tensors of shape [batch, sequence, num_experts, hidden]. - Batch dimension: DP (data parallel) - Sequence dimension: SP (sequence parallel) - Expert dimension: EP (expert parallel) - Hidden dimension: TP (tensor parallel) or FSDP

axes: ClassVar = ['__DATA_PARALLEL__', '__SEQUENCE_PARALLEL__', '__EXPERT_PARALLEL__', '__TENSOR_PARALLEL__']#
mode: ClassVar = '__train__'#
class eformer.common_types.ExpertActivationsAlt(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Alternative activation sharding for shape [batch, sequence, hidden]. When experts are already selected/routed.

axes: ClassVar = ['__DATA_PARALLEL__', '__SEQUENCE_PARALLEL__', ['__TENSOR_PARALLEL__', '__FULLY_SHARDED_DATA_PARALLEL__']]#
mode: ClassVar = '__train__'#
class eformer.common_types.ExpertColumnWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for Column Wise sharding.

For a typical expert layer weight tensor of shape [num_experts, hidden_size, intermediate_size]: - Dimension 0 (num_experts): Shard across EP (expert parallel) - Dimension 1 (hidden_size): Shard across FSDP (parameter sharding) - Dimension 2 (intermediate_size): Shard across TP (tensor parallel - column-wise)

DP is used for batch dimension in activations, SP for sequence length.

axes: ClassVar = ['__EXPERT_PARALLEL__', '__FULLY_SHARDED_DATA_PARALLEL__', '__TENSOR_PARALLEL__']#
mode: ClassVar = '__train__'#
class eformer.common_types.ExpertColumnWiseAlt(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Alternative column-wise sharding using SP for sequence-related parameters. Use this if your expert weights have a sequence-related dimension.

axes: ClassVar = ['__EXPERT_PARALLEL__', ['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__'], '__TENSOR_PARALLEL__']#
mode: ClassVar = '__train__'#
class eformer.common_types.ExpertRowWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for Row Wise sharding.

For a typical expert layer weight tensor of shape [num_experts, intermediate_size, hidden_size]: - Dimension 0 (num_experts): Shard across EP (expert parallel) - Dimension 1 (intermediate_size): Shard across TP (tensor parallel - row-wise) - Dimension 2 (hidden_size): Shard across FSDP (parameter sharding)

DP is used for batch dimension in activations, SP for sequence length.

axes: ClassVar = ['__EXPERT_PARALLEL__', '__TENSOR_PARALLEL__', '__FULLY_SHARDED_DATA_PARALLEL__']#
mode: ClassVar = '__train__'#
class eformer.common_types.ExpertRowWiseAlt(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Alternative row-wise sharding using SP for sequence-related parameters. Use this if your expert weights have a sequence-related dimension.

axes: ClassVar = ['__EXPERT_PARALLEL__', '__TENSOR_PARALLEL__', ['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__']]#
mode: ClassVar = '__train__'#
class eformer.common_types.ExpertTensorParallel(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Expert Tensor Parallelism (EPxTP) sharding axes.

axes: ClassVar = ['__TENSOR_PARALLEL__', '_', '_']#
mode: ClassVar = '__train__'#
eformer.common_types.GENERATION_MODES = {'__autoregressive__', '__insert__'}#

Set of runtime modes considered as generation modes.

class eformer.common_types.HiddenStateSharding(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for hidden states.

axes: ClassVar = ['__BATCH__', '__QUERY_LENGTH__', '__EMBED__']#
mode: ClassVar = 1#
eformer.common_types.KV_HEAD = '__KV_HEAD__'#

Semantic axis name for the attention head dimension.

eformer.common_types.KV_HEAD_DIM = '__KV_HEAD_DIM__'#

Semantic axis name for the dimension within each attention head.

eformer.common_types.KV_LENGTH = '__KV_LENGTH__'#

Semantic axis name for the key/value sequence length dimension.

eformer.common_types.LENGTH = '__LENGTH__'#

Semantic axis name for the sequence length dimension (general).

eformer.common_types.MLP_INTERMEDIATE = '__MLP_INTERMEDIATE__'#

Semantic axis name for the intermediate dimension in MLP layers.

eformer.common_types.MODE_DECODE = '__autoregressive__'#

Runtime mode for autoregressive decoding.

eformer.common_types.MODE_INSERT = '__insert__'#

Runtime mode for inserting into the cache.

eformer.common_types.MODE_PREFILL = '__prefill__'#

Runtime mode for prefilling the cache.

eformer.common_types.MODE_TRAIN = '__train__'#

Runtime mode for training.

class eformer.common_types.Mesh(devices: numpy.ndarray | collections.abc.Sequence[jaxlib._jax.Device], axis_names: str | collections.abc.Sequence[Any], axis_types: tuple[jax._src.mesh.AxisType, ...] | None = None)[source]#

Bases: BaseMesh, ContextDecorator

Declare the hardware resources available in the scope of this manager.

See Distributed arrays and automatic parallelization and Explicit Sharding tutorials.

Parameters
  • devices – A NumPy ndarray object containing JAX device objects (as obtained e.g. from jax.devices()).

  • axis_names – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

  • axis_types – and optional tuple of jax.sharding.AxisType entries corresponding to the axis_names. See Explicit Sharding for more information.

Examples

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P, NamedSharding
>>> import numpy as np
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> devices = np.array(jax.devices()).reshape(4, 2)
>>> mesh = Mesh(devices, ('x', 'y'))
>>> inp = np.arange(16).reshape(8, 2)
>>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y')))
>>> out = jax.jit(lambda x: x * 2)(arr)
>>> assert out.sharding == NamedSharding(mesh, P('x', 'y'))
property abstract_mesh#
axis_names: tuple[Any, ...]#
property axis_sizes: tuple[int, ...]#
property device_ids#
devices: ndarray#
property empty#
property is_multi_process#
property is_scalar#
property local_devices#
property local_mesh#
property shape#
property shape_tuple#
size: int#
update(devices=None, axis_names=None, axis_types=None)[source]#
eformer.common_types.PRNGKey#

Type alias for JAX PRNG keys.

eformer.common_types.QUERY_LENGTH = '__QUERY_LENGTH__'#

Semantic axis name for the query sequence length dimension.

eformer.common_types.RUNTIME_MODE_TYPES#

Type alias for the possible runtime modes.

alias of Literal[‘__autoregressive__’, ‘__prefill__’, ‘__train__’, ‘__insert__’]

class eformer.common_types.Replicated(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for Column Wise sharding.

axes: ClassVar = ['_']#
mode: ClassVar = '__train__'#
class eformer.common_types.RowWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for Row Wise sharding.

axes: ClassVar = ['__TENSOR_PARALLEL__', ['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__']]#
mode: ClassVar = '__train__'#
class eformer.common_types.SColumnWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for Column Wise sharding.

axes: ClassVar = [['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__']]#
mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int] = '__train__'#
class eformer.common_types.SRowWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Dynamic sharding specification for Row Wise sharding.

axes: ClassVar = ['__TENSOR_PARALLEL__']#
mode: ClassVar = '__train__'#
eformer.common_types.Shape#

Type alias for array shapes.

alias of Sequence[int]

class eformer.common_types.UnifiedExpertColumnWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Unified column-wise sharding using SP for sequence-related parameters. Use this if your expert weights have a sequence-related dimension.

axes: ClassVar = [['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__', '__EXPERT_PARALLEL__'], '_', '__TENSOR_PARALLEL__']#
mode: ClassVar = '__train__'#
class eformer.common_types.UnifiedExpertRowWise(axes: Sequence[str | None], mode: Union[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], int])[source]#

Bases: DynamicShardingAxes

Unified row-wise sharding using SP for sequence-related parameters. Use this if your expert weights have a sequence-related dimension.

axes: ClassVar = [['__FULLY_SHARDED_DATA_PARALLEL__', '__SEQUENCE_PARALLEL__', '__EXPERT_PARALLEL__'], '__TENSOR_PARALLEL__', '_']#
mode: ClassVar = '__train__'#
eformer.common_types.VOCAB = '__VOCAB__'#

Semantic axis name for the vocabulary dimension.