eformer.optimizers._base#
- class eformer.optimizers._base.OptimizerBuilder(config: Any)[source]#
Bases:
ABCAbstract base class for optimizer builders.
Optimizer builders encapsulate the configuration and construction logic for creating optax GradientTransformation objects.
- config#
Optimizer-specific configuration object.
- Type
Any
- abstract build(scheduler: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]]) GradientTransformation[source]#
Build the base optimizer transformation.
- Parameters
scheduler – Learning rate schedule to use.
- Returns
The optimizer transformation.
- Return type
optax.GradientTransformation
- build_mpmd(scheduler: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]], *, optimizer: GradientTransformation, **tx_kwargs: Any) GradientTransformation[source]#
Build the MPMD/pipeline-parallel optimizer transformation.
Registered optimizers can override this hook to expose an explicit stage-local update API for scheduled pipeline-parallel training while preserving the normal
build()path for regular Optax use.When overridden, this method should return a
StageLocalGradientTransformation(or anyoptax.GradientTransformationwhoseupdatecallable carries a_eformer_stage_local_applyattribute) so that the factory can dispatch stage-local gradient applications during PP training.- Parameters
scheduler – Learning rate schedule paired with the optimizer.
optimizer – The fully assembled optimizer chain (clip, base, weight decay, multi-step) produced by the factory.
**tx_kwargs – Factory-level transformation options such as
weight_decay,weight_decay_mask,gradient_accumulation_steps, andclip_grad.
- Returns
A
optax.GradientTransformationthat supports both the standardupdatepath and the stage-localapply_gradients_stage_localpath when appropriate.- Raises
NotImplementedError – By default, indicating that the optimizer does not yet provide PP stage-local semantics.
- config: Any#
- class eformer.optimizers._base.SchedulerBuilder(config: SchedulerConfig)[source]#
Bases:
ABCAbstract base class for scheduler builders.
Scheduler builders encapsulate the configuration and construction logic for creating optax Schedule objects.
- config#
Scheduler configuration object.
- abstract build() Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]][source]#
Build the learning rate schedule.
- Returns
The learning rate schedule.
- Return type
optax.Schedule
- config: SchedulerConfig#
- eformer.optimizers._base.register_optimizer(name: str) Callable[[type[eformer.optimizers._base.OptimizerBuilder]], type[eformer.optimizers._base.OptimizerBuilder]][source]#
Decorator to register an optimizer builder class.
- Parameters
name – Name to register the optimizer under.
- Returns
Decorator function that registers the class.
Example
@register_optimizer(“adamw”) @dataclass class AdamWOptimizer(OptimizerBuilder):
config: AdamWConfig
- def build(self, scheduler):
return optax.adamw(learning_rate=scheduler, …)
- eformer.optimizers._base.register_scheduler(name: str) Callable[[type[eformer.optimizers._base.SchedulerBuilder]], type[eformer.optimizers._base.SchedulerBuilder]][source]#
Decorator to register a scheduler builder class.
- Parameters
name – Name to register the scheduler under.
- Returns
Decorator function that registers the class.
Example
@register_scheduler(“cosine”) @dataclass class CosineSchedulerBuilder(SchedulerBuilder):
config: SchedulerConfig
- def build(self):
return optax.cosine_decay_schedule(…)