eformer.optimizers._tx.white_kron#
- class eformer.optimizers._tx.white_kron.DenseState(Ql: Array, Qr: Array, Ll: Array, Lr: Array, valid_rows: Array, valid_cols: Array, valid_count: int, block_size: int)[source]#
Bases:
objectState container for dense Kronecker-factored preconditioner blocks.
This class stores the concatenated preconditioner matrices for all dense parameter blocks in the model. Dense blocks are small enough to use full matrix Kronecker factors rather than diagonal approximations.
- Ql#
Left Kronecker factors, shape [num_blocks, block_size, block_size]. These are orthogonal/near-orthogonal matrices for left preconditioning.
- Type
- Qr#
Right Kronecker factors, shape [num_blocks, block_size, block_size]. These are orthogonal/near-orthogonal matrices for right preconditioning.
- Type
- Ll#
Lipschitz estimates for left factors, shape [num_blocks]. Used for adaptive learning rate scaling.
- Type
- Lr#
Lipschitz estimates for right factors, shape [num_blocks]. Used for adaptive learning rate scaling.
- Type
- valid_rows#
Number of valid rows in each block, shape [num_blocks]. Accounts for padding when original dimensions are not multiples of block_size.
- Type
- valid_cols#
Number of valid columns in each block, shape [num_blocks]. Accounts for padding when original dimensions are not multiples of block_size.
- Type
- valid_count#
Number of actual (non-padding) blocks in the state.
- Type
int
- block_size#
Size of each square block in the Kronecker factorization.
- Type
int
- Ll: Array#
- Lr: Array#
- Ql: Array#
- Qr: Array#
- block_size: int#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- valid_cols: Array#
- valid_count: int#
- valid_rows: Array#
- class eformer.optimizers._tx.white_kron.LeafState(kind: int, scanned: int, B: int, shape: tuple[int, ...] | None = None, merged: tuple[int, ...] | None = None, nr: int | None = None, nc: int | None = None, block_size: int | None = None, diag_left: bool | None = None, diag_right: bool | None = None, stack: int | None = None, Ql: jax.jaxlib._jax.Array | None = None, Qr: jax.jaxlib._jax.Array | None = None, Ll: jax.jaxlib._jax.Array | None = None, Lr: jax.jaxlib._jax.Array | None = None, valid_rows: jax.jaxlib._jax.Array | None = None, valid_cols: jax.jaxlib._jax.Array | None = None)[source]#
Bases:
objectState container for a single parameter leaf in the White Kron optimizer.
This class stores the preconditioner state for individual parameter tensors, supporting different processing paths based on parameter size and shape.
- The optimizer handles three types of parameters:
DENSE_PATH: Small 2D parameters use full dense Kronecker factors
LARGE_PATH: Large 2D parameters use mixed dense/diagonal factors
ONE_D_PATH: 1D parameters use diagonal-only preconditioners
- kind#
Processing path type (DENSE_PATH, LARGE_PATH, or ONE_D_PATH).
- Type
int
- scanned#
Whether this parameter is part of a scanned layer (0 or 1).
- Type
int
- B#
Batch dimension size (number of stacked parameter matrices).
- Type
int
- shape#
Original parameter shape (excluding batch dim).
- Type
tuple[int, …] | None
- merged#
Shape after merging dimensions to 2D (m, n).
- Type
tuple[int, …] | None
- nr#
Number of row blocks when using blocked processing.
- Type
int | None
- nc#
Number of column blocks when using blocked processing.
- Type
int | None
- block_size#
Block size for blocked processing.
- Type
int | None
- diag_left#
Whether left factor uses diagonal approximation.
- Type
bool | None
- diag_right#
Whether right factor uses diagonal approximation.
- Type
bool | None
- stack#
Total number of stacked blocks for parallel processing.
- Type
int | None
- B: int#
- Ll: jax.jaxlib._jax.Array | None = None#
- Lr: jax.jaxlib._jax.Array | None = None#
- Ql: jax.jaxlib._jax.Array | None = None#
- Qr: jax.jaxlib._jax.Array | None = None#
- block_size: int | None = None#
- diag_left: bool | None = None#
- diag_right: bool | None = None#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- kind: int#
- merged: tuple[int, ...] | None = None#
- nc: int | None = None#
- nr: int | None = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- scanned: int#
- shape: tuple[int, ...] | None = None#
- stack: int | None = None#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- valid_cols: jax.jaxlib._jax.Array | None = None#
- valid_rows: jax.jaxlib._jax.Array | None = None#
- eformer.optimizers._tx.white_kron.get_opt_state_partition_specs(params, **quad_kwargs)[source]#
Generate partition specs for White Kron optimizer state.
This utility function creates JAX partition specifications for the optimizer state, enabling proper sharding of the optimizer state across devices in distributed training scenarios.
- Parameters
params – Model parameters, used to infer state structure and shapes.
**quad_kwargs – Keyword arguments for the Quad/Skew optimizer, including: - lr_style: Learning rate scaling style - b1: Momentum coefficient - normalize_grads: Whether to normalize gradients - max_size_dense: Max dimension for dense factors - preconditioner_lr: Preconditioner learning rate - preconditioner_init_scale: Initial preconditioner scale - dtype: Storage dtype - scanned_layers: Scanned layer indicators - block_size: Block size for matrix partitioning - pipeline_axis_name: Pipeline axis name for sharding - pipeline_axis_size: Pipeline axis size - params_partition_specs: Parameter partition specs - noise_scale: Noise scale for stability - weight_decay: Weight decay coefficient (used to determine state structure)
- Returns
- Partition specs for the optimizer state. Structure depends on
whether weight decay is enabled: - With weight decay > 0: (precond_specs, None, None) - Without weight decay: (precond_specs, None)
- Return type
tuple
Example
>>> import jax.numpy as jnp >>> from eformer.optimizers._tx.white_kron import get_opt_state_partition_specs >>> params = {"layer1": jnp.zeros((128, 64)), "layer2": jnp.zeros((64, 32))} >>> specs = get_opt_state_partition_specs( ... params, ... pipeline_axis_name="dp", ... pipeline_axis_size=4, ... weight_decay=0.1, ... )
- eformer.optimizers._tx.white_kron.quad(learning_rate: float | collections.abc.Callable[[int], float] = 0.001, lr_style: str | None = 'adam', b1: float = 0.95, weight_decay: float = 0.1, weight_decay_mask: Any | collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Any] | None = None, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: ~typing.Optional[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, bool, int, float, complex, ~typing.Iterable[ArrayTree], ~typing.Mapping[~typing.Any, ArrayTree]]] = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: jax.sharding.PartitionSpec | list | tuple | dict | None = None, noise_scale: float = 1e-09) GradientTransformation[source]#
Complete Quad optimizer with weight decay and learning rate scheduling.
Quad is a Kronecker-factored preconditioned optimizer using quadratic preconditioner updates that minimize a quadratic loss function. This provides efficient second-order optimization.
- Parameters
learning_rate (float | Callable[[int], float]) – Learning rate or schedule. Defaults to 0.001.
lr_style (str | None) – Learning rate scaling style. Defaults to “adam”.
b1 (float) – Momentum coefficient. Defaults to 0.95.
weight_decay (float) – Weight decay coefficient. Defaults to 0.1.
weight_decay_mask – Mask for selective weight decay. Defaults to None.
normalize_grads (bool) – Whether to normalize gradients. Defaults to False.
max_size_dense (int) – Max dimension for dense factors. Defaults to 16384.
preconditioner_lr (float) – Preconditioner learning rate. Defaults to 0.7.
preconditioner_init_scale (float) – Initial preconditioner scale. Defaults to 1.0.
dtype (str | jnp.dtype) – Storage dtype. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None) – Scanned layer indicators. Defaults to None.
block_size (int) – Block size for matrix partitioning. Defaults to 256.
pipeline_axis_name (str | None) – Pipeline axis name. Defaults to None.
pipeline_axis_size (int) – Pipeline axis size. Defaults to 1.
params_partition_specs – Parameter partition specs. Defaults to None.
noise_scale (float) – Noise scale for stability. Defaults to 1e-9.
- Returns
Complete Quad optimizer transformation.
- Return type
base.GradientTransformation
Example
>>> from eformer.optimizers._tx import quad >>> # With constant learning rate >>> optimizer = quad(learning_rate=1e-4, b1=0.95, weight_decay=0.1) >>> # With learning rate schedule >>> import optax >>> schedule = optax.cosine_decay_schedule(1e-4, 10000) >>> optimizer = quad(learning_rate=schedule)
- eformer.optimizers._tx.white_kron.scale_by_quad(lr_style: str | None = 'adam', b1: float = 0.95, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: ~typing.Optional[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, bool, int, float, complex, ~typing.Iterable[ArrayTree], ~typing.Mapping[~typing.Any, ArrayTree]]] = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: jax.sharding.PartitionSpec | list | tuple | dict | None = None, noise_scale: float = 1e-09) GradientTransformation[source]#
Create a gradient scaling transformation using QUAD-style preconditioner updates.
The QUAD variant of White Kron uses quadratic preconditioner updates that directly minimize a quadratic loss in the preconditioner space.
- Parameters
lr_style (str | None) – Learning rate scaling style. “adam” for Adam-like scaling. Defaults to “adam”.
b1 (float) – Momentum coefficient. Defaults to 0.95.
normalize_grads (bool) – Whether to normalize gradients. Defaults to False.
max_size_dense (int) – Max dimension for dense factors. Defaults to 16384.
preconditioner_lr (float) – Preconditioner learning rate. Defaults to 0.7.
preconditioner_init_scale (float) – Initial preconditioner scale. Defaults to 1.0.
dtype (str | jnp.dtype) – Storage dtype. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None) – Scanned layer indicators. Defaults to None.
block_size (int) – Block size for matrix partitioning. Defaults to 256.
pipeline_axis_name (str | None) – Pipeline axis name. Defaults to None.
pipeline_axis_size (int) – Pipeline axis size. Defaults to 1.
params_partition_specs – Parameter partition specs. Defaults to None.
noise_scale (float) – Noise scale for stability. Defaults to 1e-9.
- Returns
QUAD-style preconditioned gradient transformation.
- Return type
base.GradientTransformation
Example
>>> import optax >>> from eformer.optimizers._tx import scale_by_quad >>> optimizer = optax.chain( ... scale_by_quad(b1=0.95), ... optax.add_decayed_weights(0.1), ... optax.scale_by_learning_rate(1e-4), ... )
- eformer.optimizers._tx.white_kron.scale_by_skew(lr_style: str | None = 'adam', b1: float = 0.95, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: ~typing.Optional[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, bool, int, float, complex, ~typing.Iterable[ArrayTree], ~typing.Mapping[~typing.Any, ArrayTree]]] = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: jax.sharding.PartitionSpec | list | tuple | dict | None = None, noise_scale: float = 1e-09) GradientTransformation[source]#
Create a gradient scaling transformation using skew-style preconditioner updates.
The skew variant of White Kron uses Procrustes orthogonalization to maintain near-orthogonal preconditioner matrices, which can provide more stable training.
- Parameters
lr_style (str | None) – Learning rate scaling style. “adam” for Adam-like scaling. Defaults to “adam”.
b1 (float) – Momentum coefficient. Defaults to 0.95.
normalize_grads (bool) – Whether to normalize gradients. Defaults to False.
max_size_dense (int) – Max dimension for dense factors. Defaults to 16384.
preconditioner_lr (float) – Preconditioner learning rate. Defaults to 0.7.
preconditioner_init_scale (float) – Initial preconditioner scale. Defaults to 1.0.
dtype (str | jnp.dtype) – Storage dtype. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None) – Scanned layer indicators. Defaults to None.
block_size (int) – Block size for matrix partitioning. Defaults to 256.
pipeline_axis_name (str | None) – Pipeline axis name. Defaults to None.
pipeline_axis_size (int) – Pipeline axis size. Defaults to 1.
params_partition_specs – Parameter partition specs. Defaults to None.
noise_scale (float) – Noise scale for stability. Defaults to 1e-9.
- Returns
Skew-style preconditioned gradient transformation.
- Return type
base.GradientTransformation
Example
>>> import optax >>> from eformer.optimizers._tx import scale_by_skew >>> optimizer = optax.chain( ... scale_by_skew(b1=0.95), ... optax.add_decayed_weights(0.1), ... optax.scale_by_learning_rate(1e-4), ... )
- eformer.optimizers._tx.white_kron.skew(learning_rate: float | collections.abc.Callable[[int], float] = 0.001, lr_style: str | None = 'adam', b1: float = 0.95, weight_decay: float = 0.1, weight_decay_mask: Any | collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Any] | None = None, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: ~typing.Optional[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, bool, int, float, complex, ~typing.Iterable[ArrayTree], ~typing.Mapping[~typing.Any, ArrayTree]]] = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: jax.sharding.PartitionSpec | list | tuple | dict | None = None, noise_scale: float = 1e-09) GradientTransformation[source]#
Complete Skew optimizer with weight decay and learning rate scheduling.
Skew is a Kronecker-factored preconditioned optimizer using Procrustes orthogonalization to maintain near-orthogonal preconditioner matrices. This provides efficient second-order optimization with stable training.
- Parameters
learning_rate (float | Callable[[int], float]) – Learning rate or schedule. Defaults to 0.001.
lr_style (str | None) – Learning rate scaling style. Defaults to “adam”.
b1 (float) – Momentum coefficient. Defaults to 0.95.
weight_decay (float) – Weight decay coefficient. Defaults to 0.1.
weight_decay_mask – Mask for selective weight decay. Defaults to None.
normalize_grads (bool) – Whether to normalize gradients. Defaults to False.
max_size_dense (int) – Max dimension for dense factors. Defaults to 16384.
preconditioner_lr (float) – Preconditioner learning rate. Defaults to 0.7.
preconditioner_init_scale (float) – Initial preconditioner scale. Defaults to 1.0.
dtype (str | jnp.dtype) – Storage dtype. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None) – Scanned layer indicators. Defaults to None.
block_size (int) – Block size for matrix partitioning. Defaults to 256.
pipeline_axis_name (str | None) – Pipeline axis name. Defaults to None.
pipeline_axis_size (int) – Pipeline axis size. Defaults to 1.
params_partition_specs – Parameter partition specs. Defaults to None.
noise_scale (float) – Noise scale for stability. Defaults to 1e-9.
- Returns
Complete Skew optimizer transformation.
- Return type
base.GradientTransformation
Example
>>> from eformer.optimizers._tx import skew >>> # With constant learning rate >>> optimizer = skew(learning_rate=1e-4, b1=0.95, weight_decay=0.1) >>> # With learning rate schedule >>> import optax >>> schedule = optax.cosine_decay_schedule(1e-4, 10000) >>> optimizer = skew(learning_rate=schedule)