eformer.optimizers._tx.mars#
- class eformer.optimizers._tx.mars.ScaleByMarsState(count: Union[Array, ndarray, bool, number], mu: Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]], nu: Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]], mog: Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]])[source]#
Bases:
NamedTupleState for the Mars (Matrix-wise Adaptive Regularized Scaling) algorithm.
This named tuple holds the optimizer state required by the Mars algorithm, tracking moments and gradient history for variance reduction.
- count#
Integer array tracking the current optimization step. Used for bias correction of the moment estimates.
- Type
chex.Array
- mu#
First moment estimates (exponential moving average of gradients). Has the same structure as the model parameters.
- Type
optax.Updates
- nu#
Second moment estimates (exponential moving average of squared gradients). Has the same structure as the model parameters.
- Type
optax.Updates
- mog#
Momentum of gradients from the previous step. Used for variance reduction in the Mars algorithm.
- Type
optax.Updates
- mog: Union[Array, ndarray, bool, number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]]#
Alias for field number 3
- eformer.optimizers._tx.mars.mars(learning_rate: float | collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], **kwargs) GradientTransformation[source]#
Mars (Matrix-wise Adaptive Regularized Scaling) optimizer.
Complete Mars optimizer that combines Mars gradient scaling with learning rate scheduling. Mars uses a variance reduction technique that incorporates gradient momentum from the previous step, providing improved convergence over Adam.
Reference: https://arxiv.org/abs/2411.10438
- Parameters
learning_rate (float | optax.Schedule) – Learning rate value or schedule function. Can be a constant float or an optax.Schedule that takes step count as input.
**kwargs – Additional keyword arguments passed to scale_by_mars. See scale_by_mars for available options including: - b1 (float): Decay rate for first moment. Defaults to 0.9. - b2 (float): Decay rate for second moment. Defaults to 0.999. - gamma (float): Variance reduction strength. Defaults to 0.05. - eps (float): Numerical stability constant. Defaults to 1e-8. - max_grad_norm (float): Gradient clipping norm. Defaults to 0.0. - mu_dtype: Data type for moment accumulators.
- Returns
- The Mars optimizer ready for use with
optax.apply_updates.
- Return type
optax.GradientTransformation
Example
>>> import optax >>> from eformer.optimizers._tx import mars >>> # With constant learning rate >>> optimizer = mars(learning_rate=1e-4, b1=0.95, b2=0.99) >>> # With learning rate schedule >>> schedule = optax.warmup_cosine_decay_schedule( ... init_value=1e-7, peak_value=1e-4, warmup_steps=1000, decay_steps=10000 ... ) >>> optimizer = mars(learning_rate=schedule, gamma=0.025)
- eformer.optimizers._tx.mars.scale_by_mars(b1: float = 0.9, b2: float = 0.999, gamma: float = 0.05, eps: float = 1e-08, eps_root: float = 0.0, max_grad_norm: float = 0.0, mu_dtype: Any | None = None) GradientTransformation[source]#
Rescale updates according to the Mars algorithm.
Mars uses a variance reduction technique that incorporates gradient momentum from the previous step, improving upon standard Adam-style optimizers.
Reference: https://arxiv.org/abs/2411.10438
- Parameters
b1 (float) – Decay rate for the exponentially weighted average of gradients. Controls how quickly the first moment estimate adapts to new gradients. Defaults to 0.9.
b2 (float) – Decay rate for the exponentially weighted average of squared gradients. Controls how quickly the second moment estimate adapts. Defaults to 0.999.
gamma (float) – Decay rate for the exponentially weighted average of gradient momentum from the previous step. This parameter controls the variance reduction strength. Defaults to 0.05.
eps (float) – Small constant added to the denominator to improve numerical stability. Prevents division by zero. Defaults to 1e-8.
eps_root (float) – Small constant added inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. Defaults to 0.0.
max_grad_norm (float) – Maximum gradient norm for clipping. If > 0, gradients are clipped to this norm before computing moment estimates. Defaults to 0.0 (no clipping).
mu_dtype (Any | None) – Optional dtype for the first moment accumulator. If None, dtype is inferred from params and updates. Defaults to None.
- Returns
- A gradient transformation that rescales updates
according to the Mars algorithm.
- Return type
optax.GradientTransformation
Example
>>> import optax >>> from eformer.optimizers._tx import scale_by_mars >>> optimizer = optax.chain( ... scale_by_mars(b1=0.95, b2=0.99, gamma=0.025), ... optax.scale_by_learning_rate(1e-4), ... )