eformer.optimizers._base#

class eformer.optimizers._base.OptimizerBuilder(config: Any)[source]#

Bases: ABC

Abstract base class for optimizer builders.

Optimizer builders encapsulate the configuration and construction logic for creating optax GradientTransformation objects.

config#

Optimizer-specific configuration object.

Type

Any

build()[source]#

Creates the base optimizer transformation.

validate()[source]#

Optional validation hook called before building.

abstract build(scheduler: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]]) GradientTransformation[source]#

Build the base optimizer transformation.

Parameters

scheduler – Learning rate schedule to use.

Returns

The optimizer transformation.

Return type

optax.GradientTransformation

build_mpmd(scheduler: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]], *, optimizer: GradientTransformation, **tx_kwargs: Any) GradientTransformation[source]#

Build the MPMD/pipeline-parallel optimizer transformation.

Registered optimizers can override this hook to expose an explicit stage-local update API for scheduled pipeline-parallel training while preserving the normal build() path for regular Optax use.

When overridden, this method should return a StageLocalGradientTransformation (or any optax.GradientTransformation whose update callable carries a _eformer_stage_local_apply attribute) so that the factory can dispatch stage-local gradient applications during PP training.

Parameters
  • scheduler – Learning rate schedule paired with the optimizer.

  • optimizer – The fully assembled optimizer chain (clip, base, weight decay, multi-step) produced by the factory.

  • **tx_kwargs – Factory-level transformation options such as weight_decay, weight_decay_mask, gradient_accumulation_steps, and clip_grad.

Returns

A optax.GradientTransformation that supports both the standard update path and the stage-local apply_gradients_stage_local path when appropriate.

Raises

NotImplementedError – By default, indicating that the optimizer does not yet provide PP stage-local semantics.

config: Any#
validate() None[source]#

Optional validation hook called before building the optimizer.

Raises

ValueError – If the configuration is invalid.

class eformer.optimizers._base.SchedulerBuilder(config: SchedulerConfig)[source]#

Bases: ABC

Abstract base class for scheduler builders.

Scheduler builders encapsulate the configuration and construction logic for creating optax Schedule objects.

config#

Scheduler configuration object.

Type

eformer.optimizers._config.SchedulerConfig

build()[source]#

Creates the learning rate schedule.

abstract build() Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]][source]#

Build the learning rate schedule.

Returns

The learning rate schedule.

Return type

optax.Schedule

config: SchedulerConfig#
eformer.optimizers._base.register_optimizer(name: str) Callable[[type[eformer.optimizers._base.OptimizerBuilder]], type[eformer.optimizers._base.OptimizerBuilder]][source]#

Decorator to register an optimizer builder class.

Parameters

name – Name to register the optimizer under.

Returns

Decorator function that registers the class.

Example

@register_optimizer(“adamw”) @dataclass class AdamWOptimizer(OptimizerBuilder):

config: AdamWConfig

def build(self, scheduler):

return optax.adamw(learning_rate=scheduler, …)

eformer.optimizers._base.register_scheduler(name: str) Callable[[type[eformer.optimizers._base.SchedulerBuilder]], type[eformer.optimizers._base.SchedulerBuilder]][source]#

Decorator to register a scheduler builder class.

Parameters

name – Name to register the scheduler under.

Returns

Decorator function that registers the class.

Example

@register_scheduler(“cosine”) @dataclass class CosineSchedulerBuilder(SchedulerBuilder):

config: SchedulerConfig

def build(self):

return optax.cosine_decay_schedule(…)