Source code for eformer.optimizers._builders

# Copyright 2026 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import dataclasses
import typing as tp

import optax

from ._base import OptimizerBuilder, SchedulerBuilder, register_optimizer, register_scheduler
from ._config import (
    AdafactorConfig,
    AdamWConfig,
    LionConfig,
    MarsConfig,
    MuonConfig,
    RMSPropConfig,
    SchedulerConfig,
    WhiteKronConfig,
)
from ._stage_local import (
    StageLocalOptimizerMetadata,
    _apply_adafactor_stage_local,
    _apply_adamw_stage_local,
    _apply_lion_stage_local,
    _apply_mars_stage_local,
    _apply_muon_stage_local,
    _apply_quad_stage_local,
    _apply_rmsprop_stage_local,
    _apply_skew_stage_local,
    make_stage_local_gradient_transformation,
)
from ._tx import mars, quad, skew


def _build_metadata_mpmd_optimizer(
    config: tp.Any,
    scheduler: optax.Schedule,
    *,
    optimizer: optax.GradientTransformation,
    stage_local_apply: tp.Callable[..., tuple[optax.Params, optax.OptState]],
    **tx_kwargs: tp.Any,
) -> optax.GradientTransformation:
    """Wrap an optimizer chain with stage-local metadata for PP training.

    This helper constructs a :class:`StageLocalOptimizerMetadata` dataclass
    from the builder's configuration and the factory-level hyperparameters,
    then attaches it to the optimizer chain via
    :func:`make_stage_local_gradient_transformation`. The resulting
    transformation supports both normal Optax updates and explicit
    stage-local applications inside pipeline-parallel training loops.

    Args:
        config: The optimizer-specific configuration object (e.g. ``AdamWConfig``).
        scheduler: Learning rate schedule paired with the optimizer.
        optimizer: Fully assembled optimizer chain from the factory.
        stage_local_apply: Optimizer-specific stage-local update kernel.
        **tx_kwargs: Factory-level transform options such as ``weight_decay``,
            ``weight_decay_mask``, ``gradient_accumulation_steps``, and
            ``clip_grad``.

    Returns:
        A :class:`StageLocalGradientTransformation` wrapping ``optimizer``.
    """

    known_tx_kwargs = {
        "weight_decay",
        "weight_decay_mask",
        "gradient_accumulation_steps",
        "clip_grad",
    }
    extra_kwargs = {key: value for key, value in tx_kwargs.items() if key not in known_tx_kwargs}
    metadata = StageLocalOptimizerMetadata(
        scheduler=scheduler,
        weight_decay=float(tx_kwargs.get("weight_decay", 0.0)),
        weight_decay_mask=tx_kwargs.get("weight_decay_mask"),
        gradient_accumulation_steps=int(tx_kwargs.get("gradient_accumulation_steps", 1)),
        clip_grad=tx_kwargs.get("clip_grad"),
        adamw_b1=getattr(config, "b1", None),
        adamw_b2=getattr(config, "b2", None),
        adamw_eps=getattr(config, "eps", None),
        adamw_eps_root=getattr(config, "eps_root", None),
        adamw_mu_dtype=getattr(config, "mu_dtype", None),
        optimizer_config=config,
        extra_kwargs=extra_kwargs,
    )
    return make_stage_local_gradient_transformation(optimizer, metadata=metadata, apply_fn=stage_local_apply)


[docs]@register_scheduler("constant") @dataclasses.dataclass class ConstantSchedulerBuilder(SchedulerBuilder): """Builder for constant learning rate schedule. This builder creates a scheduler that maintains a fixed learning rate throughout training. Attributes: config (SchedulerConfig): Configuration object containing the learning rate. Example: >>> from eformer.optimizers import SchedulerConfig >>> config = SchedulerConfig(learning_rate=1e-4) >>> builder = ConstantSchedulerBuilder(config=config) >>> scheduler = builder.build() >>> scheduler(0) # Returns 1e-4 """ config: SchedulerConfig
[docs] def build(self) -> optax.Schedule: """Build a constant learning rate schedule. Returns: optax.Schedule: A schedule function that returns the configured learning rate regardless of the step count. """ return optax.constant_schedule(self.config.learning_rate)
[docs]@register_scheduler("linear") @dataclasses.dataclass class LinearSchedulerBuilder(SchedulerBuilder): """Builder for linear learning rate schedule with optional warmup. This builder creates a scheduler that linearly decays the learning rate from an initial value to an end value over a specified number of steps. Optionally, a warmup phase can be added at the beginning of training. Attributes: config (SchedulerConfig): Configuration object containing learning rate parameters, steps, and optional warmup settings. Example: >>> from eformer.optimizers import SchedulerConfig >>> config = SchedulerConfig( ... scheduler_type="linear", ... learning_rate=1e-4, ... learning_rate_end=1e-6, ... steps=10000, ... warmup_steps=1000, ... ) >>> builder = LinearSchedulerBuilder(config=config) >>> scheduler = builder.build() """ config: SchedulerConfig
[docs] def build(self) -> optax.Schedule: """Build a linear learning rate schedule with optional warmup. Creates a linear decay schedule from `learning_rate` to `learning_rate_end`. If warmup_steps is specified, prepends a linear warmup phase from a very small value (1e-8) to the initial learning rate. Returns: optax.Schedule: A schedule function that returns the learning rate for a given step count. Raises: ValueError: If learning_rate_end is not specified in the config. """ if self.config.learning_rate_end is None: raise ValueError("Linear scheduler requires learning_rate_end") decay_steps = self.config.steps if self.config.warmup_steps: decay_steps = self.config.steps - self.config.warmup_steps base_scheduler = optax.linear_schedule( init_value=self.config.learning_rate, end_value=self.config.learning_rate_end, transition_steps=decay_steps, ) if self.config.warmup_steps: warmup = optax.linear_schedule( init_value=1e-8, end_value=self.config.learning_rate, transition_steps=self.config.warmup_steps, ) return optax.join_schedules( schedules=[warmup, base_scheduler], boundaries=[self.config.warmup_steps], ) return base_scheduler
[docs]@register_scheduler("cosine") @dataclasses.dataclass class CosineSchedulerBuilder(SchedulerBuilder): """Builder for cosine learning rate schedule with optional warmup. This builder creates a scheduler that decays the learning rate following a cosine curve. This is a popular choice for training neural networks as it provides smooth decay with a "warm restart" capability. Attributes: config (SchedulerConfig): Configuration object containing learning rate parameters, steps, warmup settings, and cosine decay exponent. Example: >>> from eformer.optimizers import SchedulerConfig >>> config = SchedulerConfig( ... scheduler_type="cosine", ... learning_rate=1e-4, ... learning_rate_end=1e-6, ... steps=10000, ... warmup_steps=1000, ... exponent=1.0, ... ) >>> builder = CosineSchedulerBuilder(config=config) >>> scheduler = builder.build() """ config: SchedulerConfig
[docs] def build(self) -> optax.Schedule: """Build a cosine learning rate schedule with optional warmup. Creates a cosine decay schedule that smoothly decreases the learning rate from the peak value to the end value. If warmup_steps is specified, includes a linear warmup phase from a very small value (1e-8) to the peak learning rate before the cosine decay begins. Returns: optax.Schedule: A schedule function that returns the learning rate for a given step count, following a cosine decay pattern. """ if self.config.warmup_steps: end_value = self.config.learning_rate_end or 0.0 return optax.warmup_cosine_decay_schedule( init_value=1e-8, peak_value=self.config.learning_rate, warmup_steps=self.config.warmup_steps, decay_steps=self.config.steps, end_value=end_value, exponent=self.config.exponent, ) if self.config.learning_rate_end is not None and self.config.learning_rate <= 0: raise ValueError("learning_rate must be greater than 0 when learning_rate_end is set") cosine_alpha = ( 0.0 if self.config.learning_rate_end is None else self.config.learning_rate_end / self.config.learning_rate ) return optax.cosine_decay_schedule( init_value=self.config.learning_rate, decay_steps=self.config.steps, alpha=cosine_alpha, )
[docs]@register_optimizer("adamw") @dataclasses.dataclass class AdamWOptimizer(OptimizerBuilder): """Builder for AdamW optimizer. AdamW is a variant of Adam that decouples weight decay from the gradient update, which often leads to better generalization. It is one of the most widely used optimizers for training transformers and other deep learning models. Attributes: config (AdamWConfig): Configuration object containing AdamW hyperparameters including momentum coefficients (b1, b2), epsilon values, and data type. Example: >>> from eformer.optimizers import AdamWConfig >>> import optax >>> config = AdamWConfig(b1=0.9, b2=0.999, eps=1e-8) >>> builder = AdamWOptimizer(config=config) >>> scheduler = optax.constant_schedule(1e-4) >>> optimizer = builder.build(scheduler) """ config: AdamWConfig
[docs] def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """Build the AdamW optimizer transformation. Args: scheduler (optax.Schedule): Learning rate schedule to use for the optimizer. Returns: optax.GradientTransformation: The AdamW optimizer transformation that can be used with optax.apply_updates to update model parameters. """ return optax.adamw( learning_rate=scheduler, b1=self.config.b1, b2=self.config.b2, eps=self.config.eps, eps_root=self.config.eps_root, mu_dtype=self.config.mu_dtype, weight_decay=0.0, )
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the AdamW stage-local optimizer for pipeline-parallel training. Delegates to :func:`_build_metadata_mpmd_optimizer` to attach stage-local metadata and a leafwise AdamW update kernel to the optimizer chain. Args: scheduler: Learning rate schedule. optimizer: Fully assembled optimizer chain from the factory. **tx_kwargs: Factory-level transformation options. Returns: A :class:`StageLocalGradientTransformation` supporting both normal Optax updates and stage-local PP updates. """ return _build_metadata_mpmd_optimizer( self.config, scheduler, optimizer=optimizer, stage_local_apply=_apply_adamw_stage_local, **tx_kwargs, )
[docs]@register_optimizer("adafactor") @dataclasses.dataclass class AdafactorOptimizer(OptimizerBuilder): """Builder for Adafactor optimizer. Adafactor is a memory-efficient adaptive learning rate optimizer designed for training large models. It uses factored second-moment estimation to reduce memory usage while maintaining adaptive learning rate capabilities. This optimizer is particularly useful for training large language models where memory constraints are significant. Attributes: config (AdafactorConfig): Configuration object containing Adafactor hyperparameters including factorization settings, decay rates, and clipping thresholds. Example: >>> from eformer.optimizers import AdafactorConfig >>> import optax >>> config = AdafactorConfig(decay_rate=0.8, factored=True) >>> builder = AdafactorOptimizer(config=config) >>> scheduler = optax.constant_schedule(1e-4) >>> optimizer = builder.build(scheduler) """ config: AdafactorConfig
[docs] def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """Build the Adafactor optimizer transformation. Args: scheduler (optax.Schedule): Learning rate schedule to use for the optimizer. Returns: optax.GradientTransformation: The Adafactor optimizer transformation configured with factored second-moment estimation for memory efficiency. """ return optax.adafactor( learning_rate=scheduler, min_dim_size_to_factor=self.config.min_dim_size_to_factor, decay_rate=self.config.decay_rate, decay_offset=self.config.decay_offset, multiply_by_parameter_scale=self.config.multiply_by_parameter_scale, clipping_threshold=self.config.clipping_threshold, momentum=self.config.momentum, dtype_momentum=self.config.dtype_momentum, weight_decay_rate=self.config.weight_decay_rate, eps=self.config.eps, factored=self.config.factored, )
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the Adafactor stage-local optimizer for pipeline-parallel training. Delegates to :func:`_build_metadata_mpmd_optimizer` to attach stage-local metadata and a leafwise Adafactor update kernel to the optimizer chain. Args: scheduler: Learning rate schedule. optimizer: Fully assembled optimizer chain from the factory. **tx_kwargs: Factory-level transformation options. Returns: A :class:`StageLocalGradientTransformation` supporting both normal Optax updates and stage-local PP updates. """ return _build_metadata_mpmd_optimizer( self.config, scheduler, optimizer=optimizer, stage_local_apply=_apply_adafactor_stage_local, **tx_kwargs, )
[docs]@register_optimizer("lion") @dataclasses.dataclass class LionOptimizer(OptimizerBuilder): """Builder for Lion (Evolved Sign Momentum) optimizer. Lion is an optimizer discovered through neural architecture search that uses sign-based updates with momentum. It often achieves better generalization than Adam with fewer hyperparameters to tune. Reference: https://arxiv.org/abs/2302.06675 Attributes: config (LionConfig): Configuration object containing Lion hyperparameters including momentum coefficients (b1, b2) and data type for momentum. Example: >>> from eformer.optimizers import LionConfig >>> import optax >>> config = LionConfig(b1=0.9, b2=0.99) >>> builder = LionOptimizer(config=config) >>> scheduler = optax.constant_schedule(1e-4) >>> optimizer = builder.build(scheduler) """ config: LionConfig
[docs] def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """Build the Lion optimizer transformation. Args: scheduler (optax.Schedule): Learning rate schedule to use for the optimizer. Returns: optax.GradientTransformation: The Lion optimizer transformation that uses sign-based updates with momentum for efficient parameter updates. """ return optax.lion( learning_rate=scheduler, b1=self.config.b1, b2=self.config.b2, mu_dtype=self.config.mu_dtype, )
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the Lion stage-local optimizer for pipeline-parallel training. Delegates to :func:`_build_metadata_mpmd_optimizer` to attach stage-local metadata and a leafwise Lion update kernel to the optimizer chain. Args: scheduler: Learning rate schedule. optimizer: Fully assembled optimizer chain from the factory. **tx_kwargs: Factory-level transformation options. Returns: A :class:`StageLocalGradientTransformation` supporting both normal Optax updates and stage-local PP updates. """ return _build_metadata_mpmd_optimizer( self.config, scheduler, optimizer=optimizer, stage_local_apply=_apply_lion_stage_local, **tx_kwargs, )
[docs]@register_optimizer("rmsprop") @dataclasses.dataclass class RMSPropOptimizer(OptimizerBuilder): """Builder for RMSProp (Root Mean Square Propagation) optimizer. RMSProp is an adaptive learning rate optimizer that divides the gradient by a running average of the magnitude of recent gradients. It is effective for training recurrent neural networks and other models with non-stationary objectives. Attributes: config (RMSPropConfig): Configuration object containing RMSProp hyperparameters including decay rate, epsilon, momentum, and Nesterov momentum settings. Example: >>> from eformer.optimizers import RMSPropConfig >>> import optax >>> config = RMSPropConfig(decay=0.9, eps=1e-8, momentum=0.9) >>> builder = RMSPropOptimizer(config=config) >>> scheduler = optax.constant_schedule(1e-4) >>> optimizer = builder.build(scheduler) """ config: RMSPropConfig
[docs] def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """Build the RMSProp optimizer transformation. Args: scheduler (optax.Schedule): Learning rate schedule to use for the optimizer. Returns: optax.GradientTransformation: The RMSProp optimizer transformation that adapts the learning rate based on a moving average of squared gradients. """ return optax.rmsprop( learning_rate=scheduler, decay=self.config.decay, eps=self.config.eps, initial_scale=self.config.initial_scale, centered=False, momentum=self.config.momentum, nesterov=self.config.nesterov, )
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the RMSProp stage-local optimizer for pipeline-parallel training. Delegates to :func:`_build_metadata_mpmd_optimizer` to attach stage-local metadata and a leafwise RMSProp update kernel to the optimizer chain. Args: scheduler: Learning rate schedule. optimizer: Fully assembled optimizer chain from the factory. **tx_kwargs: Factory-level transformation options. Returns: A :class:`StageLocalGradientTransformation` supporting both normal Optax updates and stage-local PP updates. """ return _build_metadata_mpmd_optimizer( self.config, scheduler, optimizer=optimizer, stage_local_apply=_apply_rmsprop_stage_local, **tx_kwargs, )
[docs]@register_optimizer("muon") @dataclasses.dataclass class MuonOptimizer(OptimizerBuilder): """Builder for Muon (Momentum Orthogonalized by Newton-schulz) optimizer. Muon is designed specifically for 2D parameters (matrices) and uses the Newton-Schulz method to orthogonalize momentum. Non-2D parameters are processed through an Adam optimizer fallback. This makes it particularly effective for training models with large matrix parameters. The optimizer maintains orthogonality of the momentum, which can lead to more stable training and better convergence for certain architectures. Attributes: config (MuonConfig): Configuration object containing Muon hyperparameters including Newton-Schulz coefficients, number of steps, momentum parameters, and Adam fallback settings. Example: >>> from eformer.optimizers import MuonConfig >>> import optax >>> config = MuonConfig(ns_steps=5, beta=0.95, nesterov=True) >>> builder = MuonOptimizer(config=config) >>> scheduler = optax.constant_schedule(1e-4) >>> optimizer = builder.build(scheduler) """ config: MuonConfig
[docs] def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """Build the Muon optimizer transformation. Args: scheduler (optax.Schedule): Learning rate schedule to use for the optimizer. Returns: optax.GradientTransformation: The Muon optimizer transformation that uses Newton-Schulz orthogonalization for 2D parameters and Adam for others. """ return optax.contrib.muon( learning_rate=scheduler, ns_steps=self.config.ns_steps, ns_coeffs=self.config.ns_coeffs, beta=self.config.beta, eps=self.config.eps, weight_decay=self.config.weight_decay, weight_decay_mask=self.config.weight_decay_mask, mu_dtype=self.config.mu_dtype, nesterov=self.config.nesterov, adaptive=self.config.adaptive, adam_b1=self.config.adam_b1, adam_b2=self.config.adam_b2, adam_eps_root=self.config.adam_eps_root, )
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the Muon stage-local optimizer for pipeline-parallel training. Delegates to :func:`_build_metadata_mpmd_optimizer` to attach stage-local metadata and a leafwise Muon update kernel to the optimizer chain. Args: scheduler: Learning rate schedule. optimizer: Fully assembled optimizer chain from the factory. **tx_kwargs: Factory-level transformation options. Returns: A :class:`StageLocalGradientTransformation` supporting both normal Optax updates and stage-local PP updates. """ return _build_metadata_mpmd_optimizer( self.config, scheduler, optimizer=optimizer, stage_local_apply=_apply_muon_stage_local, **tx_kwargs, )
[docs]@register_optimizer("quad") @dataclasses.dataclass class QuadOptimizer(OptimizerBuilder): """Builder for Quad (White Kron with QUAD update) optimizer. Quad is a Kronecker-factored preconditioned optimizer that uses the QUAD preconditioner update style. It provides efficient second-order optimization by approximating the inverse Fisher information matrix using Kronecker products. This optimizer is particularly effective for training deep neural networks, especially transformers, where second-order information can significantly improve convergence. Attributes: config (WhiteKronConfig): Configuration object containing Quad optimizer hyperparameters including preconditioner settings, block size, sharding configurations, and numerical stability parameters. Example: >>> from eformer.optimizers import WhiteKronConfig >>> import optax >>> config = WhiteKronConfig(b1=0.95, preconditioner_lr=0.7) >>> builder = QuadOptimizer(config=config) >>> scheduler = optax.constant_schedule(1e-4) >>> optimizer = builder.build(scheduler) """ config: WhiteKronConfig
[docs] def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """Build the Quad optimizer transformation. Args: scheduler (optax.Schedule): Learning rate schedule to use for the optimizer. Returns: optax.GradientTransformation: The Quad optimizer transformation using QUAD-style Kronecker-factored preconditioning for efficient second-order optimization. """ return quad( learning_rate=scheduler, lr_style=self.config.lr_style, b1=self.config.b1, weight_decay=self.config.weight_decay, weight_decay_mask=self.config.weight_decay_mask, normalize_grads=self.config.normalize_grads, max_size_dense=self.config.max_size_dense, preconditioner_lr=self.config.preconditioner_lr, preconditioner_init_scale=self.config.preconditioner_init_scale, dtype=self.config.dtype, scanned_layers=self.config.scanned_layers, block_size=self.config.block_size, pipeline_axis_name=self.config.pipeline_axis_name, pipeline_axis_size=self.config.pipeline_axis_size, params_partition_specs=self.config.params_partition_specs, noise_scale=self.config.noise_scale, )
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the Quad stage-local optimizer for pipeline-parallel training. Delegates to :func:`_build_metadata_mpmd_optimizer` to attach stage-local metadata and a leafwise WhiteKron (Quad) update kernel to the optimizer chain. Args: scheduler: Learning rate schedule. optimizer: Fully assembled optimizer chain from the factory. **tx_kwargs: Factory-level transformation options. Returns: A :class:`StageLocalGradientTransformation` supporting both normal Optax updates and stage-local PP updates. """ return _build_metadata_mpmd_optimizer( self.config, scheduler, optimizer=optimizer, stage_local_apply=_apply_quad_stage_local, **tx_kwargs, )
[docs]@register_optimizer("skew") @dataclasses.dataclass class SkewOptimizer(OptimizerBuilder): """Builder for Skew (White Kron with skew update) optimizer. Skew is a Kronecker-factored preconditioned optimizer that uses the skew preconditioner update style. It provides efficient second-order optimization with a different update rule compared to the QUAD variant. The skew update uses a Procrustes step to maintain orthogonality of the preconditioner, which can lead to more stable training in certain scenarios. Attributes: config (WhiteKronConfig): Configuration object containing Skew optimizer hyperparameters including preconditioner settings, block size, sharding configurations, and numerical stability parameters. Example: >>> from eformer.optimizers import WhiteKronConfig >>> import optax >>> config = WhiteKronConfig(b1=0.95, preconditioner_lr=0.7) >>> builder = SkewOptimizer(config=config) >>> scheduler = optax.constant_schedule(1e-4) >>> optimizer = builder.build(scheduler) """ config: WhiteKronConfig
[docs] def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """Build the Skew optimizer transformation. Args: scheduler (optax.Schedule): Learning rate schedule to use for the optimizer. Returns: optax.GradientTransformation: The Skew optimizer transformation using skew-style Kronecker-factored preconditioning with Procrustes orthogonalization for efficient second-order optimization. """ return skew( learning_rate=scheduler, lr_style=self.config.lr_style, b1=self.config.b1, weight_decay=self.config.weight_decay, weight_decay_mask=self.config.weight_decay_mask, normalize_grads=self.config.normalize_grads, max_size_dense=self.config.max_size_dense, preconditioner_lr=self.config.preconditioner_lr, preconditioner_init_scale=self.config.preconditioner_init_scale, dtype=self.config.dtype, scanned_layers=self.config.scanned_layers, block_size=self.config.block_size, pipeline_axis_name=self.config.pipeline_axis_name, pipeline_axis_size=self.config.pipeline_axis_size, params_partition_specs=self.config.params_partition_specs, noise_scale=self.config.noise_scale, )
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the Skew stage-local optimizer for pipeline-parallel training. Delegates to :func:`_build_metadata_mpmd_optimizer` to attach stage-local metadata and a leafwise WhiteKron (Skew) update kernel to the optimizer chain. Args: scheduler: Learning rate schedule. optimizer: Fully assembled optimizer chain from the factory. **tx_kwargs: Factory-level transformation options. Returns: A :class:`StageLocalGradientTransformation` supporting both normal Optax updates and stage-local PP updates. """ return _build_metadata_mpmd_optimizer( self.config, scheduler, optimizer=optimizer, stage_local_apply=_apply_skew_stage_local, **tx_kwargs, )
[docs]@register_optimizer("mars") @dataclasses.dataclass class MarsOptimizer(OptimizerBuilder): """Builder for Mars (Matrix-wise Adaptive Regularized Scaling) optimizer. Mars improves upon Adam by using a variance reduction technique with gradient momentum from the previous step. This can lead to improved convergence and better generalization, particularly for training large language models. Reference: https://arxiv.org/abs/2411.10438 Attributes: config (MarsConfig): Configuration object containing Mars hyperparameters including beta coefficients, gamma for gradient momentum, epsilon for numerical stability, and gradient clipping threshold. Example: >>> from eformer.optimizers import MarsConfig >>> import optax >>> config = MarsConfig(beta1=0.95, beta2=0.99, gamma=0.025) >>> builder = MarsOptimizer(config=config) >>> scheduler = optax.constant_schedule(1e-4) >>> optimizer = builder.build(scheduler) """ config: MarsConfig
[docs] def build(self, scheduler: optax.Schedule) -> optax.GradientTransformation: """Build the Mars optimizer transformation. Args: scheduler (optax.Schedule): Learning rate schedule to use for the optimizer. Returns: optax.GradientTransformation: The Mars optimizer transformation that uses variance reduction with gradient momentum for improved convergence. """ return mars( learning_rate=scheduler, b1=self.config.beta1, b2=self.config.beta2, gamma=self.config.gamma, eps=self.config.epsilon, max_grad_norm=self.config.max_grad_norm, )
[docs] def build_mpmd( self, scheduler: optax.Schedule, *, optimizer: optax.GradientTransformation, **tx_kwargs: tp.Any, ) -> optax.GradientTransformation: """Build the Mars stage-local optimizer for pipeline-parallel training. Delegates to :func:`_build_metadata_mpmd_optimizer` to attach stage-local metadata and a leafwise Mars update kernel to the optimizer chain. Args: scheduler: Learning rate schedule. optimizer: Fully assembled optimizer chain from the factory. **tx_kwargs: Factory-level transformation options. Returns: A :class:`StageLocalGradientTransformation` supporting both normal Optax updates and stage-local PP updates. """ return _build_metadata_mpmd_optimizer( self.config, scheduler, optimizer=optimizer, stage_local_apply=_apply_mars_stage_local, **tx_kwargs, )