eformer.optimizers._factory#

class eformer.optimizers._factory.OptimizerFactory[source]#

Bases: object

Factory class for creating optimizers with validated configurations.

This class provides methods to create optimizers based on a configuration object. All optimizers are registered using the @register_optimizer decorator pattern.

create()[source]#

Creates an optimizer with validated configuration.

generate_template()[source]#

Generates a configuration template for the specified optimizer.

serialize_config()[source]#

Serializes configuration to different formats.

deserialize_config()[source]#

Deserializes configuration from different formats.

Private Methods:

_get_config_class: Gets the configuration class for an optimizer type. _convert_dtypes: Converts string dtype representations to JAX dtypes. _validate_kwargs: Validates additional parameters for the optimizer. _build_optimizer_chain: Constructs the final optimizer chain.

classmethod create(optimizer_type: str, scheduler_config: eformer.optimizers._config.SchedulerConfig | None = None, optimizer_config: eformer.optimizers._config.AdafactorConfig | eformer.optimizers._config.AdamWConfig | eformer.optimizers._config.KronConfig | eformer.optimizers._config.LionConfig | eformer.optimizers._config.MarsConfig | eformer.optimizers._config.MuonConfig | eformer.optimizers._config.RMSPropConfig | eformer.optimizers._config.SoapConfig | eformer.optimizers._config.WhiteKronConfig | None = None, *, weight_decay: float = 0.0, weight_decay_mask: Any | None = None, gradient_accumulation_steps: int = 1, clip_grad: float | None = None, custom_scheduler: Optional[Callable[[int], Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]]]] = None, **kwargs) tuple[optax._src.base.GradientTransformation, collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]]][source]#

Create an optimizer with validated configuration.

Parameters
  • optimizer_type (str) – One of the registered optimizer types.

  • scheduler_config (SchedulerConfig) – Configured scheduler parameters.

  • optimizer_config (Union[AdafactorConfig, AdamWConfig, LionConfig, MuonConfig, RMSPropConfig]) – Optimizer-specific configuration.

  • weight_decay (float) – Global weight decay rate. Defaults to 0.0.

  • weight_decay_mask (Optional[Any]) – Mask for weight decay application. Defaults to None.

  • gradient_accumulation_steps (int) – Steps for gradient accumulation. Defaults to 1.

  • clip_grad (Optional[float]) – Global clip gradient norm value. Defaults to None.

  • custom_scheduler (Optional[Callable[[int], optax.Schedule]]) – Optional custom scheduler function. Defaults to None.

  • **kwargs – Additional optimizer-specific parameters.

Returns

A tuple containing the optimizer and scheduler.

Return type

Tuple[optax.GradientTransformation, optax.Schedule]

Raises
  • ValueError – If the optimizer type is unsupported or the configuration is invalid.

  • TypeError – If the configuration type is invalid.

classmethod deserialize_config(optimizer_type: str, data: dict | str, format: str = 'dict') SerializationMixin[source]#

Deserialize configuration from different formats.

Parameters
  • optimizer_type (str) – Name of the optimizer.

  • data (Union[Dict, str]) – Serialized configuration data.

  • format (str) – Serialization format. Supported formats: ‘dict’, ‘json’.

Returns

Deserialized configuration object.

Return type

SerializationMixin

Raises
  • ValueError – If the optimizer type is unknown or the format is unsupported.

  • TypeError – If the input data type is invalid.

classmethod generate_template(optimizer_type: str) str[source]#

Generate a configuration template for the specified optimizer.

Parameters

optimizer_type (str) – Name of the optimizer.

Returns

Configuration template.

Return type

str

Raises

ValueError – If the optimizer type is unknown.

classmethod serialize_config(config: SerializationMixin, format: str = 'dict') dict | str[source]#

Serialize configuration to different formats.

Parameters
  • config (SerializationMixin) – Configuration object.

  • format (str) – Serialization format. Supported formats: ‘dict’, ‘json’.

Returns

Serialized configuration.

Return type

Union[Dict, str]

Raises

ValueError – If the format is unsupported.

class eformer.optimizers._factory.SchedulerFactory[source]#

Bases: object

Factory class for creating learning rate schedulers.

This class provides methods to create schedulers based on a configuration object (SchedulerConfig). It supports linear and cosine schedulers with optional warmup steps.

create_scheduler()[source]#

Creates a scheduler based on the provided configuration.

_create_linear()[source]#

Creates a linear scheduler with optional warmup.

_create_cosine()[source]#

Creates a cosine scheduler with optional warmup.

static create_scheduler(config: SchedulerConfig, custom_scheduler: Optional[Callable[[int], Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]]]] = None) Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]][source]#

Create a scheduler based on the provided configuration.

Parameters
  • config (SchedulerConfig) – Configuration object for the scheduler.

  • custom_scheduler (Optional[Callable[[int], optax.Schedule]]) – Custom scheduler function. Defaults to None.

Returns

The created scheduler.

Return type

optax.Schedule

Raises

ValueError – If the configuration is invalid or unsupported scheduler type is provided.