eformer.optimizers._tx.utils#

class eformer.optimizers._tx.utils.OptaxScheduledWeightDecayState(count: Union[Array, ndarray, bool, number])[source]#

Bases: NamedTuple

State for the scheduled weight decay optimizer.

This named tuple holds the state required by the scheduled weight decay transformation, tracking the current step count for schedule evaluation.

count#

Integer array tracking the current optimization step. Used to evaluate the weight decay schedule function.

Type

chex.Array

count: Union[Array, ndarray, bool, number]#

Alias for field number 0

eformer.optimizers._tx.utils.create_cosine_scheduler(steps: int, learning_rate: float, learning_rate_end: float | None = None, warmup_steps: int | None = None, exponent: float = 1.0) Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]][source]#

Creates a cosine learning rate scheduler with optional warmup.

Parameters
  • steps (int) – Total number of training steps.

  • learning_rate (float) – Peak learning rate.

  • learning_rate_end (tp.Optional[float]) – Final learning rate.

  • warmup_steps (tp.Optional[int]) – Number of warmup steps.

  • exponent (float) – Exponent for the cosine decay.

Returns

The configured scheduler.

Return type

optax.Schedule

eformer.optimizers._tx.utils.create_linear_scheduler(steps: int, learning_rate_start: float, learning_rate_end: float, warmup_steps: int | None = None) Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]][source]#

Creates a linear learning rate scheduler with optional warmup.

Parameters
  • steps (int) – Total number of training steps.

  • learning_rate_start (float) – Initial learning rate.

  • learning_rate_end (float) – Final learning rate.

  • warmup_steps (tp.Optional[int]) – Number of warmup steps.

Returns

The configured scheduler.

Return type

optax.Schedule

eformer.optimizers._tx.utils.get_base_optimizer(optimizer_type: str, scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]], optimizer_kwargs: dict, weight_decay: float = 0.0, weight_decay_mask: Any | None = None, gradient_accumulation_steps: int = 1, clip_grad: float | None = None, **kwargs) GradientTransformation[source]#

Base function to create an optimizer with a given scheduler.

Parameters
  • optimizer_type (str) – Type of optimizer (‘adafactor’, ‘adamw’, ‘lion’, ‘rmsprop’).

  • scheduler (optax.Schedule) – Learning rate scheduler.

  • optimizer_kwargs (dict) – Arguments specific to the optimizer.

  • weight_decay (float) – Weight decay factor.

  • weight_decay_mask (tp.Optional[tp.Any]) – Mask for weight decay.

  • gradient_accumulation_steps (int) – Number of steps to accumulate gradients.

  • clip_grad (tp.Optional[float]) – If provided, gradients will be clipped to this maximum norm.

Returns

The configured optimizer.

Return type

optax.GradientTransformation

eformer.optimizers._tx.utils.optax_add_scheduled_weight_decay(schedule_fn: Callable[[Union[Array, ndarray, bool, number]], Union[Array, ndarray, bool, number]], mask: Optional[Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]] = None) GradientTransformation[source]#

Create an optax optimizer that applies weight decay on a schedule.

This function is similar to optax.add_decayed_weights, but it allows for the weight decay rate to be scheduled over training steps.

Parameters
  • schedule_fn – A function that takes the current step count as input and returns the weight decay rate.

  • mask – A PyTree with the same structure as the parameters. A value of True at a particular location indicates that weight decay should be applied to that parameter.

Returns

An optax.GradientTransformation object representing the optimizer.