eformer.optimizers._stage_local#

eformer.optimizers._stage_local.StageLocalApplyFn#

Type alias for a stage-local apply function.

A callable with keyword-only arguments params, grads, opt_state, learning_rate_fn (optional), and delete_grads (optional) that returns a tuple of (new_params, new_opt_state).

alias of Callable[[…], tuple[Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]

class eformer.optimizers._stage_local.StageLocalGradientTransformation(init: TransformInitFn, update: TransformUpdateFn)[source]#

Bases: GradientTransformation

Optax transformation with an explicit PP stage-local apply API.

This class wraps a standard optax.GradientTransformation and attaches an additional apply_gradients_stage_local entry-point that pipeline-parallel training loops can call when parameters and gradients are partitioned per-stage. The normal update() path is preserved, so the same object works seamlessly with regular Optax code paths.

The stage-local path relies on a custom _eformer_stage_local_apply attribute attached to the internal update callable. If that attribute is missing, calling apply_gradients_stage_local() raises a clear NotImplementedError.

apply_gradients_stage_local(*, params: optax.Params, grads: optax.Updates, opt_state: optax.OptState, learning_rate_fn: optax.Schedule | None = None, delete_grads: bool = False) tuple[optax.Params, optax.OptState][source]#

Apply gradients leafwise without whole-tree cross-stage math.

This method dispatches to the stage-local kernel that was attached at construction time. It is intended for scheduled pipeline-parallel training where each stage only has a local view of parameters and gradients.

Parameters
  • params – Current model parameters (pytree). Must be partitioned to the same devices as the stage-local gradient buffers.

  • grads – Gradient updates (pytree) with the same structure as params. None leaves are interpreted as zero gradients.

  • opt_state – Optimizer state produced by init() or previous calls to update() / apply_gradients_stage_local().

  • learning_rate_fn – Optional schedule override. When None, the schedule stored in StageLocalOptimizerMetadata.scheduler is used.

  • delete_grads – If True, a best-effort deletion of gradient arrays is performed after the update to reduce peak memory usage.

Returns

A tuple (new_params, new_opt_state) with updated values.

Raises

NotImplementedError – If the underlying transformation was not built with a stage-local apply function.

class eformer.optimizers._stage_local.StageLocalOptimizerMetadata(scheduler: ~collections.abc.Callable[[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, bool, int, float, complex]], ~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, bool, int, float, complex]], weight_decay: float = 0.0, weight_decay_mask: Any | None = None, gradient_accumulation_steps: int = 1, clip_grad: float | None = None, adamw_b1: float | None = None, adamw_b2: float | None = None, adamw_eps: float | None = None, adamw_eps_root: float | None = None, adamw_mu_dtype: Any | None = None, optimizer_config: Any | None = None, extra_kwargs: ~typing.Mapping[str, ~typing.Any] = <factory>)[source]#

Bases: object

Immutable metadata container for pipeline-parallel stage-local optimizer updates.

This dataclass captures all hyperparameters and configuration needed to reconstruct and apply optimizer updates leaf-by-leaf inside a single pipeline stage without whole-tree cross-stage communication. It is used by the stage-local apply functions to correctly schedule learning rates, apply weight decay, and clip gradients while respecting per-parameter masks and accumulation settings.

scheduler#

The learning-rate schedule that was paired with the optimizer at construction time.

Type

collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]]

weight_decay#

Global weight-decay coefficient applied externally through optax_add_scheduled_weight_decay().

Type

float

weight_decay_mask#

Optional pytree or callable mask controlling which parameters receive external weight decay. None means all parameters are decayed.

Type

Any | None

gradient_accumulation_steps#

Number of micro-steps accumulated before applying an update. Values > 1 are currently unsupported for stage-local paths and will raise at runtime.

Type

int

clip_grad#

Optional global gradient-norm clipping threshold applied before the base optimizer update.

Type

float | None

adamw_b1#

AdamW first-moment decay (b1) when the underlying optimizer is AdamW-like. None for non-AdamW optimizers.

Type

float | None

adamw_b2#

AdamW second-moment decay (b2) when the underlying optimizer is AdamW-like. None for non-AdamW optimizers.

Type

float | None

adamw_eps#

AdamW epsilon for numerical stability. None for non-AdamW.

Type

float | None

adamw_eps_root#

AdamW epsilon applied inside the square-root. None for non-AdamW.

Type

float | None

adamw_mu_dtype#

Optional dtype for the first-moment buffer. None for non-AdamW or when the default dtype is acceptable.

Type

Any | None

optimizer_config#

The original optimizer-specific configuration object (e.g. AdamWConfig). This gives stage-local kernels access to extra hyperparameters that are not explicitly mirrored above.

Type

Any | None

extra_kwargs#

Additional MPMD optimizer options forwarded by the builder/factory. Optimizer-specific stage-local kernels can read this mapping for extension arguments without changing this dataclass every time.

Type

Mapping[str, Any]

adamw_b1: float | None = None#
adamw_b2: float | None = None#
adamw_eps: float | None = None#
adamw_eps_root: float | None = None#
adamw_mu_dtype: Any | None = None#
clip_grad: float | None = None#
extra_kwargs: Mapping[str, Any]#
gradient_accumulation_steps: int = 1#
optimizer_config: Any | None = None#
scheduler: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]]#
weight_decay: float = 0.0#
weight_decay_mask: Any | None = None#
eformer.optimizers._stage_local.make_stage_local_gradient_transformation(tx: optax.GradientTransformation, metadata: StageLocalOptimizerMetadata | None = None, apply_fn: StageLocalApplyFn | None = None) StageLocalGradientTransformation[source]#

Attach an explicit stage-local MPMD apply path to an Optax transform.

This function wraps an existing optax.GradientTransformation so that it also exposes StageLocalGradientTransformation.apply_gradients_stage_local(). The normal tx.init and tx.update callables are forwarded verbatim, while the extra apply_fn is stashed as an attribute on the internal update function for later retrieval.

Optimizer builders pass their optimizer-specific stage-local kernel directly. When metadata is also provided, this helper injects that metadata and applies factory-level gradient clipping before calling the kernel.

Parameters
  • tx – Base optax transformation to wrap.

  • metadata – Hyperparameter metadata used by eFormer-provided stage-local kernels. Optional for fully custom apply_fn callables.

  • apply_fn – Explicit stage-local apply callable.

Returns

A StageLocalGradientTransformation that behaves like tx for standard Optax calls but also supports stage-local updates.

Raises

ValueError – If apply_fn is None.

eformer.optimizers._stage_local.make_unsupported_stage_local_gradient_transformation(tx: GradientTransformation, *, optimizer_type: str, reason: str | None = None) StageLocalGradientTransformation[source]#

Expose a clear PP error while preserving the normal Optax update path.

Use this helper when an optimizer builder does not override OptimizerBuilder.build_mpmd(). The returned transformation still works via the standard optax.GradientTransformation.update(), but calling StageLocalGradientTransformation.apply_gradients_stage_local() raises a descriptive NotImplementedError telling users exactly which optimizer is unsupported and how to fix it.

Parameters
  • tx – Base optax transformation to wrap.

  • optimizer_type – Registered name of the unsupported optimizer (used only for the error message).

  • reason – Optional extra detail appended to the error message.

Returns

A StageLocalGradientTransformation whose stage-local path always raises NotImplementedError.