eformer.optimizers._tx.mars#

class eformer.optimizers._tx.mars.ScaleByMarsState(count: Union[Array, ndarray, bool, number], mu: Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], nu: Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], mog: Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]])[source]#

Bases: NamedTuple

State for the Mars (Matrix-wise Adaptive Regularized Scaling) algorithm.

This named tuple holds the optimizer state required by the Mars algorithm, tracking moments and gradient history for variance reduction.

count#

Integer array tracking the current optimization step. Used for bias correction of the moment estimates.

Type

chex.Array

mu#

First moment estimates (exponential moving average of gradients). Has the same structure as the model parameters.

Type

optax.Updates

nu#

Second moment estimates (exponential moving average of squared gradients). Has the same structure as the model parameters.

Type

optax.Updates

mog#

Momentum of gradients from the previous step. Used for variance reduction in the Mars algorithm.

Type

optax.Updates

count: Union[Array, ndarray, bool, number]#

Alias for field number 0

mog: Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]#

Alias for field number 3

mu: Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]#

Alias for field number 1

nu: Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]#

Alias for field number 2

eformer.optimizers._tx.mars.mars(learning_rate: float | collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], **kwargs) GradientTransformation[source]#

Mars (Matrix-wise Adaptive Regularized Scaling) optimizer.

Complete Mars optimizer that combines Mars gradient scaling with learning rate scheduling. Mars uses a variance reduction technique that incorporates gradient momentum from the previous step, providing improved convergence over Adam.

Reference: https://arxiv.org/abs/2411.10438

Parameters
  • learning_rate (float | optax.Schedule) – Learning rate value or schedule function. Can be a constant float or an optax.Schedule that takes step count as input.

  • **kwargs – Additional keyword arguments passed to scale_by_mars. See scale_by_mars for available options including: - b1 (float): Decay rate for first moment. Defaults to 0.9. - b2 (float): Decay rate for second moment. Defaults to 0.999. - gamma (float): Variance reduction strength. Defaults to 0.05. - eps (float): Numerical stability constant. Defaults to 1e-8. - max_grad_norm (float): Gradient clipping norm. Defaults to 0.0. - mu_dtype: Data type for moment accumulators.

Returns

The Mars optimizer ready for use with

optax.apply_updates.

Return type

optax.GradientTransformation

Example

>>> import optax
>>> from eformer.optimizers._tx import mars
>>> # With constant learning rate
>>> optimizer = mars(learning_rate=1e-4, b1=0.95, b2=0.99)
>>> # With learning rate schedule
>>> schedule = optax.warmup_cosine_decay_schedule(
...     init_value=1e-7, peak_value=1e-4, warmup_steps=1000, decay_steps=10000
... )
>>> optimizer = mars(learning_rate=schedule, gamma=0.025)
eformer.optimizers._tx.mars.scale_by_mars(b1: float = 0.9, b2: float = 0.999, gamma: float = 0.05, eps: float = 1e-08, eps_root: float = 0.0, max_grad_norm: float = 0.0, mu_dtype: Any | None = None) GradientTransformation[source]#

Rescale updates according to the Mars algorithm.

Mars uses a variance reduction technique that incorporates gradient momentum from the previous step, improving upon standard Adam-style optimizers.

Reference: https://arxiv.org/abs/2411.10438

Parameters
  • b1 (float) – Decay rate for the exponentially weighted average of gradients. Controls how quickly the first moment estimate adapts to new gradients. Defaults to 0.9.

  • b2 (float) – Decay rate for the exponentially weighted average of squared gradients. Controls how quickly the second moment estimate adapts. Defaults to 0.999.

  • gamma (float) – Decay rate for the exponentially weighted average of gradient momentum from the previous step. This parameter controls the variance reduction strength. Defaults to 0.05.

  • eps (float) – Small constant added to the denominator to improve numerical stability. Prevents division by zero. Defaults to 1e-8.

  • eps_root (float) – Small constant added inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. Defaults to 0.0.

  • max_grad_norm (float) – Maximum gradient norm for clipping. If > 0, gradients are clipped to this norm before computing moment estimates. Defaults to 0.0 (no clipping).

  • mu_dtype (Any | None) – Optional dtype for the first moment accumulator. If None, dtype is inferred from params and updates. Defaults to None.

Returns

A gradient transformation that rescales updates

according to the Mars algorithm.

Return type

optax.GradientTransformation

Example

>>> import optax
>>> from eformer.optimizers._tx import scale_by_mars
>>> optimizer = optax.chain(
...     scale_by_mars(b1=0.95, b2=0.99, gamma=0.025),
...     optax.scale_by_learning_rate(1e-4),
... )