eformer.optimizers._factory#
- class eformer.optimizers._factory.OptimizerFactory[source]#
Bases:
objectFactory class for creating optimizers with validated configurations.
This class provides methods to create optimizers based on a configuration object. All optimizers are registered using the @register_optimizer decorator pattern.
- Private Methods:
_get_config_class: Gets the configuration class for an optimizer type. _convert_dtypes: Converts string dtype representations to JAX dtypes. _validate_kwargs: Validates additional parameters for the optimizer. _build_optimizer_chain: Constructs the final optimizer chain.
- classmethod create(optimizer_type: str, scheduler_config: eformer.optimizers._config.SchedulerConfig | None = None, optimizer_config: eformer.optimizers._config.AdafactorConfig | eformer.optimizers._config.AdamWConfig | eformer.optimizers._config.KronConfig | eformer.optimizers._config.LionConfig | eformer.optimizers._config.MarsConfig | eformer.optimizers._config.MuonConfig | eformer.optimizers._config.RMSPropConfig | eformer.optimizers._config.SoapConfig | eformer.optimizers._config.WhiteKronConfig | None = None, *, weight_decay: float = 0.0, weight_decay_mask: Any | None = None, gradient_accumulation_steps: int = 1, clip_grad: float | None = None, custom_scheduler: Optional[Callable[[int], Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]]] = None, **kwargs) tuple[optax._src.base.GradientTransformation, collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]]][source]#
Create an optimizer with validated configuration.
- Parameters
optimizer_type (str) – One of the registered optimizer types.
scheduler_config (SchedulerConfig) – Configured scheduler parameters.
optimizer_config (Union[AdafactorConfig, AdamWConfig, LionConfig, MuonConfig, RMSPropConfig]) – Optimizer-specific configuration.
weight_decay (float) – Global weight decay rate. Defaults to 0.0.
weight_decay_mask (Optional[Any]) – Mask for weight decay application. Defaults to None.
gradient_accumulation_steps (int) – Steps for gradient accumulation. Defaults to 1.
clip_grad (Optional[float]) – Global clip gradient norm value. Defaults to None.
custom_scheduler (Optional[Callable[[int], optax.Schedule]]) – Optional custom scheduler function. Defaults to None.
**kwargs – Additional optimizer-specific parameters.
- Returns
A tuple containing the optimizer and scheduler.
- Return type
Tuple[optax.GradientTransformation, optax.Schedule]
- Raises
ValueError – If the optimizer type is unsupported or the configuration is invalid.
TypeError – If the configuration type is invalid.
- classmethod deserialize_config(optimizer_type: str, data: dict | str, format: str = 'dict') SerializationMixin[source]#
Deserialize configuration from different formats.
- Parameters
optimizer_type (str) – Name of the optimizer.
data (Union[Dict, str]) – Serialized configuration data.
format (str) – Serialization format. Supported formats: ‘dict’, ‘json’.
- Returns
Deserialized configuration object.
- Return type
- Raises
ValueError – If the optimizer type is unknown or the format is unsupported.
TypeError – If the input data type is invalid.
- classmethod generate_template(optimizer_type: str) str[source]#
Generate a configuration template for the specified optimizer.
- Parameters
optimizer_type (str) – Name of the optimizer.
- Returns
Configuration template.
- Return type
str
- Raises
ValueError – If the optimizer type is unknown.
- classmethod serialize_config(config: SerializationMixin, format: str = 'dict') dict | str[source]#
Serialize configuration to different formats.
- Parameters
config (SerializationMixin) – Configuration object.
format (str) – Serialization format. Supported formats: ‘dict’, ‘json’.
- Returns
Serialized configuration.
- Return type
Union[Dict, str]
- Raises
ValueError – If the format is unsupported.
- class eformer.optimizers._factory.SchedulerFactory[source]#
Bases:
objectFactory class for creating learning rate schedulers.
This class provides methods to create schedulers based on a configuration object (SchedulerConfig). It supports linear and cosine schedulers with optional warmup steps.
- static create_scheduler(config: SchedulerConfig, custom_scheduler: Optional[Callable[[int], Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]]] = None) Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]][source]#
Create a scheduler based on the provided configuration.
- Parameters
config (SchedulerConfig) – Configuration object for the scheduler.
custom_scheduler (Optional[Callable[[int], optax.Schedule]]) – Custom scheduler function. Defaults to None.
- Returns
The created scheduler.
- Return type
optax.Schedule
- Raises
ValueError – If the configuration is invalid or unsupported scheduler type is provided.