eformer.optimizers._config

Contents

eformer.optimizers._config#

class eformer.optimizers._config.AdafactorConfig(min_dim_size_to_factor: int = 128, decay_rate: float = 0.8, decay_offset: int = 0, multiply_by_parameter_scale: bool = True, clipping_threshold: float | None = 1.0, momentum: float | None = None, dtype_momentum: ~numpy.dtype = <class 'jax.numpy.float32'>, weight_decay_rate: float | None = None, eps: float = 1e-30, factored: bool = True)[source]#

Bases: SerializationMixin

Configuration class for the Adafactor optimizer.

min_dim_size_to_factor#

Minimum dimension size for factoring. Defaults to 128.

Type

int

decay_rate#

Decay rate for second-moment estimator. Defaults to 0.8.

Type

float

decay_offset#

Decay offset. Defaults to 0.

Type

int

multiply_by_parameter_scale#

Whether to multiply by parameter scale. Defaults to True.

Type

bool

clipping_threshold#

Clipping threshold for updates. Defaults to 1.0.

Type

Optional[float]

momentum#

Momentum factor. Defaults to None.

Type

Optional[float]

dtype_momentum#

Data type for momentum. Defaults to jnp.float32.

Type

jnp.dtype

weight_decay_rate#

Weight decay rate. Defaults to None.

Type

Optional[float]

eps#

Small constant for numerical stability. Defaults to 1e-30.

Type

float

factored#

Whether to use factored second-moment estimates. Defaults to True.

Type

bool

clipping_threshold: float | None = 1.0#
decay_offset: int = 0#
decay_rate: float = 0.8#
dtype_momentum#

alias of float32

eps: float = 1e-30#
factored: bool = True#
min_dim_size_to_factor: int = 128#
momentum: float | None = None#
multiply_by_parameter_scale: bool = True#
weight_decay_rate: float | None = None#
class eformer.optimizers._config.AdamWConfig(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: numpy.dtype | None = None)[source]#

Bases: SerializationMixin

Configuration class for the AdamW optimizer.

b1#

Exponential decay rate for the first moment estimates. Defaults to 0.9.

Type

float

b2#

Exponential decay rate for the second moment estimates. Defaults to 0.999.

Type

float

eps#

Small constant for numerical stability. Defaults to 1e-8.

Type

float

eps_root#

Small constant for root calculations. Defaults to 0.0.

Type

float

mu_dtype#

Data type for momentum. Defaults to None.

Type

Optional[jnp.dtype]

b1: float = 0.9#
b2: float = 0.999#
eps: float = 1e-08#
eps_root: float = 0.0#
mu_dtype: numpy.dtype | None = None#
class eformer.optimizers._config.KronConfig(beta1: float = 0.9, weight_decay: float = 0.1, max_grad_norm: float | None = 1.0, normalize_grads: bool = False, preconditioner_update_probability: float = 0.05, update_prob_flat_start: int = 500, max_size_triangular: int = 25000, min_ndim_triangular: int = 2, memory_save_mode: str | None = None, preconditioner_lr: float = 0.1, preconditioner_init_scale: float = 1.0, mu_dtype: numpy.dtype | None = None, precond_dtype: numpy.dtype | None = None, precond_update_precision: str | None = 'tensorfloat32', precond_grads_precision: str | None = None, lax_map_scanned_layers: bool = False, lax_map_batch_size: int = 8, merge_small_dims: bool = True, target_merged_dim_size: int = 8192, partition_grads_into_blocks: bool = True, block_size: int = 256)[source]#

Bases: SerializationMixin

Configuration class for the Kron (PSGD Kron) optimizer.

Kron uses Kronecker-factored preconditioners for efficient second-order optimization, particularly effective for neural network training.

beta1#

Momentum parameter. Common values are 0.9 or 0.95. Defaults to 0.9.

Type

float

weight_decay#

Weight decay coefficient. Defaults to 0.1.

Type

float

max_grad_norm#

Optional gradient norm clipping value. Defaults to 1.0.

Type

float | None

normalize_grads#

Whether to normalize incoming gradients to unit norm layer-wise. Can help with stability. Defaults to False.

Type

bool

preconditioner_update_probability#

Final probability of updating the preconditioner. Defaults to 0.05 (update every 20 steps on average).

Type

float

update_prob_flat_start#

Number of steps to keep update probability at 1.0 before annealing. Defaults to 500.

Type

int

max_size_triangular#

Maximum size for triangular factorization. Defaults to 25000.

Type

int

min_ndim_triangular#

Minimum number of dimensions for triangular factorization. Defaults to 2.

Type

int

memory_save_mode#

Memory saving mode. Can be None, “one_diag”, or “all_diag”. Defaults to None.

Type

str | None

preconditioner_lr#

Learning rate for preconditioner updates. Defaults to 0.1.

Type

float

preconditioner_init_scale#

Initial scale for preconditioner. Defaults to 1.0.

Type

float

mu_dtype#

Data type for momentum computation. Defaults to None.

Type

jnp.dtype | None

precond_dtype#

Data type for preconditioners. Defaults to None.

Type

jnp.dtype | None

precond_update_precision#

Precision for preconditioner updates. Defaults to “tensorfloat32”.

Type

str | None

precond_grads_precision#

Precision for gradient preconditioning. Defaults to None.

Type

str | None

lax_map_scanned_layers#

Whether to use lax.map for scanned layers. Defaults to False.

Type

bool

lax_map_batch_size#

Batch size for lax.map. Defaults to 8.

Type

int

merge_small_dims#

Whether to merge small dimensions. Defaults to True.

Type

bool

target_merged_dim_size#

Target size for merged dimensions. Defaults to 8192.

Type

int

partition_grads_into_blocks#

Whether to partition gradients into blocks. Defaults to True.

Type

bool

block_size#

Block size for gradient partitioning. Defaults to 256.

Type

int

beta1: float = 0.9#
block_size: int = 256#
lax_map_batch_size: int = 8#
lax_map_scanned_layers: bool = False#
max_grad_norm: float | None = 1.0#
max_size_triangular: int = 25000#
memory_save_mode: str | None = None#
merge_small_dims: bool = True#
min_ndim_triangular: int = 2#
mu_dtype: numpy.dtype | None = None#
normalize_grads: bool = False#
partition_grads_into_blocks: bool = True#
precond_dtype: numpy.dtype | None = None#
precond_grads_precision: str | None = None#
precond_update_precision: str | None = 'tensorfloat32'#
preconditioner_init_scale: float = 1.0#
preconditioner_lr: float = 0.1#
preconditioner_update_probability: float = 0.05#
target_merged_dim_size: int = 8192#
update_prob_flat_start: int = 500#
weight_decay: float = 0.1#
class eformer.optimizers._config.LionConfig(b1: float = 0.9, b2: float = 0.99, mu_dtype: numpy.dtype | None = None)[source]#

Bases: SerializationMixin

Configuration class for the Lion optimizer.

b1#

Exponential decay rate for the first moment estimates. Defaults to 0.9.

Type

float

b2#

Exponential decay rate for the second moment estimates. Defaults to 0.99.

Type

float

mu_dtype#

Data type for momentum. Defaults to None.

Type

Optional[jnp.dtype]

b1: float = 0.9#
b2: float = 0.99#
mu_dtype: numpy.dtype | None = None#
class eformer.optimizers._config.MarsConfig(weight_decay: float = 0.1, beta1: float = 0.95, beta2: float = 0.99, gamma: float = 0.025, epsilon: float = 1e-08, max_grad_norm: float | None = 1.0)[source]#

Bases: SerializationMixin

Configuration class for the Mars optimizer.

Mars (Matrix-wise Adaptive Regularized Scaling) optimizer improves upon Adam by using matrix-wise adaptive regularization.

Reference: https://arxiv.org/abs/2411.10438

weight_decay#

Weight decay coefficient. Defaults to 0.1.

Type

float

beta1#

Exponential decay rate for first moment estimates. Defaults to 0.95.

Type

float

beta2#

Exponential decay rate for second moment estimates. Defaults to 0.99.

Type

float

gamma#

Decay rate for exponentially weighted average of gradient from previous step. Defaults to 0.025.

Type

float

epsilon#

Small constant for numerical stability. Defaults to 1e-8.

Type

float

max_grad_norm#

Maximum gradient norm for clipping. Defaults to 1.0.

Type

float | None

beta1: float = 0.95#
beta2: float = 0.99#
epsilon: float = 1e-08#
gamma: float = 0.025#
max_grad_norm: float | None = 1.0#
weight_decay: float = 0.1#
class eformer.optimizers._config.MuonConfig(ns_coeffs: tuple[float, float, float] = (3.4445, -4.775, 2.0315), ns_steps: int = 5, beta: float = 0.95, eps: float = 1e-08, weight_decay: float = 0.0, weight_decay_mask: Any | None = None, mu_dtype: numpy.dtype | None = None, nesterov: bool = True, adaptive: bool = False, adam_b1: float = 0.9, adam_b2: float = 0.999, adam_eps_root: float = 0.0, adam_weight_decay: float = 0.0)[source]#

Bases: SerializationMixin

Configuration class for the Muon (Momentum Orthogonalized by Newton-schulz) optimizer.

Muon is designed for 2D parameters (matrices) and uses Newton-Schulz method to orthogonalize momentum. Non-2D parameters are processed through an Adam optimizer.

ns_coeffs#

Coefficients for the Newton-schulz method. Defaults to (3.4445, -4.775, 2.0315).

Type

tuple[float, float, float]

ns_steps#

Number of Newton-schulz iterations. Defaults to 5.

Type

int

beta#

Decay rate for the exponentially weighted average of grads. Defaults to 0.95.

Type

float

eps#

Term added to the denominator to improve numerical stability. Defaults to 1e-8.

Type

float

weight_decay#

Strength of the weight decay regularization. Defaults to 0.0.

Type

float

weight_decay_mask#

Weight decay mask. Defaults to None.

Type

Any | None

mu_dtype#

Data type for momentum computation. Defaults to None.

Type

jnp.dtype | None

nesterov#

Whether to use Nesterov momentum. Defaults to True.

Type

bool

adaptive#

Whether to scale the updates by the dual norm of the original updates. Defaults to False.

Type

bool

adam_b1#

Exponential decay rate for first moment estimates in Adam (for non-2D params). Defaults to 0.9.

Type

float

adam_b2#

Exponential decay rate for second moment estimates in Adam (for non-2D params). Defaults to 0.999.

Type

float

adam_eps_root#

Small constant for root calculations in Adam. Defaults to 0.0.

Type

float

adam_weight_decay#

Weight decay for Adam optimizer (for non-2D params). Defaults to 0.0.

Type

float

adam_b1: float = 0.9#
adam_b2: float = 0.999#
adam_eps_root: float = 0.0#
adam_weight_decay: float = 0.0#
adaptive: bool = False#
beta: float = 0.95#
eps: float = 1e-08#
mu_dtype: numpy.dtype | None = None#
nesterov: bool = True#
ns_coeffs: tuple[float, float, float] = (3.4445, -4.775, 2.0315)#
ns_steps: int = 5#
weight_decay: float = 0.0#
weight_decay_mask: Any | None = None#
class eformer.optimizers._config.RMSPropConfig(decay: float = 0.9, initial_scale: float = 0.0, momentum: float | None = None, nesterov: bool = False, eps: float = 1e-08)[source]#

Bases: SerializationMixin

Configuration class for the RMSProp optimizer.

decay#

Decay rate for the moving average. Defaults to 0.9.

Type

float

initial_scale#

Initial scale for the moving average. Defaults to 0.0.

Type

float

momentum#

Momentum factor. Defaults to None.

Type

Optional[float]

nesterov#

Whether to use Nesterov momentum. Defaults to False.

Type

bool

eps#

Small constant for numerical stability. Defaults to 1e-8.

Type

float

decay: float = 0.9#
eps: float = 1e-08#
initial_scale: float = 0.0#
momentum: float | None = None#
nesterov: bool = False#
class eformer.optimizers._config.SchedulerConfig(scheduler_type: Optional[Literal['linear', 'cosine']] = None, learning_rate: float = 5e-05, learning_rate_end: float | None = None, warmup_steps: int | None = None, steps: int | None = None, exponent: float = 1.0)[source]#

Bases: SerializationMixin

Configuration class for learning rate schedulers.

scheduler_type#

Type of scheduler to use.

Type

Optional[Literal[“linear”, “cosine”]]

learning_rate#

Initial learning rate. Defaults to 5e-5.

Type

float

learning_rate_end#

Final learning rate for linear scheduler.

Type

Optional[float]

warmup_steps#

Number of warmup steps.

Type

Optional[int]

steps#

Total number of steps. Required for non-constant schedulers.

Type

Optional[int]

exponent#

Exponent for polynomial decay. Defaults to 1.0.

Type

float

__post_init__()[source]#

Validates the configuration after initialization.

_validate()[source]#

Performs validation checks on the configuration.

exponent: float = 1.0#
learning_rate: float = 5e-05#
learning_rate_end: float | None = None#
scheduler_type: Optional[Literal['linear', 'cosine']] = None#
steps: int | None = None#
warmup_steps: int | None = None#
class eformer.optimizers._config.ScionConfig(momentum: float = 0.95, backend_steps: int = 10, beta1: float = 0.9, epsilon: float = 1e-08, unconstrained: bool = False, spectral_radius: float = 50, sign_radius: float = 3000)[source]#

Bases: SerializationMixin

Configuration class for the Scion optimizer.

Scion combines spectral normalization with sign-based updates for different parameter types, providing efficient training for neural networks.

Reference: https://arxiv.org/abs/2502.07529

momentum#

Momentum parameter for both spectral and sign methods. Defaults to 0.95.

Type

float

backend_steps#

Number of steps for Newton-Schulz orthogonalization. Defaults to 10.

Type

int

beta1#

Beta1 parameter for sign method. Defaults to 0.9.

Type

float

epsilon#

Small constant for numerical stability. Defaults to 1e-8.

Type

float

unconstrained#

Whether to use unconstrained version. Defaults to False.

Type

bool

spectral_radius#

Scaling factor for spectral method. Defaults to 50.

Type

float

sign_radius#

Scaling factor for sign method. Defaults to 3000.

Type

float

backend_steps: int = 10#
beta1: float = 0.9#
epsilon: float = 1e-08#
momentum: float = 0.95#
sign_radius: float = 3000#
spectral_radius: float = 50#
unconstrained: bool = False#
class eformer.optimizers._config.SerializationMixin[source]#

Bases: object

Mixin class providing serialization capabilities for configuration classes.

This class provides methods to convert instances to and from dictionaries and JSON strings, making it easy to serialize and deserialize configuration objects.

to_dict()[source]#

Convert the instance to a dictionary, filtering out private fields.

from_dict()[source]#

Create an instance from a dictionary with error checking.

to_json()[source]#

Serialize the instance to a JSON string.

from_json()[source]#

Create an instance from a JSON string.

classmethod from_dict(data: dict[str, Any]) T[source]#

Create an instance from a dictionary with error checking.

Parameters

data (dict) – A dictionary containing the data to populate the instance.

Returns

An instance of the class populated with the provided data.

Return type

T

Raises

Warning – If unexpected keys are present in the input dictionary.

classmethod from_json(json_str: str) T[source]#

Create an instance from a JSON string.

Parameters

json_str (str) – A JSON string containing the data to populate the instance.

Returns

An instance of the class populated with the data from the JSON string.

Return type

T

to_dict() dict[str, Any][source]#

Convert the instance to a dictionary, filtering out private fields.

Returns

A dictionary representation of the instance, excluding private fields.

Return type

dict

to_json() str[source]#

Serialize the instance to a JSON string.

Returns

A JSON string representation of the instance.

Return type

str

class eformer.optimizers._config.SoapConfig(weight_decay: float = 0.0, beta1: float = 0.95, beta2: float = 0.95, shampoo_beta: float = 0.95, epsilon: float = 1e-08, max_grad_norm: float | None = 1.0, haps: list[int] | None = None, schedule_list: list[str] | None = None, precondition_frequency: int = 10, max_precond_dim: int = 10000, merge_small_dims: bool = True, one_diag: bool = False, target_merged_dim_size: int = 2048, mu_dtype: numpy.dtype | None = None, precond_dtype: numpy.dtype | None = None, partition_grads_into_blocks: bool = True, block_size: int = 256)[source]#

Bases: SerializationMixin

Configuration class for the SOAP (Shampoo with Orthogonal and Adaptive Preconditioning) optimizer.

SOAP combines Shampoo’s second-order optimization with orthogonal preconditioning and adaptive scheduling for improved convergence.

weight_decay#

Weight decay coefficient. Defaults to 0.0.

Type

float

beta1#

Momentum parameter for first moment estimates. Defaults to 0.95.

Type

float

beta2#

Momentum parameter for second moment estimates. Defaults to 0.95.

Type

float

shampoo_beta#

Beta parameter for Shampoo preconditioning. Defaults to 0.95.

Type

float

epsilon#

Small constant for numerical stability. Defaults to 1e-8.

Type

float

max_grad_norm#

Maximum gradient norm for clipping. Defaults to 1.0.

Type

float | None

haps#

HAP schedule parameters. Defaults to None.

Type

list[int] | None

schedule_list#

Schedule list. Defaults to None.

Type

list[str] | None

precondition_frequency#

Frequency of preconditioner updates. Defaults to 10.

Type

int

max_precond_dim#

Maximum dimension for preconditioning. Defaults to 10000.

Type

int

merge_small_dims#

Whether to merge small dimensions. Defaults to True.

Type

bool

one_diag#

Whether to use diagonal preconditioning only. Defaults to False.

Type

bool

target_merged_dim_size#

Target size for merged dimensions. Defaults to 2048.

Type

int

mu_dtype#

Data type for momentum computation. Defaults to None.

Type

jnp.dtype | None

precond_dtype#

Data type for preconditioners. Defaults to None.

Type

jnp.dtype | None

partition_grads_into_blocks#

Whether to partition gradients into blocks. Defaults to True.

Type

bool

block_size#

Block size for gradient partitioning. Defaults to 256.

Type

int

beta1: float = 0.95#
beta2: float = 0.95#
block_size: int = 256#
epsilon: float = 1e-08#
haps: list[int] | None = None#
max_grad_norm: float | None = 1.0#
max_precond_dim: int = 10000#
merge_small_dims: bool = True#
mu_dtype: numpy.dtype | None = None#
one_diag: bool = False#
partition_grads_into_blocks: bool = True#
precond_dtype: numpy.dtype | None = None#
precondition_frequency: int = 10#
schedule_list: list[str] | None = None#
shampoo_beta: float = 0.95#
target_merged_dim_size: int = 2048#
weight_decay: float = 0.0#
class eformer.optimizers._config.WhiteKronConfig(lr_style: str | None = 'adam', b1: float = 0.95, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: Any | None = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: Any | None = None, noise_scale: float = 1e-09, weight_decay: float = 0.1, weight_decay_mask: Any | None = None)[source]#

Bases: SerializationMixin

Configuration class for the White Kron optimizer.

White Kron is a Kronecker-factored preconditioned optimizer that uses different update styles (skew or quad) for efficient second-order optimization.

lr_style#

Learning rate style. Defaults to “adam”.

Type

str | None

b1#

Exponential decay rate for first moment estimates. Defaults to 0.95.

Type

float

normalize_grads#

Whether to normalize gradients. Defaults to False.

Type

bool

max_size_dense#

Maximum size for dense preconditioning. Defaults to 16384.

Type

int

preconditioner_lr#

Learning rate for preconditioner updates. Defaults to 0.7.

Type

float

preconditioner_init_scale#

Initial scale for preconditioner. Defaults to 1.0.

Type

float

dtype#

Data type for computations. Defaults to jnp.bfloat16.

Type

str | jnp.dtype

scanned_layers#

Scanned layers configuration. Defaults to None.

Type

Any | None

block_size#

Block size for matrix operations. Defaults to 256.

Type

int

pipeline_axis_name#

Name of pipeline axis for sharding. Defaults to None.

Type

str | None

pipeline_axis_size#

Size of pipeline axis. Defaults to 1.

Type

int

params_partition_specs#

Parameter partition specifications. Defaults to None.

Type

Any | None

noise_scale#

Scale of noise added for numerical stability. Defaults to 1e-9.

Type

float

weight_decay#

Weight decay coefficient. Defaults to 0.1.

Type

float

weight_decay_mask#

Weight decay mask. Defaults to None.

Type

Any | None

b1: float = 0.95#
block_size: int = 256#
dtype#

alias of bfloat16

lr_style: str | None = 'adam'#
max_size_dense: int = 16384#
noise_scale: float = 1e-09#
normalize_grads: bool = False#
params_partition_specs: Any | None = None#
pipeline_axis_name: str | None = None#
pipeline_axis_size: int = 1#
preconditioner_init_scale: float = 1.0#
preconditioner_lr: float = 0.7#
scanned_layers: Any | None = None#
weight_decay: float = 0.1#
weight_decay_mask: Any | None = None#