eformer.optimizers._config#
- class eformer.optimizers._config.AdafactorConfig(min_dim_size_to_factor: int = 128, decay_rate: float = 0.8, decay_offset: int = 0, multiply_by_parameter_scale: bool = True, clipping_threshold: float | None = 1.0, momentum: float | None = None, dtype_momentum: ~numpy.dtype = <class 'jax.numpy.float32'>, weight_decay_rate: float | None = None, eps: float = 1e-30, factored: bool = True)[source]#
Bases:
SerializationMixinConfiguration class for the Adafactor optimizer.
- min_dim_size_to_factor#
Minimum dimension size for factoring. Defaults to 128.
- Type
int
- decay_rate#
Decay rate for second-moment estimator. Defaults to 0.8.
- Type
float
- decay_offset#
Decay offset. Defaults to 0.
- Type
int
- multiply_by_parameter_scale#
Whether to multiply by parameter scale. Defaults to True.
- Type
bool
- clipping_threshold#
Clipping threshold for updates. Defaults to 1.0.
- Type
Optional[float]
- momentum#
Momentum factor. Defaults to None.
- Type
Optional[float]
- dtype_momentum#
Data type for momentum. Defaults to jnp.float32.
- Type
jnp.dtype
- weight_decay_rate#
Weight decay rate. Defaults to None.
- Type
Optional[float]
- eps#
Small constant for numerical stability. Defaults to 1e-30.
- Type
float
- factored#
Whether to use factored second-moment estimates. Defaults to True.
- Type
bool
- clipping_threshold: float | None = 1.0#
- decay_offset: int = 0#
- decay_rate: float = 0.8#
- eps: float = 1e-30#
- factored: bool = True#
- min_dim_size_to_factor: int = 128#
- momentum: float | None = None#
- multiply_by_parameter_scale: bool = True#
- weight_decay_rate: float | None = None#
- class eformer.optimizers._config.AdamWConfig(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: numpy.dtype | None = None)[source]#
Bases:
SerializationMixinConfiguration class for the AdamW optimizer.
- b1#
Exponential decay rate for the first moment estimates. Defaults to 0.9.
- Type
float
- b2#
Exponential decay rate for the second moment estimates. Defaults to 0.999.
- Type
float
- eps#
Small constant for numerical stability. Defaults to 1e-8.
- Type
float
- eps_root#
Small constant for root calculations. Defaults to 0.0.
- Type
float
- mu_dtype#
Data type for momentum. Defaults to None.
- Type
Optional[jnp.dtype]
- b1: float = 0.9#
- b2: float = 0.999#
- eps: float = 1e-08#
- eps_root: float = 0.0#
- mu_dtype: numpy.dtype | None = None#
- class eformer.optimizers._config.KronConfig(beta1: float = 0.9, weight_decay: float = 0.1, max_grad_norm: float | None = 1.0, normalize_grads: bool = False, preconditioner_update_probability: float = 0.05, update_prob_flat_start: int = 500, max_size_triangular: int = 25000, min_ndim_triangular: int = 2, memory_save_mode: str | None = None, preconditioner_lr: float = 0.1, preconditioner_init_scale: float = 1.0, mu_dtype: numpy.dtype | None = None, precond_dtype: numpy.dtype | None = None, precond_update_precision: str | None = 'tensorfloat32', precond_grads_precision: str | None = None, lax_map_scanned_layers: bool = False, lax_map_batch_size: int = 8, merge_small_dims: bool = True, target_merged_dim_size: int = 8192, partition_grads_into_blocks: bool = True, block_size: int = 256)[source]#
Bases:
SerializationMixinConfiguration class for the Kron (PSGD Kron) optimizer.
Kron uses Kronecker-factored preconditioners for efficient second-order optimization, particularly effective for neural network training.
- beta1#
Momentum parameter. Common values are 0.9 or 0.95. Defaults to 0.9.
- Type
float
- weight_decay#
Weight decay coefficient. Defaults to 0.1.
- Type
float
- max_grad_norm#
Optional gradient norm clipping value. Defaults to 1.0.
- Type
float | None
- normalize_grads#
Whether to normalize incoming gradients to unit norm layer-wise. Can help with stability. Defaults to False.
- Type
bool
- preconditioner_update_probability#
Final probability of updating the preconditioner. Defaults to 0.05 (update every 20 steps on average).
- Type
float
- update_prob_flat_start#
Number of steps to keep update probability at 1.0 before annealing. Defaults to 500.
- Type
int
- max_size_triangular#
Maximum size for triangular factorization. Defaults to 25000.
- Type
int
- min_ndim_triangular#
Minimum number of dimensions for triangular factorization. Defaults to 2.
- Type
int
- memory_save_mode#
Memory saving mode. Can be None, “one_diag”, or “all_diag”. Defaults to None.
- Type
str | None
- preconditioner_lr#
Learning rate for preconditioner updates. Defaults to 0.1.
- Type
float
- preconditioner_init_scale#
Initial scale for preconditioner. Defaults to 1.0.
- Type
float
- mu_dtype#
Data type for momentum computation. Defaults to None.
- Type
jnp.dtype | None
- precond_dtype#
Data type for preconditioners. Defaults to None.
- Type
jnp.dtype | None
- precond_update_precision#
Precision for preconditioner updates. Defaults to “tensorfloat32”.
- Type
str | None
- precond_grads_precision#
Precision for gradient preconditioning. Defaults to None.
- Type
str | None
- lax_map_scanned_layers#
Whether to use lax.map for scanned layers. Defaults to False.
- Type
bool
- lax_map_batch_size#
Batch size for lax.map. Defaults to 8.
- Type
int
- merge_small_dims#
Whether to merge small dimensions. Defaults to True.
- Type
bool
- target_merged_dim_size#
Target size for merged dimensions. Defaults to 8192.
- Type
int
- partition_grads_into_blocks#
Whether to partition gradients into blocks. Defaults to True.
- Type
bool
- block_size#
Block size for gradient partitioning. Defaults to 256.
- Type
int
- beta1: float = 0.9#
- block_size: int = 256#
- lax_map_batch_size: int = 8#
- lax_map_scanned_layers: bool = False#
- max_grad_norm: float | None = 1.0#
- max_size_triangular: int = 25000#
- memory_save_mode: str | None = None#
- merge_small_dims: bool = True#
- min_ndim_triangular: int = 2#
- mu_dtype: numpy.dtype | None = None#
- normalize_grads: bool = False#
- partition_grads_into_blocks: bool = True#
- precond_dtype: numpy.dtype | None = None#
- precond_grads_precision: str | None = None#
- precond_update_precision: str | None = 'tensorfloat32'#
- preconditioner_init_scale: float = 1.0#
- preconditioner_lr: float = 0.1#
- preconditioner_update_probability: float = 0.05#
- target_merged_dim_size: int = 8192#
- update_prob_flat_start: int = 500#
- weight_decay: float = 0.1#
- class eformer.optimizers._config.LionConfig(b1: float = 0.9, b2: float = 0.99, mu_dtype: numpy.dtype | None = None)[source]#
Bases:
SerializationMixinConfiguration class for the Lion optimizer.
- b1#
Exponential decay rate for the first moment estimates. Defaults to 0.9.
- Type
float
- b2#
Exponential decay rate for the second moment estimates. Defaults to 0.99.
- Type
float
- mu_dtype#
Data type for momentum. Defaults to None.
- Type
Optional[jnp.dtype]
- b1: float = 0.9#
- b2: float = 0.99#
- mu_dtype: numpy.dtype | None = None#
- class eformer.optimizers._config.MarsConfig(weight_decay: float = 0.1, beta1: float = 0.95, beta2: float = 0.99, gamma: float = 0.025, epsilon: float = 1e-08, max_grad_norm: float | None = 1.0)[source]#
Bases:
SerializationMixinConfiguration class for the Mars optimizer.
Mars (Matrix-wise Adaptive Regularized Scaling) optimizer improves upon Adam by using matrix-wise adaptive regularization.
Reference: https://arxiv.org/abs/2411.10438
- weight_decay#
Weight decay coefficient. Defaults to 0.1.
- Type
float
- beta1#
Exponential decay rate for first moment estimates. Defaults to 0.95.
- Type
float
- beta2#
Exponential decay rate for second moment estimates. Defaults to 0.99.
- Type
float
- gamma#
Decay rate for exponentially weighted average of gradient from previous step. Defaults to 0.025.
- Type
float
- epsilon#
Small constant for numerical stability. Defaults to 1e-8.
- Type
float
- max_grad_norm#
Maximum gradient norm for clipping. Defaults to 1.0.
- Type
float | None
- beta1: float = 0.95#
- beta2: float = 0.99#
- epsilon: float = 1e-08#
- gamma: float = 0.025#
- max_grad_norm: float | None = 1.0#
- weight_decay: float = 0.1#
- class eformer.optimizers._config.MuonConfig(ns_coeffs: tuple[float, float, float] = (3.4445, -4.775, 2.0315), ns_steps: int = 5, beta: float = 0.95, eps: float = 1e-08, weight_decay: float = 0.0, weight_decay_mask: Any | None = None, mu_dtype: numpy.dtype | None = None, nesterov: bool = True, adaptive: bool = False, adam_b1: float = 0.9, adam_b2: float = 0.999, adam_eps_root: float = 0.0, adam_weight_decay: float = 0.0)[source]#
Bases:
SerializationMixinConfiguration class for the Muon (Momentum Orthogonalized by Newton-schulz) optimizer.
Muon is designed for 2D parameters (matrices) and uses Newton-Schulz method to orthogonalize momentum. Non-2D parameters are processed through an Adam optimizer.
- ns_coeffs#
Coefficients for the Newton-schulz method. Defaults to (3.4445, -4.775, 2.0315).
- Type
tuple[float, float, float]
- ns_steps#
Number of Newton-schulz iterations. Defaults to 5.
- Type
int
- beta#
Decay rate for the exponentially weighted average of grads. Defaults to 0.95.
- Type
float
- eps#
Term added to the denominator to improve numerical stability. Defaults to 1e-8.
- Type
float
- weight_decay#
Strength of the weight decay regularization. Defaults to 0.0.
- Type
float
- weight_decay_mask#
Weight decay mask. Defaults to None.
- Type
Any | None
- mu_dtype#
Data type for momentum computation. Defaults to None.
- Type
jnp.dtype | None
- nesterov#
Whether to use Nesterov momentum. Defaults to True.
- Type
bool
- adaptive#
Whether to scale the updates by the dual norm of the original updates. Defaults to False.
- Type
bool
- adam_b1#
Exponential decay rate for first moment estimates in Adam (for non-2D params). Defaults to 0.9.
- Type
float
- adam_b2#
Exponential decay rate for second moment estimates in Adam (for non-2D params). Defaults to 0.999.
- Type
float
- adam_eps_root#
Small constant for root calculations in Adam. Defaults to 0.0.
- Type
float
- adam_weight_decay#
Weight decay for Adam optimizer (for non-2D params). Defaults to 0.0.
- Type
float
- adam_b1: float = 0.9#
- adam_b2: float = 0.999#
- adam_eps_root: float = 0.0#
- adam_weight_decay: float = 0.0#
- adaptive: bool = False#
- beta: float = 0.95#
- eps: float = 1e-08#
- mu_dtype: numpy.dtype | None = None#
- nesterov: bool = True#
- ns_coeffs: tuple[float, float, float] = (3.4445, -4.775, 2.0315)#
- ns_steps: int = 5#
- weight_decay: float = 0.0#
- weight_decay_mask: Any | None = None#
- class eformer.optimizers._config.RMSPropConfig(decay: float = 0.9, initial_scale: float = 0.0, momentum: float | None = None, nesterov: bool = False, eps: float = 1e-08)[source]#
Bases:
SerializationMixinConfiguration class for the RMSProp optimizer.
- decay#
Decay rate for the moving average. Defaults to 0.9.
- Type
float
- initial_scale#
Initial scale for the moving average. Defaults to 0.0.
- Type
float
- momentum#
Momentum factor. Defaults to None.
- Type
Optional[float]
- nesterov#
Whether to use Nesterov momentum. Defaults to False.
- Type
bool
- eps#
Small constant for numerical stability. Defaults to 1e-8.
- Type
float
- decay: float = 0.9#
- eps: float = 1e-08#
- initial_scale: float = 0.0#
- momentum: float | None = None#
- nesterov: bool = False#
- class eformer.optimizers._config.SchedulerConfig(scheduler_type: Optional[Literal['linear', 'cosine']] = None, learning_rate: float = 5e-05, learning_rate_end: float | None = None, warmup_steps: int | None = None, steps: int | None = None, exponent: float = 1.0)[source]#
Bases:
SerializationMixinConfiguration class for learning rate schedulers.
- scheduler_type#
Type of scheduler to use.
- Type
Optional[Literal[“linear”, “cosine”]]
- learning_rate#
Initial learning rate. Defaults to 5e-5.
- Type
float
- learning_rate_end#
Final learning rate for linear scheduler.
- Type
Optional[float]
- warmup_steps#
Number of warmup steps.
- Type
Optional[int]
- steps#
Total number of steps. Required for non-constant schedulers.
- Type
Optional[int]
- exponent#
Exponent for polynomial decay. Defaults to 1.0.
- Type
float
- exponent: float = 1.0#
- learning_rate: float = 5e-05#
- learning_rate_end: float | None = None#
- scheduler_type: Optional[Literal['linear', 'cosine']] = None#
- steps: int | None = None#
- warmup_steps: int | None = None#
- class eformer.optimizers._config.ScionConfig(momentum: float = 0.95, backend_steps: int = 10, beta1: float = 0.9, epsilon: float = 1e-08, unconstrained: bool = False, spectral_radius: float = 50, sign_radius: float = 3000)[source]#
Bases:
SerializationMixinConfiguration class for the Scion optimizer.
Scion combines spectral normalization with sign-based updates for different parameter types, providing efficient training for neural networks.
Reference: https://arxiv.org/abs/2502.07529
- momentum#
Momentum parameter for both spectral and sign methods. Defaults to 0.95.
- Type
float
- backend_steps#
Number of steps for Newton-Schulz orthogonalization. Defaults to 10.
- Type
int
- beta1#
Beta1 parameter for sign method. Defaults to 0.9.
- Type
float
- epsilon#
Small constant for numerical stability. Defaults to 1e-8.
- Type
float
- unconstrained#
Whether to use unconstrained version. Defaults to False.
- Type
bool
- spectral_radius#
Scaling factor for spectral method. Defaults to 50.
- Type
float
- sign_radius#
Scaling factor for sign method. Defaults to 3000.
- Type
float
- backend_steps: int = 10#
- beta1: float = 0.9#
- epsilon: float = 1e-08#
- momentum: float = 0.95#
- sign_radius: float = 3000#
- spectral_radius: float = 50#
- unconstrained: bool = False#
- class eformer.optimizers._config.SerializationMixin[source]#
Bases:
objectMixin class providing serialization capabilities for configuration classes.
This class provides methods to convert instances to and from dictionaries and JSON strings, making it easy to serialize and deserialize configuration objects.
- classmethod from_dict(data: dict[str, Any]) T[source]#
Create an instance from a dictionary with error checking.
- Parameters
data (dict) – A dictionary containing the data to populate the instance.
- Returns
An instance of the class populated with the provided data.
- Return type
T
- Raises
Warning – If unexpected keys are present in the input dictionary.
- classmethod from_json(json_str: str) T[source]#
Create an instance from a JSON string.
- Parameters
json_str (str) – A JSON string containing the data to populate the instance.
- Returns
An instance of the class populated with the data from the JSON string.
- Return type
T
- class eformer.optimizers._config.SoapConfig(weight_decay: float = 0.0, beta1: float = 0.95, beta2: float = 0.95, shampoo_beta: float = 0.95, epsilon: float = 1e-08, max_grad_norm: float | None = 1.0, haps: list[int] | None = None, schedule_list: list[str] | None = None, precondition_frequency: int = 10, max_precond_dim: int = 10000, merge_small_dims: bool = True, one_diag: bool = False, target_merged_dim_size: int = 2048, mu_dtype: numpy.dtype | None = None, precond_dtype: numpy.dtype | None = None, partition_grads_into_blocks: bool = True, block_size: int = 256)[source]#
Bases:
SerializationMixinConfiguration class for the SOAP (Shampoo with Orthogonal and Adaptive Preconditioning) optimizer.
SOAP combines Shampoo’s second-order optimization with orthogonal preconditioning and adaptive scheduling for improved convergence.
- weight_decay#
Weight decay coefficient. Defaults to 0.0.
- Type
float
- beta1#
Momentum parameter for first moment estimates. Defaults to 0.95.
- Type
float
- beta2#
Momentum parameter for second moment estimates. Defaults to 0.95.
- Type
float
- shampoo_beta#
Beta parameter for Shampoo preconditioning. Defaults to 0.95.
- Type
float
- epsilon#
Small constant for numerical stability. Defaults to 1e-8.
- Type
float
- max_grad_norm#
Maximum gradient norm for clipping. Defaults to 1.0.
- Type
float | None
- haps#
HAP schedule parameters. Defaults to None.
- Type
list[int] | None
- schedule_list#
Schedule list. Defaults to None.
- Type
list[str] | None
- precondition_frequency#
Frequency of preconditioner updates. Defaults to 10.
- Type
int
- max_precond_dim#
Maximum dimension for preconditioning. Defaults to 10000.
- Type
int
- merge_small_dims#
Whether to merge small dimensions. Defaults to True.
- Type
bool
- one_diag#
Whether to use diagonal preconditioning only. Defaults to False.
- Type
bool
- target_merged_dim_size#
Target size for merged dimensions. Defaults to 2048.
- Type
int
- mu_dtype#
Data type for momentum computation. Defaults to None.
- Type
jnp.dtype | None
- precond_dtype#
Data type for preconditioners. Defaults to None.
- Type
jnp.dtype | None
- partition_grads_into_blocks#
Whether to partition gradients into blocks. Defaults to True.
- Type
bool
- block_size#
Block size for gradient partitioning. Defaults to 256.
- Type
int
- beta1: float = 0.95#
- beta2: float = 0.95#
- block_size: int = 256#
- epsilon: float = 1e-08#
- haps: list[int] | None = None#
- max_grad_norm: float | None = 1.0#
- max_precond_dim: int = 10000#
- merge_small_dims: bool = True#
- mu_dtype: numpy.dtype | None = None#
- one_diag: bool = False#
- partition_grads_into_blocks: bool = True#
- precond_dtype: numpy.dtype | None = None#
- precondition_frequency: int = 10#
- schedule_list: list[str] | None = None#
- shampoo_beta: float = 0.95#
- target_merged_dim_size: int = 2048#
- weight_decay: float = 0.0#
- class eformer.optimizers._config.WhiteKronConfig(lr_style: str | None = 'adam', b1: float = 0.95, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: Any | None = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: Any | None = None, noise_scale: float = 1e-09, weight_decay: float = 0.1, weight_decay_mask: Any | None = None)[source]#
Bases:
SerializationMixinConfiguration class for the White Kron optimizer.
White Kron is a Kronecker-factored preconditioned optimizer that uses different update styles (skew or quad) for efficient second-order optimization.
- lr_style#
Learning rate style. Defaults to “adam”.
- Type
str | None
- b1#
Exponential decay rate for first moment estimates. Defaults to 0.95.
- Type
float
- normalize_grads#
Whether to normalize gradients. Defaults to False.
- Type
bool
- max_size_dense#
Maximum size for dense preconditioning. Defaults to 16384.
- Type
int
- preconditioner_lr#
Learning rate for preconditioner updates. Defaults to 0.7.
- Type
float
- preconditioner_init_scale#
Initial scale for preconditioner. Defaults to 1.0.
- Type
float
- dtype#
Data type for computations. Defaults to jnp.bfloat16.
- Type
str | jnp.dtype
- scanned_layers#
Scanned layers configuration. Defaults to None.
- Type
Any | None
- block_size#
Block size for matrix operations. Defaults to 256.
- Type
int
- pipeline_axis_name#
Name of pipeline axis for sharding. Defaults to None.
- Type
str | None
- pipeline_axis_size#
Size of pipeline axis. Defaults to 1.
- Type
int
- params_partition_specs#
Parameter partition specifications. Defaults to None.
- Type
Any | None
- noise_scale#
Scale of noise added for numerical stability. Defaults to 1e-9.
- Type
float
- weight_decay#
Weight decay coefficient. Defaults to 0.1.
- Type
float
- weight_decay_mask#
Weight decay mask. Defaults to None.
- Type
Any | None
- b1: float = 0.95#
- block_size: int = 256#
- dtype#
alias of
bfloat16
- lr_style: str | None = 'adam'#
- max_size_dense: int = 16384#
- noise_scale: float = 1e-09#
- normalize_grads: bool = False#
- params_partition_specs: Any | None = None#
- pipeline_axis_name: str | None = None#
- pipeline_axis_size: int = 1#
- preconditioner_init_scale: float = 1.0#
- preconditioner_lr: float = 0.7#
- scanned_layers: Any | None = None#
- weight_decay: float = 0.1#
- weight_decay_mask: Any | None = None#