eformer.optimizers._tx.white_kron#

class eformer.optimizers._tx.white_kron.DenseState(Ql: Array, Qr: Array, Ll: Array, Lr: Array, valid_rows: Array, valid_cols: Array, valid_count: int, block_size: int)[source]#

Bases: object

State container for dense Kronecker-factored preconditioner blocks.

This class stores the concatenated preconditioner matrices for all dense parameter blocks in the model. Dense blocks are small enough to use full matrix Kronecker factors rather than diagonal approximations.

Ql#

Left Kronecker factors, shape [num_blocks, block_size, block_size]. These are orthogonal/near-orthogonal matrices for left preconditioning.

Type

jax.Array

Qr#

Right Kronecker factors, shape [num_blocks, block_size, block_size]. These are orthogonal/near-orthogonal matrices for right preconditioning.

Type

jax.Array

Ll#

Lipschitz estimates for left factors, shape [num_blocks]. Used for adaptive learning rate scaling.

Type

jax.Array

Lr#

Lipschitz estimates for right factors, shape [num_blocks]. Used for adaptive learning rate scaling.

Type

jax.Array

valid_rows#

Number of valid rows in each block, shape [num_blocks]. Accounts for padding when original dimensions are not multiples of block_size.

Type

jax.Array

valid_cols#

Number of valid columns in each block, shape [num_blocks]. Accounts for padding when original dimensions are not multiples of block_size.

Type

jax.Array

valid_count#

Number of actual (non-padding) blocks in the state.

Type

int

block_size#

Size of each square block in the Kronecker factorization.

Type

int

Ll: Array#
Lr: Array#
Ql: Array#
Qr: Array#
block_size: int#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

valid_cols: Array#
valid_count: int#
valid_rows: Array#
class eformer.optimizers._tx.white_kron.LeafState(kind: int, scanned: int, B: int, shape: tuple[int, ...] | None = None, merged: tuple[int, ...] | None = None, nr: int | None = None, nc: int | None = None, block_size: int | None = None, diag_left: bool | None = None, diag_right: bool | None = None, stack: int | None = None, Ql: jax.jaxlib._jax.Array | None = None, Qr: jax.jaxlib._jax.Array | None = None, Ll: jax.jaxlib._jax.Array | None = None, Lr: jax.jaxlib._jax.Array | None = None, valid_rows: jax.jaxlib._jax.Array | None = None, valid_cols: jax.jaxlib._jax.Array | None = None)[source]#

Bases: object

State container for a single parameter leaf in the White Kron optimizer.

This class stores the preconditioner state for individual parameter tensors, supporting different processing paths based on parameter size and shape.

The optimizer handles three types of parameters:
  • DENSE_PATH: Small 2D parameters use full dense Kronecker factors

  • LARGE_PATH: Large 2D parameters use mixed dense/diagonal factors

  • ONE_D_PATH: 1D parameters use diagonal-only preconditioners

kind#

Processing path type (DENSE_PATH, LARGE_PATH, or ONE_D_PATH).

Type

int

scanned#

Whether this parameter is part of a scanned layer (0 or 1).

Type

int

B#

Batch dimension size (number of stacked parameter matrices).

Type

int

shape#

Original parameter shape (excluding batch dim).

Type

tuple[int, …] | None

merged#

Shape after merging dimensions to 2D (m, n).

Type

tuple[int, …] | None

nr#

Number of row blocks when using blocked processing.

Type

int | None

nc#

Number of column blocks when using blocked processing.

Type

int | None

block_size#

Block size for blocked processing.

Type

int | None

diag_left#

Whether left factor uses diagonal approximation.

Type

bool | None

diag_right#

Whether right factor uses diagonal approximation.

Type

bool | None

stack#

Total number of stacked blocks for parallel processing.

Type

int | None

Ql#

Left Kronecker factor(s).

Type

jax.Array | None

Qr#

Right Kronecker factor(s).

Type

jax.Array | None

Ll#

Left Lipschitz estimates.

Type

jax.Array | None

Lr#

Right Lipschitz estimates.

Type

jax.Array | None

valid_rows#

Valid row counts per block.

Type

jax.Array | None

valid_cols#

Valid column counts per block.

Type

jax.Array | None

B: int#
Ll: jax.jaxlib._jax.Array | None = None#
Lr: jax.jaxlib._jax.Array | None = None#
Ql: jax.jaxlib._jax.Array | None = None#
Qr: jax.jaxlib._jax.Array | None = None#
block_size: int | None = None#
diag_left: bool | None = None#
diag_right: bool | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

kind: int#
merged: tuple[int, ...] | None = None#
nc: int | None = None#
nr: int | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

scanned: int#
shape: tuple[int, ...] | None = None#
stack: int | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

valid_cols: jax.jaxlib._jax.Array | None = None#
valid_rows: jax.jaxlib._jax.Array | None = None#
eformer.optimizers._tx.white_kron.get_opt_state_partition_specs(params, **quad_kwargs)[source]#

Generate partition specs for White Kron optimizer state.

This utility function creates JAX partition specifications for the optimizer state, enabling proper sharding of the optimizer state across devices in distributed training scenarios.

Parameters
  • params – Model parameters, used to infer state structure and shapes.

  • **quad_kwargs – Keyword arguments for the Quad/Skew optimizer, including: - lr_style: Learning rate scaling style - b1: Momentum coefficient - normalize_grads: Whether to normalize gradients - max_size_dense: Max dimension for dense factors - preconditioner_lr: Preconditioner learning rate - preconditioner_init_scale: Initial preconditioner scale - dtype: Storage dtype - scanned_layers: Scanned layer indicators - block_size: Block size for matrix partitioning - pipeline_axis_name: Pipeline axis name for sharding - pipeline_axis_size: Pipeline axis size - params_partition_specs: Parameter partition specs - noise_scale: Noise scale for stability - weight_decay: Weight decay coefficient (used to determine state structure)

Returns

Partition specs for the optimizer state. Structure depends on

whether weight decay is enabled: - With weight decay > 0: (precond_specs, None, None) - Without weight decay: (precond_specs, None)

Return type

tuple

Example

>>> import jax.numpy as jnp
>>> from eformer.optimizers._tx.white_kron import get_opt_state_partition_specs
>>> params = {"layer1": jnp.zeros((128, 64)), "layer2": jnp.zeros((64, 32))}
>>> specs = get_opt_state_partition_specs(
...     params,
...     pipeline_axis_name="dp",
...     pipeline_axis_size=4,
...     weight_decay=0.1,
... )
eformer.optimizers._tx.white_kron.quad(learning_rate: float | collections.abc.Callable[[int], float] = 0.001, lr_style: str | None = 'adam', b1: float = 0.95, weight_decay: float = 0.1, weight_decay_mask: Any | collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Any] | None = None, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: ~typing.Optional[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, ~typing.Iterable[ArrayTree], ~typing.Mapping[~typing.Any, ArrayTree]]] = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: jax.sharding.PartitionSpec | list | tuple | dict | None = None, noise_scale: float = 1e-09) GradientTransformation[source]#

Complete Quad optimizer with weight decay and learning rate scheduling.

Quad is a Kronecker-factored preconditioned optimizer using quadratic preconditioner updates that minimize a quadratic loss function. This provides efficient second-order optimization.

Parameters
  • learning_rate (float | Callable[[int], float]) – Learning rate or schedule. Defaults to 0.001.

  • lr_style (str | None) – Learning rate scaling style. Defaults to “adam”.

  • b1 (float) – Momentum coefficient. Defaults to 0.95.

  • weight_decay (float) – Weight decay coefficient. Defaults to 0.1.

  • weight_decay_mask – Mask for selective weight decay. Defaults to None.

  • normalize_grads (bool) – Whether to normalize gradients. Defaults to False.

  • max_size_dense (int) – Max dimension for dense factors. Defaults to 16384.

  • preconditioner_lr (float) – Preconditioner learning rate. Defaults to 0.7.

  • preconditioner_init_scale (float) – Initial preconditioner scale. Defaults to 1.0.

  • dtype (str | jnp.dtype) – Storage dtype. Defaults to jnp.bfloat16.

  • scanned_layers (base.Params | None) – Scanned layer indicators. Defaults to None.

  • block_size (int) – Block size for matrix partitioning. Defaults to 256.

  • pipeline_axis_name (str | None) – Pipeline axis name. Defaults to None.

  • pipeline_axis_size (int) – Pipeline axis size. Defaults to 1.

  • params_partition_specs – Parameter partition specs. Defaults to None.

  • noise_scale (float) – Noise scale for stability. Defaults to 1e-9.

Returns

Complete Quad optimizer transformation.

Return type

base.GradientTransformation

Example

>>> from eformer.optimizers._tx import quad
>>> # With constant learning rate
>>> optimizer = quad(learning_rate=1e-4, b1=0.95, weight_decay=0.1)
>>> # With learning rate schedule
>>> import optax
>>> schedule = optax.cosine_decay_schedule(1e-4, 10000)
>>> optimizer = quad(learning_rate=schedule)
eformer.optimizers._tx.white_kron.scale_by_quad(lr_style: str | None = 'adam', b1: float = 0.95, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: ~typing.Optional[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, ~typing.Iterable[ArrayTree], ~typing.Mapping[~typing.Any, ArrayTree]]] = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: jax.sharding.PartitionSpec | list | tuple | dict | None = None, noise_scale: float = 1e-09) GradientTransformation[source]#

Create a gradient scaling transformation using QUAD-style preconditioner updates.

The QUAD variant of White Kron uses quadratic preconditioner updates that directly minimize a quadratic loss in the preconditioner space.

Parameters
  • lr_style (str | None) – Learning rate scaling style. “adam” for Adam-like scaling. Defaults to “adam”.

  • b1 (float) – Momentum coefficient. Defaults to 0.95.

  • normalize_grads (bool) – Whether to normalize gradients. Defaults to False.

  • max_size_dense (int) – Max dimension for dense factors. Defaults to 16384.

  • preconditioner_lr (float) – Preconditioner learning rate. Defaults to 0.7.

  • preconditioner_init_scale (float) – Initial preconditioner scale. Defaults to 1.0.

  • dtype (str | jnp.dtype) – Storage dtype. Defaults to jnp.bfloat16.

  • scanned_layers (base.Params | None) – Scanned layer indicators. Defaults to None.

  • block_size (int) – Block size for matrix partitioning. Defaults to 256.

  • pipeline_axis_name (str | None) – Pipeline axis name. Defaults to None.

  • pipeline_axis_size (int) – Pipeline axis size. Defaults to 1.

  • params_partition_specs – Parameter partition specs. Defaults to None.

  • noise_scale (float) – Noise scale for stability. Defaults to 1e-9.

Returns

QUAD-style preconditioned gradient transformation.

Return type

base.GradientTransformation

Example

>>> import optax
>>> from eformer.optimizers._tx import scale_by_quad
>>> optimizer = optax.chain(
...     scale_by_quad(b1=0.95),
...     optax.add_decayed_weights(0.1),
...     optax.scale_by_learning_rate(1e-4),
... )
eformer.optimizers._tx.white_kron.scale_by_skew(lr_style: str | None = 'adam', b1: float = 0.95, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: ~typing.Optional[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, ~typing.Iterable[ArrayTree], ~typing.Mapping[~typing.Any, ArrayTree]]] = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: jax.sharding.PartitionSpec | list | tuple | dict | None = None, noise_scale: float = 1e-09) GradientTransformation[source]#

Create a gradient scaling transformation using skew-style preconditioner updates.

The skew variant of White Kron uses Procrustes orthogonalization to maintain near-orthogonal preconditioner matrices, which can provide more stable training.

Parameters
  • lr_style (str | None) – Learning rate scaling style. “adam” for Adam-like scaling. Defaults to “adam”.

  • b1 (float) – Momentum coefficient. Defaults to 0.95.

  • normalize_grads (bool) – Whether to normalize gradients. Defaults to False.

  • max_size_dense (int) – Max dimension for dense factors. Defaults to 16384.

  • preconditioner_lr (float) – Preconditioner learning rate. Defaults to 0.7.

  • preconditioner_init_scale (float) – Initial preconditioner scale. Defaults to 1.0.

  • dtype (str | jnp.dtype) – Storage dtype. Defaults to jnp.bfloat16.

  • scanned_layers (base.Params | None) – Scanned layer indicators. Defaults to None.

  • block_size (int) – Block size for matrix partitioning. Defaults to 256.

  • pipeline_axis_name (str | None) – Pipeline axis name. Defaults to None.

  • pipeline_axis_size (int) – Pipeline axis size. Defaults to 1.

  • params_partition_specs – Parameter partition specs. Defaults to None.

  • noise_scale (float) – Noise scale for stability. Defaults to 1e-9.

Returns

Skew-style preconditioned gradient transformation.

Return type

base.GradientTransformation

Example

>>> import optax
>>> from eformer.optimizers._tx import scale_by_skew
>>> optimizer = optax.chain(
...     scale_by_skew(b1=0.95),
...     optax.add_decayed_weights(0.1),
...     optax.scale_by_learning_rate(1e-4),
... )
eformer.optimizers._tx.white_kron.skew(learning_rate: float | collections.abc.Callable[[int], float] = 0.001, lr_style: str | None = 'adam', b1: float = 0.95, weight_decay: float = 0.1, weight_decay_mask: Any | collections.abc.Callable[[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Any] | None = None, normalize_grads: bool = False, max_size_dense: int = 16384, preconditioner_lr: float = 0.7, preconditioner_init_scale: float = 1.0, dtype: str | numpy.dtype = <class 'jax.numpy.bfloat16'>, scanned_layers: ~typing.Optional[~typing.Union[~jax.jaxlib._jax.Array, ~numpy.ndarray, ~numpy.bool, ~numpy.number, ~typing.Iterable[ArrayTree], ~typing.Mapping[~typing.Any, ArrayTree]]] = None, block_size: int = 256, pipeline_axis_name: str | None = None, pipeline_axis_size: int = 1, params_partition_specs: jax.sharding.PartitionSpec | list | tuple | dict | None = None, noise_scale: float = 1e-09) GradientTransformation[source]#

Complete Skew optimizer with weight decay and learning rate scheduling.

Skew is a Kronecker-factored preconditioned optimizer using Procrustes orthogonalization to maintain near-orthogonal preconditioner matrices. This provides efficient second-order optimization with stable training.

Parameters
  • learning_rate (float | Callable[[int], float]) – Learning rate or schedule. Defaults to 0.001.

  • lr_style (str | None) – Learning rate scaling style. Defaults to “adam”.

  • b1 (float) – Momentum coefficient. Defaults to 0.95.

  • weight_decay (float) – Weight decay coefficient. Defaults to 0.1.

  • weight_decay_mask – Mask for selective weight decay. Defaults to None.

  • normalize_grads (bool) – Whether to normalize gradients. Defaults to False.

  • max_size_dense (int) – Max dimension for dense factors. Defaults to 16384.

  • preconditioner_lr (float) – Preconditioner learning rate. Defaults to 0.7.

  • preconditioner_init_scale (float) – Initial preconditioner scale. Defaults to 1.0.

  • dtype (str | jnp.dtype) – Storage dtype. Defaults to jnp.bfloat16.

  • scanned_layers (base.Params | None) – Scanned layer indicators. Defaults to None.

  • block_size (int) – Block size for matrix partitioning. Defaults to 256.

  • pipeline_axis_name (str | None) – Pipeline axis name. Defaults to None.

  • pipeline_axis_size (int) – Pipeline axis size. Defaults to 1.

  • params_partition_specs – Parameter partition specs. Defaults to None.

  • noise_scale (float) – Noise scale for stability. Defaults to 1e-9.

Returns

Complete Skew optimizer transformation.

Return type

base.GradientTransformation

Example

>>> from eformer.optimizers._tx import skew
>>> # With constant learning rate
>>> optimizer = skew(learning_rate=1e-4, b1=0.95, weight_decay=0.1)
>>> # With learning rate schedule
>>> import optax
>>> schedule = optax.cosine_decay_schedule(1e-4, 10000)
>>> optimizer = skew(learning_rate=schedule)