Source code for eformer.optimizers._base

# Copyright 2026 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import dataclasses
import typing as tp
from abc import ABC, abstractmethod

import optax

from ._config import SchedulerConfig


[docs]@dataclasses.dataclass class OptimizerBuilder(ABC): """ Abstract base class for optimizer builders. Optimizer builders encapsulate the configuration and construction logic for creating optax GradientTransformation objects. Attributes: config: Optimizer-specific configuration object. Methods: build: Creates the base optimizer transformation. validate: Optional validation hook called before building. """ config: tp.Any # Will be overridden with specific config type in subclasses
[docs] @abstractmethod def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """ Build the base optimizer transformation. Args: scheduler: Learning rate schedule to use. Returns: optax.GradientTransformation: The optimizer transformation. """ pass
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the MPMD/pipeline-parallel optimizer transformation. Registered optimizers can override this hook to expose an explicit stage-local update API for scheduled pipeline-parallel training while preserving the normal :meth:`build` path for regular Optax use. When overridden, this method should return a :class:`StageLocalGradientTransformation` (or any :class:`optax.GradientTransformation` whose ``update`` callable carries a ``_eformer_stage_local_apply`` attribute) so that the factory can dispatch stage-local gradient applications during PP training. Args: scheduler: Learning rate schedule paired with the optimizer. optimizer: The fully assembled optimizer chain (clip, base, weight decay, multi-step) produced by the factory. **tx_kwargs: Factory-level transformation options such as ``weight_decay``, ``weight_decay_mask``, ``gradient_accumulation_steps``, and ``clip_grad``. Returns: A :class:`optax.GradientTransformation` that supports both the standard ``update`` path and the stage-local ``apply_gradients_stage_local`` path when appropriate. Raises: NotImplementedError: By default, indicating that the optimizer does not yet provide PP stage-local semantics. """ del scheduler, optimizer, tx_kwargs raise NotImplementedError( f"{self.__class__.__name__} does not implement build_mpmd(...). " "Override OptimizerBuilder.build_mpmd to provide PP stage-local optimizer semantics." )
[docs] def validate(self) -> None: # noqa """ Optional validation hook called before building the optimizer. Raises: ValueError: If the configuration is invalid. """ pass
[docs]@dataclasses.dataclass class SchedulerBuilder(ABC): """ Abstract base class for scheduler builders. Scheduler builders encapsulate the configuration and construction logic for creating optax Schedule objects. Attributes: config: Scheduler configuration object. Methods: build: Creates the learning rate schedule. """ config: SchedulerConfig
[docs] @abstractmethod def build(self) -> optax.Schedule: """ Build the learning rate schedule. Returns: optax.Schedule: The learning rate schedule. """ pass
# Registry dictionaries _OPTIMIZER_BUILDER_REGISTRY: dict[str, type[OptimizerBuilder]] = {} _SCHEDULER_BUILDER_REGISTRY: dict[str, type[SchedulerBuilder]] = {}
[docs]def register_optimizer(name: str) -> tp.Callable[[type[OptimizerBuilder]], type[OptimizerBuilder]]: """ Decorator to register an optimizer builder class. Args: name: Name to register the optimizer under. Returns: Decorator function that registers the class. Example: @register_optimizer("adamw") @dataclass class AdamWOptimizer(OptimizerBuilder): config: AdamWConfig def build(self, scheduler): return optax.adamw(learning_rate=scheduler, ...) """ def decorator(cls: type[OptimizerBuilder]) -> type[OptimizerBuilder]: if name in _OPTIMIZER_BUILDER_REGISTRY: raise ValueError(f"Optimizer '{name}' is already registered") _OPTIMIZER_BUILDER_REGISTRY[name] = cls return cls return decorator
[docs]def register_scheduler(name: str) -> tp.Callable[[type[SchedulerBuilder]], type[SchedulerBuilder]]: """ Decorator to register a scheduler builder class. Args: name: Name to register the scheduler under. Returns: Decorator function that registers the class. Example: @register_scheduler("cosine") @dataclass class CosineSchedulerBuilder(SchedulerBuilder): config: SchedulerConfig def build(self): return optax.cosine_decay_schedule(...) """ def decorator(cls: type[SchedulerBuilder]) -> type[SchedulerBuilder]: if name in _SCHEDULER_BUILDER_REGISTRY: raise ValueError(f"Scheduler '{name}' is already registered") _SCHEDULER_BUILDER_REGISTRY[name] = cls return cls return decorator