Source code for eformer.optimizers._base
# Copyright 2025 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import typing as tp
from abc import ABC, abstractmethod
import optax
from ._config import SchedulerConfig
[docs]@dataclasses.dataclass
class OptimizerBuilder(ABC):
"""
Abstract base class for optimizer builders.
Optimizer builders encapsulate the configuration and construction logic
for creating optax GradientTransformation objects.
Attributes:
config: Optimizer-specific configuration object.
Methods:
build: Creates the base optimizer transformation.
validate: Optional validation hook called before building.
"""
config: tp.Any # Will be overridden with specific config type in subclasses
[docs] @abstractmethod
def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation:
"""
Build the base optimizer transformation.
Args:
scheduler: Learning rate schedule to use.
Returns:
optax.GradientTransformation: The optimizer transformation.
"""
pass
[docs] def validate(self) -> None: # noqa
"""
Optional validation hook called before building the optimizer.
Raises:
ValueError: If the configuration is invalid.
"""
pass
[docs]@dataclasses.dataclass
class SchedulerBuilder(ABC):
"""
Abstract base class for scheduler builders.
Scheduler builders encapsulate the configuration and construction logic
for creating optax Schedule objects.
Attributes:
config: Scheduler configuration object.
Methods:
build: Creates the learning rate schedule.
"""
config: SchedulerConfig
[docs] @abstractmethod
def build(self) -> optax.Schedule:
"""
Build the learning rate schedule.
Returns:
optax.Schedule: The learning rate schedule.
"""
pass
# Registry dictionaries
_OPTIMIZER_BUILDER_REGISTRY: dict[str, type[OptimizerBuilder]] = {}
_SCHEDULER_BUILDER_REGISTRY: dict[str, type[SchedulerBuilder]] = {}
[docs]def register_optimizer(name: str) -> tp.Callable[[type[OptimizerBuilder]], type[OptimizerBuilder]]:
"""
Decorator to register an optimizer builder class.
Args:
name: Name to register the optimizer under.
Returns:
Decorator function that registers the class.
Example:
@register_optimizer("adamw")
@dataclass
class AdamWOptimizer(OptimizerBuilder):
config: AdamWConfig
def build(self, scheduler):
return optax.adamw(learning_rate=scheduler, ...)
"""
def decorator(cls: type[OptimizerBuilder]) -> type[OptimizerBuilder]:
if name in _OPTIMIZER_BUILDER_REGISTRY:
raise ValueError(f"Optimizer '{name}' is already registered")
_OPTIMIZER_BUILDER_REGISTRY[name] = cls
return cls
return decorator
[docs]def register_scheduler(name: str) -> tp.Callable[[type[SchedulerBuilder]], type[SchedulerBuilder]]:
"""
Decorator to register a scheduler builder class.
Args:
name: Name to register the scheduler under.
Returns:
Decorator function that registers the class.
Example:
@register_scheduler("cosine")
@dataclass
class CosineSchedulerBuilder(SchedulerBuilder):
config: SchedulerConfig
def build(self):
return optax.cosine_decay_schedule(...)
"""
def decorator(cls: type[SchedulerBuilder]) -> type[SchedulerBuilder]:
if name in _SCHEDULER_BUILDER_REGISTRY:
raise ValueError(f"Scheduler '{name}' is already registered")
_SCHEDULER_BUILDER_REGISTRY[name] = cls
return cls
return decorator