# Copyright 2025 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
import warnings
import chex
import jax
import optax
from jax import numpy as jnp
[docs]class OptaxScheduledWeightDecayState(tp.NamedTuple):
"""State for the scheduled weight decay optimizer.
This named tuple holds the state required by the scheduled weight decay
transformation, tracking the current step count for schedule evaluation.
Attributes:
count (chex.Array): Integer array tracking the current optimization step.
Used to evaluate the weight decay schedule function.
"""
count: chex.Array
[docs]def optax_add_scheduled_weight_decay(
schedule_fn: tp.Callable[[chex.Array], chex.Array],
mask: chex.ArrayTree | None = None,
) -> optax.GradientTransformation:
"""
Create an optax optimizer that applies weight decay on a schedule.
This function is similar to `optax.add_decayed_weights`, but it allows for
the weight decay rate to be scheduled over training steps.
Args:
schedule_fn: A function that takes the current step count as input
and returns the weight decay rate.
mask: A PyTree with the same structure as the parameters.
A value of True at a particular location indicates that weight
decay should be applied to that parameter.
Returns:
An `optax.GradientTransformation` object representing the optimizer.
"""
def init_fn(params: chex.ArrayTree) -> OptaxScheduledWeightDecayState:
"""Initialize the state of the scheduled weight decay optimizer.
Args:
params (chex.ArrayTree): Parameter tree (unused, but required by optax interface).
Returns:
OptaxScheduledWeightDecayState: Initial state with step count set to zero.
"""
del params
return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
def update_fn(
updates: chex.ArrayTree,
state: OptaxScheduledWeightDecayState,
params: chex.ArrayTree | None = None,
) -> tuple[chex.ArrayTree, OptaxScheduledWeightDecayState]:
"""Apply scheduled weight decay to the gradient updates.
Computes the weight decay rate from the schedule function and adds
the scaled parameters to the gradient updates.
Args:
updates (chex.ArrayTree): Gradient updates to be modified.
state (OptaxScheduledWeightDecayState): Current optimizer state containing step count.
params (chex.ArrayTree | None): Model parameters for weight decay computation.
Returns:
tuple[chex.ArrayTree, OptaxScheduledWeightDecayState]: Tuple containing:
- Modified gradient updates with weight decay applied.
- Updated state with incremented step count.
Raises:
ValueError: If params is None, as weight decay requires parameter values.
"""
if params is None:
raise ValueError("Params cannot be None for weight decay!")
weight_decay = schedule_fn(state.count)
updates = jax.tree_util.tree_map(lambda g, p: g + weight_decay * p, updates, params)
return updates, OptaxScheduledWeightDecayState(count=optax.safe_int32_increment(state.count))
if mask is not None:
return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
return optax.GradientTransformation(init_fn, update_fn)
[docs]def create_linear_scheduler(
steps: int,
learning_rate_start: float,
learning_rate_end: float,
warmup_steps: int | None = None,
) -> optax.Schedule:
"""
Creates a linear learning rate scheduler with optional warmup.
Args:
steps (int): Total number of training steps.
learning_rate_start (float): Initial learning rate.
learning_rate_end (float): Final learning rate.
warmup_steps (tp.Optional[int]): Number of warmup steps.
Returns:
optax.Schedule: The configured scheduler.
"""
if warmup_steps is not None:
scheduler_warmup = optax.linear_schedule(
init_value=5e-8,
end_value=learning_rate_start,
transition_steps=warmup_steps,
)
scheduler_decay = optax.linear_schedule(
init_value=learning_rate_start,
end_value=learning_rate_end,
transition_steps=steps - warmup_steps,
)
return optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])
else:
return optax.linear_schedule(
init_value=learning_rate_start,
end_value=learning_rate_end,
transition_steps=steps,
)
[docs]def create_cosine_scheduler(
steps: int,
learning_rate: float,
learning_rate_end: float | None = None,
warmup_steps: int | None = None,
exponent: float = 1.0,
) -> optax.Schedule:
"""
Creates a cosine learning rate scheduler with optional warmup.
Args:
steps (int): Total number of training steps.
learning_rate (float): Peak learning rate.
learning_rate_end (tp.Optional[float]): Final learning rate.
warmup_steps (tp.Optional[int]): Number of warmup steps.
exponent (float): Exponent for the cosine decay.
Returns:
optax.Schedule: The configured scheduler.
"""
if warmup_steps is not None:
return optax.warmup_cosine_decay_schedule(
init_value=0.5e-7,
peak_value=learning_rate,
warmup_steps=warmup_steps,
decay_steps=steps,
end_value=learning_rate_end or 1e-5,
exponent=exponent,
)
else:
return optax.cosine_decay_schedule(init_value=learning_rate, decay_steps=steps, alpha=learning_rate_end or 0.0)
[docs]def get_base_optimizer(
optimizer_type: str,
scheduler: optax.Schedule,
optimizer_kwargs: dict,
weight_decay: float = 0.0,
weight_decay_mask: tp.Any | None = None,
gradient_accumulation_steps: int = 1,
clip_grad: float | None = None,
**kwargs,
) -> optax.GradientTransformation:
"""
Base function to create an optimizer with a given scheduler.
Args:
optimizer_type (str): Type of optimizer ('adafactor', 'adamw', 'lion', 'rmsprop').
scheduler (optax.Schedule): Learning rate scheduler.
optimizer_kwargs (dict): Arguments specific to the optimizer.
weight_decay (float): Weight decay factor.
weight_decay_mask (tp.Optional[tp.Any]): Mask for weight decay.
gradient_accumulation_steps (int): Number of steps to accumulate gradients.
clip_grad (tp.Optional[float]): If provided, gradients will be clipped to this maximum norm.
Returns:
optax.GradientTransformation: The configured optimizer.
"""
for kwarg in kwargs.keys():
warnings.warn(f"Key {kwarg} is not used for optimizer.", stacklevel=1)
if optimizer_type == "adafactor":
optimizer = optax.adafactor(learning_rate=scheduler, **optimizer_kwargs)
elif optimizer_type == "adamw":
optimizer = optax.adamw(learning_rate=scheduler, **optimizer_kwargs)
elif optimizer_type == "lion":
optimizer = optax.lion(learning_rate=scheduler, **optimizer_kwargs)
elif optimizer_type == "rmsprop":
optimizer = optax.rmsprop(learning_rate=scheduler, **optimizer_kwargs)
else:
raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
chain = [optimizer]
if clip_grad is not None:
chain.insert(0, optax.clip_by_global_norm(clip_grad))
if weight_decay != 0.0:
chain.append(
optax_add_scheduled_weight_decay(
lambda step: -scheduler(step) * weight_decay,
weight_decay_mask,
)
)
tx = optax.chain(*chain)
if gradient_accumulation_steps > 1:
tx = optax.MultiSteps(tx, gradient_accumulation_steps)
return tx