eformer.optimizers._tx.utils#
- class eformer.optimizers._tx.utils.OptaxScheduledWeightDecayState(count: Union[Array, ndarray, bool, number])[source]#
Bases:
NamedTupleState for the scheduled weight decay optimizer.
This named tuple holds the state required by the scheduled weight decay transformation, tracking the current step count for schedule evaluation.
- count#
Integer array tracking the current optimization step. Used to evaluate the weight decay schedule function.
- Type
chex.Array
- eformer.optimizers._tx.utils.create_cosine_scheduler(steps: int, learning_rate: float, learning_rate_end: float | None = None, warmup_steps: int | None = None, exponent: float = 1.0) Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]][source]#
Creates a cosine learning rate scheduler with optional warmup.
- Parameters
steps (int) – Total number of training steps.
learning_rate (float) – Peak learning rate.
learning_rate_end (tp.Optional[float]) – Final learning rate.
warmup_steps (tp.Optional[int]) – Number of warmup steps.
exponent (float) – Exponent for the cosine decay.
- Returns
The configured scheduler.
- Return type
optax.Schedule
- eformer.optimizers._tx.utils.create_linear_scheduler(steps: int, learning_rate_start: float, learning_rate_end: float, warmup_steps: int | None = None) Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]][source]#
Creates a linear learning rate scheduler with optional warmup.
- Parameters
steps (int) – Total number of training steps.
learning_rate_start (float) – Initial learning rate.
learning_rate_end (float) – Final learning rate.
warmup_steps (tp.Optional[int]) – Number of warmup steps.
- Returns
The configured scheduler.
- Return type
optax.Schedule
- eformer.optimizers._tx.utils.get_base_optimizer(optimizer_type: str, scheduler: Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]], optimizer_kwargs: dict, weight_decay: float = 0.0, weight_decay_mask: Any | None = None, gradient_accumulation_steps: int = 1, clip_grad: float | None = None, **kwargs) GradientTransformation[source]#
Base function to create an optimizer with a given scheduler.
- Parameters
optimizer_type (str) – Type of optimizer (‘adafactor’, ‘adamw’, ‘lion’, ‘rmsprop’).
scheduler (optax.Schedule) – Learning rate scheduler.
optimizer_kwargs (dict) – Arguments specific to the optimizer.
weight_decay (float) – Weight decay factor.
weight_decay_mask (tp.Optional[tp.Any]) – Mask for weight decay.
gradient_accumulation_steps (int) – Number of steps to accumulate gradients.
clip_grad (tp.Optional[float]) – If provided, gradients will be clipped to this maximum norm.
- Returns
The configured optimizer.
- Return type
optax.GradientTransformation
- eformer.optimizers._tx.utils.optax_add_scheduled_weight_decay(schedule_fn: Callable[[Union[Array, ndarray, bool, number]], Union[Array, ndarray, bool, number]], mask: Optional[Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]] = None) GradientTransformation[source]#
Create an optax optimizer that applies weight decay on a schedule.
This function is similar to optax.add_decayed_weights, but it allows for the weight decay rate to be scheduled over training steps.
- Parameters
schedule_fn – A function that takes the current step count as input and returns the weight decay rate.
mask – A PyTree with the same structure as the parameters. A value of True at a particular location indicates that weight decay should be applied to that parameter.
- Returns
An optax.GradientTransformation object representing the optimizer.