eformer.optimizers._stage_local#
- eformer.optimizers._stage_local.StageLocalApplyFn#
Type alias for a stage-local apply function.
A callable with keyword-only arguments
params,grads,opt_state,learning_rate_fn(optional), anddelete_grads(optional) that returns a tuple of(new_params, new_opt_state).alias of
Callable[[…],tuple[Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any,ArrayTree]],Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any,ArrayTree]]]]
- class eformer.optimizers._stage_local.StageLocalGradientTransformation(init: TransformInitFn, update: TransformUpdateFn)[source]#
Bases:
GradientTransformationOptax transformation with an explicit PP stage-local apply API.
This class wraps a standard
optax.GradientTransformationand attaches an additionalapply_gradients_stage_localentry-point that pipeline-parallel training loops can call when parameters and gradients are partitioned per-stage. The normalupdate()path is preserved, so the same object works seamlessly with regular Optax code paths.The stage-local path relies on a custom
_eformer_stage_local_applyattribute attached to the internalupdatecallable. If that attribute is missing, callingapply_gradients_stage_local()raises a clearNotImplementedError.- apply_gradients_stage_local(*, params: optax.Params, grads: optax.Updates, opt_state: optax.OptState, learning_rate_fn: optax.Schedule | None = None, delete_grads: bool = False) tuple[optax.Params, optax.OptState][source]#
Apply gradients leafwise without whole-tree cross-stage math.
This method dispatches to the stage-local kernel that was attached at construction time. It is intended for scheduled pipeline-parallel training where each stage only has a local view of parameters and gradients.
- Parameters
params – Current model parameters (pytree). Must be partitioned to the same devices as the stage-local gradient buffers.
grads – Gradient updates (pytree) with the same structure as
params.Noneleaves are interpreted as zero gradients.opt_state – Optimizer state produced by
init()or previous calls toupdate()/apply_gradients_stage_local().learning_rate_fn – Optional schedule override. When
None, the schedule stored inStageLocalOptimizerMetadata.scheduleris used.delete_grads – If
True, a best-effort deletion of gradient arrays is performed after the update to reduce peak memory usage.
- Returns
A tuple
(new_params, new_opt_state)with updated values.- Raises
NotImplementedError – If the underlying transformation was not built with a stage-local apply function.
- class eformer.optimizers._stage_local.StageLocalOptimizerMetadata(scheduler: ~collections.abc.Callable[[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, bool, int, float, complex]], ~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, bool, int, float, complex]], weight_decay: float = 0.0, weight_decay_mask: Any | None = None, gradient_accumulation_steps: int = 1, clip_grad: float | None = None, adamw_b1: float | None = None, adamw_b2: float | None = None, adamw_eps: float | None = None, adamw_eps_root: float | None = None, adamw_mu_dtype: Any | None = None, optimizer_config: Any | None = None, extra_kwargs: ~typing.Mapping[str, ~typing.Any] = <factory>)[source]#
Bases:
objectImmutable metadata container for pipeline-parallel stage-local optimizer updates.
This dataclass captures all hyperparameters and configuration needed to reconstruct and apply optimizer updates leaf-by-leaf inside a single pipeline stage without whole-tree cross-stage communication. It is used by the stage-local apply functions to correctly schedule learning rates, apply weight decay, and clip gradients while respecting per-parameter masks and accumulation settings.
- scheduler#
The learning-rate schedule that was paired with the optimizer at construction time.
- Type
collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]]
- weight_decay#
Global weight-decay coefficient applied externally through
optax_add_scheduled_weight_decay().- Type
float
- weight_decay_mask#
Optional pytree or callable mask controlling which parameters receive external weight decay.
Nonemeans all parameters are decayed.- Type
Any | None
- gradient_accumulation_steps#
Number of micro-steps accumulated before applying an update. Values
> 1are currently unsupported for stage-local paths and will raise at runtime.- Type
int
- clip_grad#
Optional global gradient-norm clipping threshold applied before the base optimizer update.
- Type
float | None
- adamw_b1#
AdamW first-moment decay (
b1) when the underlying optimizer is AdamW-like.Nonefor non-AdamW optimizers.- Type
float | None
- adamw_b2#
AdamW second-moment decay (
b2) when the underlying optimizer is AdamW-like.Nonefor non-AdamW optimizers.- Type
float | None
- adamw_eps#
AdamW epsilon for numerical stability.
Nonefor non-AdamW.- Type
float | None
- adamw_eps_root#
AdamW epsilon applied inside the square-root.
Nonefor non-AdamW.- Type
float | None
- adamw_mu_dtype#
Optional dtype for the first-moment buffer.
Nonefor non-AdamW or when the default dtype is acceptable.- Type
Any | None
- optimizer_config#
The original optimizer-specific configuration object (e.g.
AdamWConfig). This gives stage-local kernels access to extra hyperparameters that are not explicitly mirrored above.- Type
Any | None
- extra_kwargs#
Additional MPMD optimizer options forwarded by the builder/factory. Optimizer-specific stage-local kernels can read this mapping for extension arguments without changing this dataclass every time.
- Type
Mapping[str, Any]
- adamw_b1: float | None = None#
- adamw_b2: float | None = None#
- adamw_eps: float | None = None#
- adamw_eps_root: float | None = None#
- adamw_mu_dtype: Any | None = None#
- clip_grad: float | None = None#
- extra_kwargs: Mapping[str, Any]#
- gradient_accumulation_steps: int = 1#
- optimizer_config: Any | None = None#
- scheduler: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]]#
- weight_decay: float = 0.0#
- weight_decay_mask: Any | None = None#
- eformer.optimizers._stage_local.make_stage_local_gradient_transformation(tx: optax.GradientTransformation, metadata: StageLocalOptimizerMetadata | None = None, apply_fn: StageLocalApplyFn | None = None) StageLocalGradientTransformation[source]#
Attach an explicit stage-local MPMD apply path to an Optax transform.
This function wraps an existing
optax.GradientTransformationso that it also exposesStageLocalGradientTransformation.apply_gradients_stage_local(). The normaltx.initandtx.updatecallables are forwarded verbatim, while the extraapply_fnis stashed as an attribute on the internal update function for later retrieval.Optimizer builders pass their optimizer-specific stage-local kernel directly. When
metadatais also provided, this helper injects that metadata and applies factory-level gradient clipping before calling the kernel.- Parameters
tx – Base optax transformation to wrap.
metadata – Hyperparameter metadata used by eFormer-provided stage-local kernels. Optional for fully custom
apply_fncallables.apply_fn – Explicit stage-local apply callable.
- Returns
A
StageLocalGradientTransformationthat behaves liketxfor standard Optax calls but also supports stage-local updates.- Raises
ValueError – If
apply_fnisNone.
- eformer.optimizers._stage_local.make_unsupported_stage_local_gradient_transformation(tx: GradientTransformation, *, optimizer_type: str, reason: str | None = None) StageLocalGradientTransformation[source]#
Expose a clear PP error while preserving the normal Optax update path.
Use this helper when an optimizer builder does not override
OptimizerBuilder.build_mpmd(). The returned transformation still works via the standardoptax.GradientTransformation.update(), but callingStageLocalGradientTransformation.apply_gradients_stage_local()raises a descriptiveNotImplementedErrortelling users exactly which optimizer is unsupported and how to fix it.- Parameters
tx – Base optax transformation to wrap.
optimizer_type – Registered name of the unsupported optimizer (used only for the error message).
reason – Optional extra detail appended to the error message.
- Returns
A
StageLocalGradientTransformationwhose stage-local path always raisesNotImplementedError.