# Copyright 2025 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import difflib
import inspect
import typing as tp
from dataclasses import fields
import jax
import optax
# Import builders to trigger registration (side-effect import)
# Import custom optimizers to trigger registration (side-effect import)
from . import _builders, _tx # noqa
# Import base classes and registries
from ._base import _OPTIMIZER_BUILDER_REGISTRY
from ._config import (
AdafactorConfig,
AdamWConfig,
KronConfig,
LionConfig,
MarsConfig,
MuonConfig,
RMSPropConfig,
SchedulerConfig,
SerializationMixin,
SoapConfig,
WhiteKronConfig,
)
from ._tx import optax_add_scheduled_weight_decay
TxConfigs = (
AdafactorConfig
| AdamWConfig
| KronConfig
| LionConfig
| MarsConfig
| MuonConfig
| RMSPropConfig
| SoapConfig
| WhiteKronConfig
)
[docs]class SchedulerFactory:
"""
Factory class for creating learning rate schedulers.
This class provides methods to create schedulers based on a configuration object (`SchedulerConfig`).
It supports linear and cosine schedulers with optional warmup steps.
Methods:
create_scheduler: Creates a scheduler based on the provided configuration.
_create_linear: Creates a linear scheduler with optional warmup.
_create_cosine: Creates a cosine scheduler with optional warmup.
"""
[docs] @staticmethod
def create_scheduler(
config: SchedulerConfig,
custom_scheduler: tp.Callable[[int], optax.Schedule] | None = None,
) -> optax.Schedule:
"""
Create a scheduler based on the provided configuration.
Args:
config (SchedulerConfig): Configuration object for the scheduler.
custom_scheduler (Optional[Callable[[int], optax.Schedule]]): Custom scheduler function. Defaults to None.
Returns:
optax.Schedule: The created scheduler.
Raises:
ValueError: If the configuration is invalid or unsupported scheduler type is provided.
"""
if custom_scheduler is not None:
if config.steps is None:
raise ValueError("Custom schedulers require steps configuration")
return custom_scheduler(config.steps)
if config.scheduler_type is None:
return optax.constant_schedule(config.learning_rate)
if config.steps is None:
raise ValueError("Steps must be specified for configured schedulers")
if config.scheduler_type == "linear":
return SchedulerFactory._create_linear(config)
elif config.scheduler_type == "cosine":
return SchedulerFactory._create_cosine(config)
else:
raise ValueError(f"Unsupported scheduler type: {config.scheduler_type}")
[docs] @staticmethod
def _create_linear(config: SchedulerConfig) -> optax.Schedule:
"""
Create a linear scheduler with optional warmup.
Args:
config (SchedulerConfig): Configuration object for the scheduler.
Returns:
optax.Schedule: The created linear scheduler.
"""
base_scheduler = optax.linear_schedule(
init_value=config.learning_rate,
end_value=config.learning_rate_end,
transition_steps=config.steps,
)
if config.warmup_steps:
warmup = optax.linear_schedule(
init_value=1e-8,
end_value=config.learning_rate,
transition_steps=config.warmup_steps,
)
return optax.join_schedules(
schedules=[warmup, base_scheduler],
boundaries=[config.warmup_steps],
)
return base_scheduler
[docs] @staticmethod
def _create_cosine(config: SchedulerConfig) -> optax.Schedule:
"""
Create a cosine scheduler with optional warmup.
Args:
config (SchedulerConfig): Configuration object for the scheduler.
Returns:
optax.Schedule: The created cosine scheduler.
"""
if config.warmup_steps:
return optax.warmup_cosine_decay_schedule(
init_value=1e-8,
peak_value=config.learning_rate,
warmup_steps=config.warmup_steps,
decay_steps=config.steps - config.warmup_steps,
end_value=config.learning_rate_end or 0.0,
exponent=config.exponent,
)
return optax.cosine_decay_schedule(
init_value=config.learning_rate,
decay_steps=config.steps,
alpha=config.learning_rate_end or 0.0,
)
[docs]class OptimizerFactory:
"""
Factory class for creating optimizers with validated configurations.
This class provides methods to create optimizers based on a configuration object.
All optimizers are registered using the @register_optimizer decorator pattern.
Methods:
create: Creates an optimizer with validated configuration.
generate_template: Generates a configuration template for the specified optimizer.
serialize_config: Serializes configuration to different formats.
deserialize_config: Deserializes configuration from different formats.
Private Methods:
_get_config_class: Gets the configuration class for an optimizer type.
_convert_dtypes: Converts string dtype representations to JAX dtypes.
_validate_kwargs: Validates additional parameters for the optimizer.
_build_optimizer_chain: Constructs the final optimizer chain.
"""
@staticmethod
def _get_config_class(optimizer_type: str) -> type:
"""
Get the configuration class for an optimizer type.
Args:
optimizer_type: Name of the optimizer.
Returns:
Configuration class for the optimizer.
Raises:
ValueError: If optimizer type is not registered.
"""
if optimizer_type not in _OPTIMIZER_BUILDER_REGISTRY:
available = sorted(_OPTIMIZER_BUILDER_REGISTRY.keys())
raise ValueError(f"Unsupported optimizer: {optimizer_type}. Available: {available}")
builder_cls = _OPTIMIZER_BUILDER_REGISTRY[optimizer_type]
# Get the config type from the builder's type hint
config_field = next((f for f in fields(builder_cls) if f.name == "config"), None)
if config_field and config_field.type:
return config_field.type
raise ValueError(f"Builder class for '{optimizer_type}' does not have a valid config field")
[docs] @classmethod
def create(
cls,
optimizer_type: str,
scheduler_config: SchedulerConfig | None = None,
optimizer_config: TxConfigs | None = None,
*,
weight_decay: float = 0.0,
weight_decay_mask: tp.Any | None = None,
gradient_accumulation_steps: int = 1,
clip_grad: float | None = None,
custom_scheduler: tp.Callable[[int], optax.Schedule] | None = None,
**kwargs,
) -> tuple[optax.GradientTransformation, optax.Schedule]:
"""
Create an optimizer with validated configuration.
Args:
optimizer_type (str): One of the registered optimizer types.
scheduler_config (SchedulerConfig): Configured scheduler parameters.
optimizer_config (Union[AdafactorConfig, AdamWConfig, LionConfig, MuonConfig, RMSPropConfig]):
Optimizer-specific configuration.
weight_decay (float): Global weight decay rate. Defaults to 0.0.
weight_decay_mask (Optional[Any]): Mask for weight decay application. Defaults to None.
gradient_accumulation_steps (int): Steps for gradient accumulation. Defaults to 1.
clip_grad (Optional[float]): Global clip gradient norm value. Defaults to None.
custom_scheduler (Optional[Callable[[int], optax.Schedule]]): Optional custom scheduler function.
Defaults to None.
**kwargs: Additional optimizer-specific parameters.
Returns:
Tuple[optax.GradientTransformation, optax.Schedule]: A tuple containing the optimizer and scheduler.
Raises:
ValueError: If the optimizer type is unsupported or the configuration is invalid.
TypeError: If the configuration type is invalid.
"""
# Get the appropriate config class
config_cls = cls._get_config_class(optimizer_type)
# Create default config if none provided
if optimizer_config is None:
optimizer_config = config_cls()
for key in list(kwargs.keys()):
if key in inspect.signature(optimizer_config.__class__).parameters:
setattr(optimizer_config, key, kwargs.pop(key))
if scheduler_config is None:
scheduler_config = SchedulerConfig()
# Convert string dtypes to JAX dtypes
cls._convert_dtypes(optimizer_config)
# Validate config type
if not isinstance(optimizer_config, config_cls):
raise TypeError(
f"Invalid config type {type(optimizer_config)} for optimizer {optimizer_type}. Expected {config_cls}"
)
# Validate scheduler config
if scheduler_config.scheduler_type is None and scheduler_config.warmup_steps:
raise ValueError("Warmup steps require specifying a scheduler type")
# Create scheduler
scheduler = SchedulerFactory.create_scheduler(scheduler_config, custom_scheduler)
# Create optimizer using builder pattern
builder_cls = _OPTIMIZER_BUILDER_REGISTRY[optimizer_type]
builder = builder_cls(config=optimizer_config)
builder.validate()
base_optimizer = builder.build(scheduler)
# Build the full optimizer chain (clip, base optimizer, weight decay, multi-step)
return cls._build_optimizer_chain(
base_optimizer=base_optimizer,
scheduler=scheduler,
weight_decay=weight_decay,
weight_decay_mask=weight_decay_mask,
gradient_accumulation_steps=gradient_accumulation_steps,
clip_grad=clip_grad,
)
@staticmethod
def _convert_dtypes(config: tp.Any):
"""
Automatically convert string dtype representations to JAX dtypes.
Args:
config (Any): Configuration object.
Raises:
ValueError: If an invalid dtype is specified.
"""
for field in fields(config):
if "dtype" in field.name and isinstance(getattr(config, field.name), str):
dtype = getattr(jax.numpy, getattr(config, field.name), None)
if dtype is None:
raise ValueError(f"Invalid dtype specified: {getattr(config, field.name)}")
setattr(config, field.name, dtype)
@classmethod
def _validate_kwargs(cls, config: tp.Any, kwargs: dict[str, tp.Any]):
"""
Validate additional parameters with helpful error messages.
Args:
config (Any): Configuration object.
kwargs (Dict[str, Any]): Additional parameters.
Raises:
ValueError: If unexpected parameters are provided.
"""
valid_params = inspect.signature(config.__class__).parameters
for kwarg in kwargs:
if kwarg not in valid_params:
suggestions = ", ".join(difflib.get_close_matches(kwarg, valid_params.keys()))
msg = (
f"Unexpected parameter '{kwarg}' for {config.__class__.__name__}. "
f"Valid parameters: {list(valid_params.keys())}"
)
if suggestions:
msg += f". Did you mean: {suggestions}?"
raise ValueError(msg)
@staticmethod
def _build_optimizer_chain(
base_optimizer: optax.GradientTransformation,
scheduler: optax.Schedule,
weight_decay: float = 0.0,
weight_decay_mask: tp.Any | None = None,
gradient_accumulation_steps: int = 1,
clip_grad: float | None = None,
) -> tuple[optax.GradientTransformation, optax.Schedule]:
"""
Construct the final optimizer chain with gradient clipping, weight decay, and accumulation.
Args:
base_optimizer: Base optimizer transformation.
scheduler: Learning rate scheduler.
weight_decay: Weight decay coefficient. Defaults to 0.0.
weight_decay_mask: Mask for weight decay application. Defaults to None.
gradient_accumulation_steps: Steps for gradient accumulation. Defaults to 1.
clip_grad: Global gradient norm clipping value. Defaults to None.
Returns:
Tuple[optax.GradientTransformation, optax.Schedule]: A tuple containing the optimizer chain and scheduler.
"""
chain = []
# Add gradient clipping if specified
if clip_grad:
chain.append(optax.clip_by_global_norm(clip_grad))
# Add base optimizer
chain.append(base_optimizer)
# Add weight decay if specified
if weight_decay != 0.0:
chain.append(
optax_add_scheduled_weight_decay(
lambda step: -scheduler(step) * weight_decay,
weight_decay_mask,
)
)
# Chain all transformations
tx = optax.chain(*chain)
# Add gradient accumulation if specified
if gradient_accumulation_steps > 1:
tx = optax.MultiSteps(tx, gradient_accumulation_steps)
return tx, scheduler
[docs] @classmethod
def generate_template(cls, optimizer_type: str) -> str:
"""
Generate a configuration template for the specified optimizer.
Args:
optimizer_type (str): Name of the optimizer.
Returns:
str: Configuration template.
Raises:
ValueError: If the optimizer type is unknown.
"""
config_cls = cls._get_config_class(optimizer_type)
fields_list = []
for field in dataclasses.fields(config_cls):
field_type = tp.get_type_hints(config_cls)[field.name]
default = f" = {field.default}" if not isinstance(field.default, dataclasses._MISSING_TYPE) else ""
if hasattr(field_type, "__name__"):
type_name = field_type.__name__
else:
type_name = str(field_type)
fields_list.append(f" {field.name}: {type_name}{default}")
return f"{config_cls.__name__}(\n" + "\n".join(fields_list) + "\n)"
[docs] @classmethod
def serialize_config(
cls,
config: SerializationMixin,
format: str = "dict", # noqa:A002
) -> dict | str:
"""
Serialize configuration to different formats.
Args:
config (SerializationMixin): Configuration object.
format (str): Serialization format. Supported formats: 'dict', 'json'.
Returns:
Union[Dict, str]: Serialized configuration.
Raises:
ValueError: If the format is unsupported.
"""
if format not in ["dict", "json"]:
raise ValueError("Supported formats: 'dict', 'json'")
if format == "dict":
return config.to_dict()
return config.to_json()
[docs] @classmethod
def deserialize_config(
cls,
optimizer_type: str,
data: dict | str,
format: str = "dict", # noqa:A002
) -> SerializationMixin:
"""
Deserialize configuration from different formats.
Args:
optimizer_type (str): Name of the optimizer.
data (Union[Dict, str]): Serialized configuration data.
format (str): Serialization format. Supported formats: 'dict', 'json'.
Returns:
SerializationMixin: Deserialized configuration object.
Raises:
ValueError: If the optimizer type is unknown or the format is unsupported.
TypeError: If the input data type is invalid.
"""
config_cls = cls._get_config_class(optimizer_type)
if format == "json":
if not isinstance(data, str):
raise TypeError("Expected string input for JSON format")
return config_cls.from_json(data)
if format == "dict":
if not isinstance(data, dict):
raise TypeError("Expected dictionary input for dict format")
return config_cls.from_dict(data)
raise ValueError("Unsupported format")