# Copyright 2026 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from typing import Any, Literal
import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax.lax import with_sharding_constraint
from jax.sharding import PartitionSpec
from optax import tree_utils as otu
from optax._src import base, transform
from optax._src.combine import chain
from optax._src.numerics import safe_int32_increment
from optax._src.utils import canonicalize_dtype
from eformer.pytree import auto_pytree, field
# Path type constants for parameter classification
DENSE_PATH = 0 # Parameters processed with full dense Kronecker factors
LARGE_PATH = 1 # Large parameters processed with mixed dense/diagonal factors
ONE_D_PATH = 2 # 1D parameters processed with diagonal-only factors
[docs]@auto_pytree
class DenseState:
"""State container for dense Kronecker-factored preconditioner blocks.
This class stores the concatenated preconditioner matrices for all dense
parameter blocks in the model. Dense blocks are small enough to use full
matrix Kronecker factors rather than diagonal approximations.
Attributes:
Ql (jax.Array): Left Kronecker factors, shape [num_blocks, block_size, block_size].
These are orthogonal/near-orthogonal matrices for left preconditioning.
Qr (jax.Array): Right Kronecker factors, shape [num_blocks, block_size, block_size].
These are orthogonal/near-orthogonal matrices for right preconditioning.
Ll (jax.Array): Lipschitz estimates for left factors, shape [num_blocks].
Used for adaptive learning rate scaling.
Lr (jax.Array): Lipschitz estimates for right factors, shape [num_blocks].
Used for adaptive learning rate scaling.
valid_rows (jax.Array): Number of valid rows in each block, shape [num_blocks].
Accounts for padding when original dimensions are not multiples of block_size.
valid_cols (jax.Array): Number of valid columns in each block, shape [num_blocks].
Accounts for padding when original dimensions are not multiples of block_size.
valid_count (int): Number of actual (non-padding) blocks in the state.
block_size (int): Size of each square block in the Kronecker factorization.
"""
Ql: jax.Array
Qr: jax.Array
Ll: jax.Array
Lr: jax.Array
valid_rows: jax.Array
valid_cols: jax.Array
valid_count: int = field(pytree_node=False)
block_size: int = field(pytree_node=False)
[docs]@auto_pytree
class LeafState:
"""State container for a single parameter leaf in the White Kron optimizer.
This class stores the preconditioner state for individual parameter tensors,
supporting different processing paths based on parameter size and shape.
The optimizer handles three types of parameters:
- DENSE_PATH: Small 2D parameters use full dense Kronecker factors
- LARGE_PATH: Large 2D parameters use mixed dense/diagonal factors
- ONE_D_PATH: 1D parameters use diagonal-only preconditioners
Attributes:
kind (int): Processing path type (DENSE_PATH, LARGE_PATH, or ONE_D_PATH).
scanned (int): Whether this parameter is part of a scanned layer (0 or 1).
B (int): Batch dimension size (number of stacked parameter matrices).
shape (tuple[int, ...] | None): Original parameter shape (excluding batch dim).
merged (tuple[int, ...] | None): Shape after merging dimensions to 2D (m, n).
nr (int | None): Number of row blocks when using blocked processing.
nc (int | None): Number of column blocks when using blocked processing.
block_size (int | None): Block size for blocked processing.
diag_left (bool | None): Whether left factor uses diagonal approximation.
diag_right (bool | None): Whether right factor uses diagonal approximation.
stack (int | None): Total number of stacked blocks for parallel processing.
Ql (jax.Array | None): Left Kronecker factor(s).
Qr (jax.Array | None): Right Kronecker factor(s).
Ll (jax.Array | None): Left Lipschitz estimates.
Lr (jax.Array | None): Right Lipschitz estimates.
valid_rows (jax.Array | None): Valid row counts per block.
valid_cols (jax.Array | None): Valid column counts per block.
"""
kind: int = field(pytree_node=False)
scanned: int = field(pytree_node=False)
B: int = field(pytree_node=False)
shape: tuple[int, ...] | None = field(pytree_node=False, default=None)
merged: tuple[int, ...] | None = field(pytree_node=False, default=None)
nr: int | None = field(pytree_node=False, default=None)
nc: int | None = field(pytree_node=False, default=None)
block_size: int | None = field(pytree_node=False, default=None)
diag_left: bool | None = field(pytree_node=False, default=None)
diag_right: bool | None = field(pytree_node=False, default=None)
stack: int | None = field(pytree_node=False, default=None)
Ql: jax.Array | None = None
Qr: jax.Array | None = None
Ll: jax.Array | None = None
Lr: jax.Array | None = None
valid_rows: jax.Array | None = None
valid_cols: jax.Array | None = None
def _def_scale(
lr_style: str | None = "adam",
b1: float = 0.95,
normalize_grads: bool = False,
max_size_dense: int = 16384,
preconditioner_lr: float = 0.7,
preconditioner_init_scale: float = 1.0,
preconditioner_update_style: Literal["QUAD", "skew"] = "skew",
dtype: str | jnp.dtype = jnp.bfloat16,
scanned_layers: base.Params | None = None,
block_size: int = 256,
pipeline_axis_name: str | None = None,
pipeline_axis_size: int = 1,
params_partition_specs: PartitionSpec | list | tuple | dict | None = None,
noise_scale: float = 1e-9,
) -> base.GradientTransformation:
"""Create a White Kron gradient scaling transformation.
This is the core implementation for White Kron optimizers (Quad and Skew).
It applies Kronecker-factored preconditioning to gradients using an online
learning approach to maintain the preconditioner matrices.
The optimizer classifies parameters into three categories:
- Dense: Small 2D parameters (both dims <= max_size_dense) use full matrix factors
- Large: Large 2D parameters use mixed dense/diagonal factors
- 1D: Scalar and 1D parameters use diagonal-only preconditioning
Args:
lr_style (str | None): Learning rate scaling style. "adam" divides updates by 5.0
for Adam-like behavior. None uses raw preconditioned gradients. Defaults to "adam".
b1 (float): Momentum coefficient for exponential moving average of gradients.
Set to 0 to disable momentum. Defaults to 0.95.
normalize_grads (bool): Whether to normalize gradients to unit norm before
preconditioning. Can help with stability. Defaults to False.
max_size_dense (int): Maximum dimension size for dense Kronecker factors.
Parameters with both dimensions <= this value use full matrix factors.
Larger parameters use diagonal approximations. Defaults to 16384.
preconditioner_lr (float): Learning rate for preconditioner updates.
Controls how quickly the preconditioner adapts. Defaults to 0.7.
preconditioner_init_scale (float): Initial scale for preconditioner matrices.
Identity matrices are scaled by this value. Defaults to 1.0.
preconditioner_update_style (Literal["QUAD", "skew"]): Update rule for
preconditioners. "QUAD" uses quadratic updates, "skew" uses Procrustes
orthogonalization. Defaults to "skew".
dtype (str | jnp.dtype): Data type for preconditioner storage. Must be
bfloat16 or float32. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None): PyTree indicating which parameters are
from scanned layers (True) vs regular layers (False). Defaults to None.
block_size (int): Block size for blocking large matrices. Controls memory
and compute tradeoffs. Defaults to 256.
pipeline_axis_name (str | None): Name of the pipeline parallel axis for
sharding preconditioner state. Defaults to None.
pipeline_axis_size (int): Size of the pipeline parallel axis. Defaults to 1.
params_partition_specs (PartitionSpec | list | tuple | dict | None): Partition
specs for model parameters. Used to maintain sharding of momentum. Defaults to None.
noise_scale (float): Scale of random noise added to gradients for numerical
stability. Defaults to 1e-9.
Returns:
base.GradientTransformation: Gradient transformation implementing White Kron
preconditioning.
Raises:
ValueError: If dtype is not bfloat16 or float32.
ValueError: If preconditioner_update_style is not "QUAD" or "skew".
"""
dtype = canonicalize_dtype(dtype)
if dtype not in (jnp.bfloat16, jnp.float32):
raise ValueError("dtype must be bfloat16 or float32")
if preconditioner_update_style not in ("QUAD", "skew"):
raise ValueError("preconditioner_update_style must be QUAD or skew")
def init_fn(params):
params_unboxed = jax.tree.map(lambda x: x, params, is_leaf=lambda x: False)
mu = None
if b1 > 0:
mu = jax.tree.map(lambda p: jnp.zeros_like(p, dtype=dtype), params_unboxed)
if params_partition_specs is not None:
mu = with_sharding_constraint(mu, params_partition_specs)
scanned_flags = scanned_layers if scanned_layers is not None else jax.tree.map(lambda _: False, params_unboxed)
dense_Ql_list: list[jax.Array] = []
dense_Qr_list: list[jax.Array] = []
dense_Ll_list: list[jax.Array] = []
dense_Lr_list: list[jax.Array] = []
dense_valid_rc: list[tuple[int, int]] = []
large_state: list[Any] = []
leaves, _tdef = jax.tree.flatten(params_unboxed)
flags, _ = jax.tree.flatten(scanned_flags)
for leaf, scanned in zip(leaves, flags, strict=False):
p = leaf if scanned else leaf[None, ...]
B = p.shape[0]
shape_wo = p.shape[1:]
merged = _merge_dims(shape_wo)
if len(merged) <= 1:
m_flat = int(np.prod(shape_wo)) if len(shape_wo) > 0 else 1
Ql = jnp.ones((B, m_flat), dtype=dtype) * preconditioner_init_scale
Ll = jnp.zeros((B,), jnp.float32)
large_state.append(
LeafState(kind=ONE_D_PATH, scanned=int(scanned), B=B, shape=shape_wo, merged=(m_flat,), Ql=Ql, Ll=Ll)
)
continue
m, n = merged
is_large_m = m > max_size_dense
is_large_n = n > max_size_dense
is_dense = (not is_large_m) and (not is_large_n)
if is_dense:
nr, nc = (m + block_size - 1) // block_size, (n + block_size - 1) // block_size
row_sizes = [block_size] * (nr - 1) + [m - block_size * (nr - 1) if nr > 0 else 0]
col_sizes = [block_size] * (nc - 1) + [n - block_size * (nc - 1) if nc > 0 else 0]
row_sizes = [rs if rs > 0 else block_size for rs in row_sizes]
col_sizes = [cs if cs > 0 else block_size for cs in col_sizes]
for _b in range(B):
for ri in range(nr):
for cj in range(nc):
vr, vc = row_sizes[ri], col_sizes[cj]
Ql = _identity_padded(block_size, vr, dtype) * preconditioner_init_scale
Qr = _identity_padded(block_size, vc, dtype) * preconditioner_init_scale
dense_Ql_list.append(Ql)
dense_Qr_list.append(Qr)
dense_Ll_list.append(jnp.zeros([], jnp.float32))
dense_Lr_list.append(jnp.zeros([], jnp.float32))
dense_valid_rc.append((vr, vc))
large_state.append(
LeafState(
kind=DENSE_PATH, scanned=int(scanned), B=B, merged=(m, n), nr=nr, nc=nc, block_size=block_size
)
)
else:
diag_left = is_large_m
diag_right = is_large_n
if diag_left and diag_right:
Ql = jnp.ones((B, m), dtype=dtype) * preconditioner_init_scale
Qr = jnp.ones((B, n), dtype=dtype) * preconditioner_init_scale
Ll = jnp.zeros((B,), jnp.float32)
Lr = jnp.zeros((B,), jnp.float32)
large_state.append(
LeafState(
kind=LARGE_PATH,
scanned=int(scanned),
B=B,
merged=(m, n),
diag_left=True,
diag_right=True,
Ql=Ql,
Qr=Qr,
Ll=Ll,
Lr=Lr,
stack=B,
)
)
elif diag_left != diag_right:
block_rows = not diag_left
dim_to_block = m if block_rows else n
other_dim = n if block_rows else m
num_blocks_per_sample = (dim_to_block + block_size - 1) // block_size
stack = B * num_blocks_per_sample
Q_diag = jnp.broadcast_to(
jnp.ones((1, other_dim), dtype=dtype) * preconditioner_init_scale, (stack, other_dim)
)
Q_blocked_blocks = []
for _ in range(B):
for i in range(num_blocks_per_sample):
v = (
block_size
if i < num_blocks_per_sample - 1
else (
dim_to_block - block_size * (num_blocks_per_sample - 1)
if num_blocks_per_sample > 0
else block_size
)
)
v = v if v > 0 else block_size
Q_blocked_blocks.append(_identity_padded(block_size, v, dtype) * preconditioner_init_scale)
Q_blocked = jnp.stack(Q_blocked_blocks, axis=0)
Ql = Q_blocked if block_rows else Q_diag
Qr = Q_diag if block_rows else Q_blocked
Ll = jnp.zeros((stack,), jnp.float32)
Lr = jnp.zeros((stack,), jnp.float32)
large_state.append(
LeafState(
kind=LARGE_PATH,
scanned=int(scanned),
B=B,
merged=(m, n),
diag_left=diag_left,
diag_right=diag_right,
Ql=Ql,
Qr=Qr,
Ll=Ll,
Lr=Lr,
stack=stack,
nr=num_blocks_per_sample if block_rows else None,
nc=num_blocks_per_sample if not block_rows else None,
block_size=block_size,
)
)
else:
raise AssertionError("unexpected large case.")
if dense_Ql_list:
Ql_cat = jnp.stack(dense_Ql_list, axis=0)
Qr_cat = jnp.stack(dense_Qr_list, axis=0)
Ll_cat = jnp.stack(dense_Ll_list, axis=0)
Lr_cat = jnp.stack(dense_Lr_list, axis=0)
valid_rows = jnp.array([vr for (vr, _) in dense_valid_rc], dtype=jnp.int32)
valid_cols = jnp.array([vc for (_, vc) in dense_valid_rc], dtype=jnp.int32)
valid_count = Ql_cat.shape[0]
if pipeline_axis_size > 1:
pad = (-valid_count) % pipeline_axis_size
else:
pad = 0
if pad > 0:
eye = jnp.eye(block_size, dtype=dtype)
Ql_pad = jnp.broadcast_to(eye, (pad, block_size, block_size))
Qr_pad = jnp.broadcast_to(eye, (pad, block_size, block_size))
Ll_pad = jnp.ones((pad,), jnp.float32)
Lr_pad = jnp.ones((pad,), jnp.float32)
Ql_cat = jnp.concatenate([Ql_cat, Ql_pad], axis=0)
Qr_cat = jnp.concatenate([Qr_cat, Qr_pad], axis=0)
Ll_cat = jnp.concatenate([Ll_cat, Ll_pad], axis=0)
Lr_cat = jnp.concatenate([Lr_cat, Lr_pad], axis=0)
valid_rows = jnp.concatenate([valid_rows, jnp.full((pad,), block_size, jnp.int32)], axis=0)
valid_cols = jnp.concatenate([valid_cols, jnp.full((pad,), block_size, jnp.int32)], axis=0)
if pipeline_axis_name is not None:
Ql_cat = with_sharding_constraint(Ql_cat, PartitionSpec(pipeline_axis_name))
Qr_cat = with_sharding_constraint(Qr_cat, PartitionSpec(pipeline_axis_name))
Ll_cat = with_sharding_constraint(Ll_cat, PartitionSpec(pipeline_axis_name))
Lr_cat = with_sharding_constraint(Lr_cat, PartitionSpec(pipeline_axis_name))
valid_rows = with_sharding_constraint(valid_rows, PartitionSpec(pipeline_axis_name))
valid_cols = with_sharding_constraint(valid_cols, PartitionSpec(pipeline_axis_name))
dense_state = DenseState(
Ql=Ql_cat,
Qr=Qr_cat,
Ll=Ll_cat,
Lr=Lr_cat,
valid_rows=valid_rows,
valid_cols=valid_cols,
valid_count=int(valid_count),
block_size=int(block_size),
)
else:
dense_state = None
for i, st in enumerate(large_state):
if st.kind != LARGE_PATH:
continue
updates = {}
current_Ql, current_Qr, current_Ll, current_Lr = st.Ql, st.Qr, st.Ll, st.Lr
current_stack = st.stack
m, n = st.merged
if pipeline_axis_size > 1:
pad = (-current_stack) % pipeline_axis_size
else:
pad = 0
if pad > 0:
if st.diag_left and st.diag_right:
current_Ql = jnp.pad(st.Ql, ((0, pad), (0, 0)), constant_values=1.0)
current_Qr = jnp.pad(st.Qr, ((0, pad), (0, 0)), constant_values=1.0)
elif st.diag_left and (not st.diag_right):
eye = jnp.eye(st.block_size, dtype=dtype)
current_Ql = jnp.pad(st.Ql, ((0, pad), (0, 0)), constant_values=1.0)
current_Qr = jnp.concatenate([st.Qr, jnp.broadcast_to(eye, (pad, eye.shape[0], eye.shape[1]))], 0)
elif (not st.diag_left) and st.diag_right:
eye = jnp.eye(st.block_size, dtype=dtype)
current_Ql = jnp.concatenate([st.Ql, jnp.broadcast_to(eye, (pad, eye.shape[0], eye.shape[1]))], 0)
current_Qr = jnp.pad(st.Qr, ((0, pad), (0, 0)), constant_values=1.0)
else:
raise AssertionError
current_Ll = jnp.pad(st.Ll, ((0, pad),), constant_values=1.0)
current_Lr = jnp.pad(st.Lr, ((0, pad),), constant_values=1.0)
current_stack += pad
updates["Ql"] = current_Ql
updates["Qr"] = current_Qr
updates["Ll"] = current_Ll
updates["Lr"] = current_Lr
updates["stack"] = current_stack
if st.diag_left and st.diag_right:
current_valid_rows = jnp.full((current_stack,), m, jnp.int32)
current_valid_cols = jnp.full((current_stack,), n, jnp.int32)
elif st.diag_left != st.diag_right:
block_rows = not st.diag_left
num_blocks_per_sample = st.nr if block_rows else st.nc
dim_to_block = m if block_rows else n
other_dim = n if block_rows else m
if num_blocks_per_sample and num_blocks_per_sample > 0:
last_block_v = dim_to_block - st.block_size * (num_blocks_per_sample - 1)
v_one_sample = (
jnp.full((num_blocks_per_sample,), st.block_size, dtype=jnp.int32).at[-1].set(last_block_v)
)
else:
v_one_sample = jnp.array([], dtype=jnp.int32)
v_all_samples = jnp.tile(v_one_sample, st.B)
if v_all_samples.shape[0] < current_stack:
p = current_stack - v_all_samples.shape[0]
pad_vals = jnp.full((p,), st.block_size, dtype=jnp.int32)
v_all_samples = jnp.concatenate([v_all_samples, pad_vals], axis=0)
other_dim_arr = jnp.full_like(v_all_samples, other_dim)
if block_rows:
current_valid_rows = v_all_samples
current_valid_cols = other_dim_arr
else:
current_valid_rows = other_dim_arr
current_valid_cols = v_all_samples
else:
raise AssertionError
if pipeline_axis_name is not None:
updates["Ql"] = with_sharding_constraint(updates["Ql"], PartitionSpec(pipeline_axis_name))
updates["Qr"] = with_sharding_constraint(updates["Qr"], PartitionSpec(pipeline_axis_name))
updates["Ll"] = with_sharding_constraint(updates["Ll"], PartitionSpec(pipeline_axis_name))
updates["Lr"] = with_sharding_constraint(updates["Lr"], PartitionSpec(pipeline_axis_name))
current_valid_rows = with_sharding_constraint(current_valid_rows, PartitionSpec(pipeline_axis_name))
current_valid_cols = with_sharding_constraint(current_valid_cols, PartitionSpec(pipeline_axis_name))
updates["valid_rows"] = current_valid_rows
updates["valid_cols"] = current_valid_cols
large_state[i] = st.replace(**updates)
opt_state = dict(count=jnp.zeros([], jnp.int32), mu=mu, large=large_state)
if dense_state is not None:
opt_state["dense"] = dense_state
return opt_state
def update_fn(updates: base.Updates, state: dict, params: base.Params | None = None):
step = safe_int32_increment(state["count"])
plr = jnp.maximum(preconditioner_lr * jax.lax.rsqrt(1.0 + step / 10000.0), 0.4)
balance = jnp.equal(step % 100, 0)
if preconditioner_update_style == "QUAD":
dense_update_fn = _dense_update
diag_update_fn = _diag_update
elif preconditioner_update_style == "skew":
dense_update_fn = _dense_update_q0p5eq1p5
diag_update_fn = _diag_update_q0p5eq1p5
else:
raise ValueError(f"Unknown preconditioner_update_style: {preconditioner_update_style}")
mu = state["mu"]
mupd = updates
if mu is not None and b1 > 0:
mu = otu.tree_update_moment(updates, mu, b1, 1)
if params_partition_specs is not None:
mu = with_sharding_constraint(mu, params_partition_specs)
mupd = otu.tree_bias_correction(mu, b1, step)
mu = otu.tree_cast(mu, dtype) if mu is not None else None
mupd = otu.tree_cast(mupd, dtype)
if normalize_grads:
mupd = jax.tree.map(lambda g: g / (jnp.linalg.norm(g) + 1e-6), mupd)
leaves_u, tdef_u = jax.tree.flatten(mupd)
perleaf_state: list[Any] = state["large"]
dense_state = state.get("dense")
pg_dense_blocks: jax.Array | None = None
dense_block_count = 0
if dense_state is not None:
blocks_list = []
for leaf, st in zip(leaves_u, perleaf_state, strict=False):
if st.kind != DENSE_PATH:
continue
B = st.B
m, n = st.merged
nr, nc = st.nr, st.nc
x2d = jnp.reshape(leaf, (B, m, n))
current_block_size = dense_state.block_size
blocks, _ = _block2d(x2d, current_block_size)
blocks_list.append(blocks)
if blocks_list:
grads_cat = jnp.concatenate(blocks_list, axis=0)
dense_block_count = grads_cat.shape[0]
state_len = dense_state.Ql.shape[0]
if dense_block_count < state_len:
pad = state_len - dense_block_count
grads_cat = jnp.concatenate(
[grads_cat, jnp.ones((pad, current_block_size, current_block_size), grads_cat.dtype)], axis=0
)
elif dense_block_count > state_len:
raise ValueError(
"dense concatenation produced more blocks than q state. check block_size/grouping consistency."
)
if pipeline_axis_name is not None:
grads_cat = with_sharding_constraint(grads_cat, PartitionSpec(pipeline_axis_name))
key_dense = jax.random.fold_in(jax.random.PRNGKey(42), step)
keys = jax.random.split(key_dense, grads_cat.shape[0])
if pipeline_axis_name is not None:
keys = with_sharding_constraint(keys, PartitionSpec(pipeline_axis_name))
diag_left = False
diag_right = False
valid_shape_dense = jnp.stack([dense_state.valid_rows, dense_state.valid_cols], axis=1)
if pipeline_axis_name is not None:
valid_shape_dense = with_sharding_constraint(valid_shape_dense, PartitionSpec(pipeline_axis_name))
Ql_c = with_sharding_constraint(dense_state.Ql, PartitionSpec(pipeline_axis_name))
Qr_c = with_sharding_constraint(dense_state.Qr, PartitionSpec(pipeline_axis_name))
Ll_in = with_sharding_constraint(dense_state.Ll, PartitionSpec(pipeline_axis_name))
Lr_in = with_sharding_constraint(dense_state.Lr, PartitionSpec(pipeline_axis_name))
else:
Ql_c = dense_state.Ql
Qr_c = dense_state.Qr
Ll_in = dense_state.Ll
Lr_in = dense_state.Lr
Ql_in, Qr_in = jax.lax.cond(balance, lambda p: _balance_qs(p[0], p[1]), lambda p: p, (Ql_c, Qr_c))
Ql_new, Qr_new, Ll_new, Lr_new, Pg_cat = vmap(
_preconditioning, in_axes=(0, 0, 0, 0, 0, 0, 0, None, None, None, None, None, None)
)(
keys,
Ql_in,
Qr_in,
Ll_in,
Lr_in,
grads_cat,
valid_shape_dense,
diag_left,
diag_right,
plr,
noise_scale,
diag_update_fn,
dense_update_fn,
)
if pipeline_axis_name is not None:
Pg_cat = with_sharding_constraint(Pg_cat, PartitionSpec(pipeline_axis_name))
state["dense"] = dense_state.replace(
Ql=(
with_sharding_constraint(otu.tree_cast(Ql_new, dtype), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Ql_new, dtype)
),
Qr=(
with_sharding_constraint(otu.tree_cast(Qr_new, dtype), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Qr_new, dtype)
),
Ll=(
with_sharding_constraint(otu.tree_cast(Ll_new, jnp.float32), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Ll_new, jnp.float32)
),
Lr=(
with_sharding_constraint(otu.tree_cast(Lr_new, jnp.float32), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Lr_new, jnp.float32)
),
)
valid_count = dense_state.valid_count
Pg_cat = Pg_cat[:valid_count]
pg_dense_blocks = Pg_cat
start_idx = 0
for leaf_idx, (leaf, st) in enumerate(zip(leaves_u, perleaf_state, strict=False)):
if st.kind != DENSE_PATH:
continue
B = st.B
m, n = st.merged
nr, nc = st.nr, st.nc
nb = B * nr * nc
blocks = pg_dense_blocks[start_idx : start_idx + nb]
start_idx += nb
rec = _unblock2d(blocks, (nr, nc, m, n), dense_state.block_size)
leaves_u[leaf_idx] = jnp.reshape(rec, leaf.shape)
for leaf_idx, (leaf, st) in enumerate(zip(leaves_u, perleaf_state, strict=False)):
if st.kind != LARGE_PATH:
continue
B = st.B
m, n = st.merged
diag_left = st.diag_left
diag_right = st.diag_right
p2d = jnp.reshape(leaf, (B, m, n))
if diag_left and diag_right:
Gs = p2d
stack = st.stack
if Gs.shape[0] < stack:
pad = stack - Gs.shape[0]
Gs = jnp.concatenate([Gs, jnp.ones((pad, m, n), Gs.dtype)], axis=0)
if pipeline_axis_name is not None:
Gs = with_sharding_constraint(Gs, PartitionSpec(pipeline_axis_name))
key = jax.random.fold_in(jax.random.PRNGKey(43), step)
keys = jax.random.split(key, stack)
if pipeline_axis_name is not None:
keys = with_sharding_constraint(keys, PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None:
Ql_c = with_sharding_constraint(st.Ql, PartitionSpec(pipeline_axis_name))
Qr_c = with_sharding_constraint(st.Qr, PartitionSpec(pipeline_axis_name))
Ll_in = with_sharding_constraint(st.Ll, PartitionSpec(pipeline_axis_name))
Lr_in = with_sharding_constraint(st.Lr, PartitionSpec(pipeline_axis_name))
else:
Ql_c, Qr_c, Ll_in, Lr_in = st.Ql, st.Qr, st.Ll, st.Lr
Ql_in, Qr_in = jax.lax.cond(balance, lambda p: _balance_qs(p[0], p[1]), lambda p: p, (Ql_c, Qr_c))
valid_shape_large = jnp.stack([st.valid_rows, st.valid_cols], axis=1)
if pipeline_axis_name is not None:
valid_shape_large = with_sharding_constraint(valid_shape_large, PartitionSpec(pipeline_axis_name))
Ql_new, Qr_new, Ll_new, Lr_new, Pg = vmap(
_preconditioning, in_axes=(0, 0, 0, 0, 0, 0, 0, None, None, None, None, None, None)
)(
keys,
Ql_in,
Qr_in,
Ll_in,
Lr_in,
Gs,
valid_shape_large,
True,
True,
plr,
noise_scale,
diag_update_fn,
dense_update_fn,
)
if pipeline_axis_name is not None:
Pg = with_sharding_constraint(Pg, PartitionSpec(pipeline_axis_name))
state["large"][leaf_idx] = st.replace(
Ql=(
with_sharding_constraint(otu.tree_cast(Ql_new, dtype), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Ql_new, dtype)
),
Qr=(
with_sharding_constraint(otu.tree_cast(Qr_new, dtype), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Qr_new, dtype)
),
Ll=(
with_sharding_constraint(otu.tree_cast(Ll_new, jnp.float32), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Ll_new, jnp.float32)
),
Lr=(
with_sharding_constraint(otu.tree_cast(Lr_new, jnp.float32), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Lr_new, jnp.float32)
),
)
Pg = Pg[:B]
leaves_u[leaf_idx] = jnp.reshape(Pg, leaf.shape)
elif diag_left != diag_right:
block_rows = not diag_left
unblock_fn_batched = _unblock_rows if block_rows else _unblock_cols
num_blocks_per_sample = st.nr if block_rows else st.nc
other_dim = n if block_rows else m
if block_rows:
Gs, meta = _block_rows(p2d, block_size)
else:
Gs, meta = _block_cols(p2d, block_size)
stack = st.stack
if Gs.shape[0] < stack:
pad = stack - Gs.shape[0]
pad_shape = (pad, block_size, other_dim) if block_rows else (pad, other_dim, block_size)
Gs = jnp.concatenate([Gs, jnp.ones(pad_shape, Gs.dtype)], axis=0)
if pipeline_axis_name is not None:
Gs = with_sharding_constraint(Gs, PartitionSpec(pipeline_axis_name))
key_val = 45 if block_rows else 44
key = jax.random.fold_in(jax.random.PRNGKey(key_val), step)
keys = jax.random.split(key, stack)
if pipeline_axis_name is not None:
keys = with_sharding_constraint(keys, PartitionSpec(pipeline_axis_name))
valid_shape_large = jnp.stack([st.valid_rows, st.valid_cols], axis=1)
if pipeline_axis_name is not None:
valid_shape_large = with_sharding_constraint(valid_shape_large, PartitionSpec(pipeline_axis_name))
Ql_c = with_sharding_constraint(st.Ql, PartitionSpec(pipeline_axis_name))
Qr_c = with_sharding_constraint(st.Qr, PartitionSpec(pipeline_axis_name))
Ll_in = with_sharding_constraint(st.Ll, PartitionSpec(pipeline_axis_name))
Lr_in = with_sharding_constraint(st.Lr, PartitionSpec(pipeline_axis_name))
else:
Ql_c, Qr_c, Ll_in, Lr_in = st.Ql, st.Qr, st.Ll, st.Lr
Ql_in, Qr_in = jax.lax.cond(balance, lambda p: _balance_qs(p[0], p[1]), lambda p: p, (Ql_c, Qr_c))
Ql_new, Qr_new, Ll_new, Lr_new, Pg = vmap(
_preconditioning, in_axes=(0, 0, 0, 0, 0, 0, 0, None, None, None, None, None, None)
)(
keys,
Ql_in,
Qr_in,
Ll_in,
Lr_in,
Gs,
valid_shape_large,
diag_left,
diag_right,
plr,
noise_scale,
diag_update_fn,
dense_update_fn,
)
if pipeline_axis_name is not None:
Pg = with_sharding_constraint(Pg, PartitionSpec(pipeline_axis_name))
state["large"][leaf_idx] = st.replace(
Ql=(
with_sharding_constraint(otu.tree_cast(Ql_new, dtype), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Ql_new, dtype)
),
Qr=(
with_sharding_constraint(otu.tree_cast(Qr_new, dtype), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Qr_new, dtype)
),
Ll=(
with_sharding_constraint(otu.tree_cast(Ll_new, jnp.float32), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Ll_new, jnp.float32)
),
Lr=(
with_sharding_constraint(otu.tree_cast(Lr_new, jnp.float32), PartitionSpec(pipeline_axis_name))
if pipeline_axis_name is not None
else otu.tree_cast(Lr_new, jnp.float32)
),
)
Pg = Pg[: (B * num_blocks_per_sample)]
rec = unblock_fn_batched(Pg, meta, block_size, B)
leaves_u[leaf_idx] = jnp.reshape(rec, leaf.shape)
leaves_mupd, _ = jax.tree.flatten(mupd)
for leaf_idx, (leaf, st) in enumerate(zip(leaves_u, perleaf_state, strict=False)): # noqa: B007
if st.kind != ONE_D_PATH:
continue
B = st.B
g = leaves_mupd[leaf_idx].astype(dtype)
g2d = jnp.reshape(g, (B, -1))
key = jax.random.fold_in(jax.random.PRNGKey(46), step)
keys = jax.random.split(key, B)
Ql_new, Ll_new, Pg_flat = vmap(_preconditioning_one_d, in_axes=(0, 0, 0, 0, None, None, None))(
keys, st.Ql, st.Ll, g2d, plr, noise_scale, diag_update_fn
)
state["large"][leaf_idx] = st.replace(Ql=otu.tree_cast(Ql_new, dtype), Ll=otu.tree_cast(Ll_new, jnp.float32))
leaves_u[leaf_idx] = jnp.reshape(Pg_flat, g.shape)
precond_all = tdef_u.unflatten(leaves_u)
if params_partition_specs is not None:
precond_all = with_sharding_constraint(precond_all, params_partition_specs)
precond_all = jax.tree.map(
lambda g: g * (1.1 / jnp.maximum(jnp.sqrt(jnp.mean(jnp.square(g))), 1.1)), precond_all
)
if lr_style == "adam":
precond_all = jax.tree.map(lambda g: g / jnp.array(5.0, g.dtype), precond_all)
state["count"] = step
state["mu"] = mu
return precond_all, state
return base.GradientTransformation(init_fn, update_fn)
[docs]def scale_by_skew(
lr_style: str | None = "adam",
b1: float = 0.95,
normalize_grads: bool = False,
max_size_dense: int = 16384,
preconditioner_lr: float = 0.7,
preconditioner_init_scale: float = 1.0,
dtype: str | jnp.dtype = jnp.bfloat16,
scanned_layers: base.Params | None = None,
block_size: int = 256,
pipeline_axis_name: str | None = None,
pipeline_axis_size: int = 1,
params_partition_specs: PartitionSpec | list | tuple | dict | None = None,
noise_scale: float = 1e-9,
) -> base.GradientTransformation:
"""Create a gradient scaling transformation using skew-style preconditioner updates.
The skew variant of White Kron uses Procrustes orthogonalization to maintain
near-orthogonal preconditioner matrices, which can provide more stable training.
Args:
lr_style (str | None): Learning rate scaling style. "adam" for Adam-like scaling.
Defaults to "adam".
b1 (float): Momentum coefficient. Defaults to 0.95.
normalize_grads (bool): Whether to normalize gradients. Defaults to False.
max_size_dense (int): Max dimension for dense factors. Defaults to 16384.
preconditioner_lr (float): Preconditioner learning rate. Defaults to 0.7.
preconditioner_init_scale (float): Initial preconditioner scale. Defaults to 1.0.
dtype (str | jnp.dtype): Storage dtype. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None): Scanned layer indicators. Defaults to None.
block_size (int): Block size for matrix partitioning. Defaults to 256.
pipeline_axis_name (str | None): Pipeline axis name. Defaults to None.
pipeline_axis_size (int): Pipeline axis size. Defaults to 1.
params_partition_specs: Parameter partition specs. Defaults to None.
noise_scale (float): Noise scale for stability. Defaults to 1e-9.
Returns:
base.GradientTransformation: Skew-style preconditioned gradient transformation.
Example:
>>> import optax
>>> from eformer.optimizers._tx import scale_by_skew
>>> optimizer = optax.chain(
... scale_by_skew(b1=0.95),
... optax.add_decayed_weights(0.1),
... optax.scale_by_learning_rate(1e-4),
... )
"""
return _def_scale(
lr_style=lr_style,
b1=b1,
normalize_grads=normalize_grads,
max_size_dense=max_size_dense,
preconditioner_lr=preconditioner_lr,
preconditioner_init_scale=preconditioner_init_scale,
preconditioner_update_style="skew",
dtype=dtype,
scanned_layers=scanned_layers,
block_size=block_size,
pipeline_axis_name=pipeline_axis_name,
pipeline_axis_size=pipeline_axis_size,
params_partition_specs=params_partition_specs,
noise_scale=noise_scale,
)
[docs]def scale_by_quad(
lr_style: str | None = "adam",
b1: float = 0.95,
normalize_grads: bool = False,
max_size_dense: int = 16384,
preconditioner_lr: float = 0.7,
preconditioner_init_scale: float = 1.0,
dtype: str | jnp.dtype = jnp.bfloat16,
scanned_layers: base.Params | None = None,
block_size: int = 256,
pipeline_axis_name: str | None = None,
pipeline_axis_size: int = 1,
params_partition_specs: PartitionSpec | list | tuple | dict | None = None,
noise_scale: float = 1e-9,
) -> base.GradientTransformation:
"""Create a gradient scaling transformation using QUAD-style preconditioner updates.
The QUAD variant of White Kron uses quadratic preconditioner updates that
directly minimize a quadratic loss in the preconditioner space.
Args:
lr_style (str | None): Learning rate scaling style. "adam" for Adam-like scaling.
Defaults to "adam".
b1 (float): Momentum coefficient. Defaults to 0.95.
normalize_grads (bool): Whether to normalize gradients. Defaults to False.
max_size_dense (int): Max dimension for dense factors. Defaults to 16384.
preconditioner_lr (float): Preconditioner learning rate. Defaults to 0.7.
preconditioner_init_scale (float): Initial preconditioner scale. Defaults to 1.0.
dtype (str | jnp.dtype): Storage dtype. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None): Scanned layer indicators. Defaults to None.
block_size (int): Block size for matrix partitioning. Defaults to 256.
pipeline_axis_name (str | None): Pipeline axis name. Defaults to None.
pipeline_axis_size (int): Pipeline axis size. Defaults to 1.
params_partition_specs: Parameter partition specs. Defaults to None.
noise_scale (float): Noise scale for stability. Defaults to 1e-9.
Returns:
base.GradientTransformation: QUAD-style preconditioned gradient transformation.
Example:
>>> import optax
>>> from eformer.optimizers._tx import scale_by_quad
>>> optimizer = optax.chain(
... scale_by_quad(b1=0.95),
... optax.add_decayed_weights(0.1),
... optax.scale_by_learning_rate(1e-4),
... )
"""
return _def_scale(
lr_style=lr_style,
b1=b1,
normalize_grads=normalize_grads,
max_size_dense=max_size_dense,
preconditioner_lr=preconditioner_lr,
preconditioner_init_scale=preconditioner_init_scale,
preconditioner_update_style="QUAD",
dtype=dtype,
scanned_layers=scanned_layers,
block_size=block_size,
pipeline_axis_name=pipeline_axis_name,
pipeline_axis_size=pipeline_axis_size,
params_partition_specs=params_partition_specs,
noise_scale=noise_scale,
)
[docs]def skew(
learning_rate: float | Callable[[int], float] = 0.001,
lr_style: str | None = "adam",
b1: float = 0.95,
weight_decay: float = 0.1,
weight_decay_mask: Any | Callable[[base.Params], Any] | None = None,
normalize_grads: bool = False,
max_size_dense: int = 16384,
preconditioner_lr: float = 0.7,
preconditioner_init_scale: float = 1.0,
dtype: str | jnp.dtype = jnp.bfloat16,
scanned_layers: base.Params | None = None,
block_size: int = 256,
pipeline_axis_name: str | None = None,
pipeline_axis_size: int = 1,
params_partition_specs: PartitionSpec | list | tuple | dict | None = None,
noise_scale: float = 1e-9,
) -> base.GradientTransformation:
"""Complete Skew optimizer with weight decay and learning rate scheduling.
Skew is a Kronecker-factored preconditioned optimizer using Procrustes
orthogonalization to maintain near-orthogonal preconditioner matrices.
This provides efficient second-order optimization with stable training.
Args:
learning_rate (float | Callable[[int], float]): Learning rate or schedule.
Defaults to 0.001.
lr_style (str | None): Learning rate scaling style. Defaults to "adam".
b1 (float): Momentum coefficient. Defaults to 0.95.
weight_decay (float): Weight decay coefficient. Defaults to 0.1.
weight_decay_mask: Mask for selective weight decay. Defaults to None.
normalize_grads (bool): Whether to normalize gradients. Defaults to False.
max_size_dense (int): Max dimension for dense factors. Defaults to 16384.
preconditioner_lr (float): Preconditioner learning rate. Defaults to 0.7.
preconditioner_init_scale (float): Initial preconditioner scale. Defaults to 1.0.
dtype (str | jnp.dtype): Storage dtype. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None): Scanned layer indicators. Defaults to None.
block_size (int): Block size for matrix partitioning. Defaults to 256.
pipeline_axis_name (str | None): Pipeline axis name. Defaults to None.
pipeline_axis_size (int): Pipeline axis size. Defaults to 1.
params_partition_specs: Parameter partition specs. Defaults to None.
noise_scale (float): Noise scale for stability. Defaults to 1e-9.
Returns:
base.GradientTransformation: Complete Skew optimizer transformation.
Example:
>>> from eformer.optimizers._tx import skew
>>> # With constant learning rate
>>> optimizer = skew(learning_rate=1e-4, b1=0.95, weight_decay=0.1)
>>> # With learning rate schedule
>>> import optax
>>> schedule = optax.cosine_decay_schedule(1e-4, 10000)
>>> optimizer = skew(learning_rate=schedule)
"""
tx = [
scale_by_skew(
lr_style=lr_style,
b1=b1,
normalize_grads=normalize_grads,
max_size_dense=max_size_dense,
preconditioner_lr=preconditioner_lr,
preconditioner_init_scale=preconditioner_init_scale,
dtype=dtype,
scanned_layers=scanned_layers,
block_size=block_size,
pipeline_axis_name=pipeline_axis_name,
pipeline_axis_size=pipeline_axis_size,
params_partition_specs=params_partition_specs,
noise_scale=noise_scale,
)
]
if weight_decay > 0.0:
tx.append(transform.add_decayed_weights(weight_decay, weight_decay_mask))
tx.append(transform.scale_by_learning_rate(learning_rate))
return chain(*tx)
[docs]def quad(
learning_rate: float | Callable[[int], float] = 0.001,
lr_style: str | None = "adam",
b1: float = 0.95,
weight_decay: float = 0.1,
weight_decay_mask: Any | Callable[[base.Params], Any] | None = None,
normalize_grads: bool = False,
max_size_dense: int = 16384,
preconditioner_lr: float = 0.7,
preconditioner_init_scale: float = 1.0,
dtype: str | jnp.dtype = jnp.bfloat16,
scanned_layers: base.Params | None = None,
block_size: int = 256,
pipeline_axis_name: str | None = None,
pipeline_axis_size: int = 1,
params_partition_specs: PartitionSpec | list | tuple | dict | None = None,
noise_scale: float = 1e-9,
) -> base.GradientTransformation:
"""Complete Quad optimizer with weight decay and learning rate scheduling.
Quad is a Kronecker-factored preconditioned optimizer using quadratic
preconditioner updates that minimize a quadratic loss function.
This provides efficient second-order optimization.
Args:
learning_rate (float | Callable[[int], float]): Learning rate or schedule.
Defaults to 0.001.
lr_style (str | None): Learning rate scaling style. Defaults to "adam".
b1 (float): Momentum coefficient. Defaults to 0.95.
weight_decay (float): Weight decay coefficient. Defaults to 0.1.
weight_decay_mask: Mask for selective weight decay. Defaults to None.
normalize_grads (bool): Whether to normalize gradients. Defaults to False.
max_size_dense (int): Max dimension for dense factors. Defaults to 16384.
preconditioner_lr (float): Preconditioner learning rate. Defaults to 0.7.
preconditioner_init_scale (float): Initial preconditioner scale. Defaults to 1.0.
dtype (str | jnp.dtype): Storage dtype. Defaults to jnp.bfloat16.
scanned_layers (base.Params | None): Scanned layer indicators. Defaults to None.
block_size (int): Block size for matrix partitioning. Defaults to 256.
pipeline_axis_name (str | None): Pipeline axis name. Defaults to None.
pipeline_axis_size (int): Pipeline axis size. Defaults to 1.
params_partition_specs: Parameter partition specs. Defaults to None.
noise_scale (float): Noise scale for stability. Defaults to 1e-9.
Returns:
base.GradientTransformation: Complete Quad optimizer transformation.
Example:
>>> from eformer.optimizers._tx import quad
>>> # With constant learning rate
>>> optimizer = quad(learning_rate=1e-4, b1=0.95, weight_decay=0.1)
>>> # With learning rate schedule
>>> import optax
>>> schedule = optax.cosine_decay_schedule(1e-4, 10000)
>>> optimizer = quad(learning_rate=schedule)
"""
tx = [
scale_by_quad(
lr_style=lr_style,
b1=b1,
normalize_grads=normalize_grads,
max_size_dense=max_size_dense,
preconditioner_lr=preconditioner_lr,
preconditioner_init_scale=preconditioner_init_scale,
dtype=dtype,
scanned_layers=scanned_layers,
block_size=block_size,
pipeline_axis_name=pipeline_axis_name,
pipeline_axis_size=pipeline_axis_size,
params_partition_specs=params_partition_specs,
noise_scale=noise_scale,
)
]
if weight_decay > 0.0:
tx.append(transform.add_decayed_weights(weight_decay, weight_decay_mask))
tx.append(transform.scale_by_learning_rate(learning_rate))
return chain(*tx)
[docs]def get_opt_state_partition_specs(params, **quad_kwargs):
"""Generate partition specs for White Kron optimizer state.
This utility function creates JAX partition specifications for the optimizer
state, enabling proper sharding of the optimizer state across devices in
distributed training scenarios.
Args:
params: Model parameters, used to infer state structure and shapes.
**quad_kwargs: Keyword arguments for the Quad/Skew optimizer, including:
- lr_style: Learning rate scaling style
- b1: Momentum coefficient
- normalize_grads: Whether to normalize gradients
- max_size_dense: Max dimension for dense factors
- preconditioner_lr: Preconditioner learning rate
- preconditioner_init_scale: Initial preconditioner scale
- dtype: Storage dtype
- scanned_layers: Scanned layer indicators
- block_size: Block size for matrix partitioning
- pipeline_axis_name: Pipeline axis name for sharding
- pipeline_axis_size: Pipeline axis size
- params_partition_specs: Parameter partition specs
- noise_scale: Noise scale for stability
- weight_decay: Weight decay coefficient (used to determine state structure)
Returns:
tuple: Partition specs for the optimizer state. Structure depends on
whether weight decay is enabled:
- With weight decay > 0: (precond_specs, None, None)
- Without weight decay: (precond_specs, None)
Example:
>>> import jax.numpy as jnp
>>> from eformer.optimizers._tx.white_kron import get_opt_state_partition_specs
>>> params = {"layer1": jnp.zeros((128, 64)), "layer2": jnp.zeros((64, 32))}
>>> specs = get_opt_state_partition_specs(
... params,
... pipeline_axis_name="dp",
... pipeline_axis_size=4,
... weight_decay=0.1,
... )
"""
_allowed = {
"lr_style",
"b1",
"normalize_grads",
"max_size_dense",
"preconditioner_lr",
"preconditioner_init_scale",
"dtype",
"scanned_layers",
"block_size",
"pipeline_axis_name",
"pipeline_axis_size",
"params_partition_specs",
"noise_scale",
}
precond_kwargs = {k: v for k, v in quad_kwargs.items() if k in _allowed}
weight_decay = float(quad_kwargs.get("weight_decay", 0.0) or 0.0)
_no_constraint_kwargs = dict(precond_kwargs)
_no_constraint_kwargs["params_partition_specs"] = None
_no_constraint_kwargs["pipeline_axis_name"] = None
tx = _def_scale(**_no_constraint_kwargs)
state_shape = jax.eval_shape(tx.init, params)
pipeline_axis_name = precond_kwargs.get("pipeline_axis_name", None)
b1 = precond_kwargs.get("b1", 0.95)
params_partition_specs = precond_kwargs.get("params_partition_specs", None)
replicated = PartitionSpec()
def _leading_axis_spec(ndim: int) -> PartitionSpec:
if pipeline_axis_name is None or ndim == 0:
return replicated
return PartitionSpec(*([pipeline_axis_name] + [None] * (ndim - 1)))
if b1 and b1 > 0:
if params_partition_specs is not None:
mu_specs = params_partition_specs
else:
def _param_spec(p):
try:
return p.sharding.spec
except Exception:
return replicated
mu_specs = jax.tree.map(_param_spec, params)
else:
mu_specs = None
def _to_specs(x, key_path: tuple[Any, ...] = ()):
if isinstance(x, jax.ShapeDtypeStruct):
return _leading_axis_spec(x.ndim)
if isinstance(x, LeafState):
if x.kind == ONE_D_PATH:
return x.replace(
Ql=replicated if isinstance(x.Ql, jax.ShapeDtypeStruct) else None,
Qr=None,
Ll=replicated if isinstance(x.Ll, jax.ShapeDtypeStruct) else None,
Lr=None,
valid_rows=replicated if isinstance(x.valid_rows, jax.ShapeDtypeStruct) else None,
valid_cols=replicated if isinstance(x.valid_cols, jax.ShapeDtypeStruct) else None,
)
return x.replace(
Ql=_to_specs(x.Ql),
Qr=_to_specs(x.Qr),
Ll=_to_specs(x.Ll),
Lr=_to_specs(x.Lr),
valid_rows=_to_specs(x.valid_rows),
valid_cols=_to_specs(x.valid_cols),
)
if isinstance(x, DenseState):
return x.replace(
Ql=_to_specs(x.Ql),
Qr=_to_specs(x.Qr),
Ll=_to_specs(x.Ll),
Lr=_to_specs(x.Lr),
valid_rows=_to_specs(x.valid_rows),
valid_cols=_to_specs(x.valid_cols),
)
if isinstance(x, dict):
out = {}
for k, v in x.items():
if k == "mu" and mu_specs is not None:
out[k] = mu_specs
else:
out[k] = _to_specs(v, (*key_path, k))
return out
if isinstance(x, list | tuple):
mapped = [_to_specs(v, (*key_path, i)) for i, v in enumerate(x)]
return type(x)(mapped)
return None
precond_specs = _to_specs(state_shape)
if weight_decay > 0.0:
return (precond_specs, None, None)
else:
return (precond_specs, None)
# Lipschitz estimate decay rate for preconditioner updates
betaL = 0.95
def _diag_update(term1, term2, L, Q, lr_precond):
"""Update diagonal preconditioner using QUAD-style update.
Args:
term1: Diagonal term from gradient outer product.
term2: Normalization term (total elements / dimension).
L: Current Lipschitz estimate.
Q: Current diagonal preconditioner.
lr_precond: Preconditioner learning rate.
Returns:
tuple: (updated_Q, updated_L) - New preconditioner and Lipschitz estimate.
"""
ell = jnp.max(term1) + term2
L = jnp.maximum(betaL * L + (1 - betaL) * ell, ell)
z = (lr_precond / (2.0 * L)).astype(Q.dtype)
gain = 1.0 - z * (term1 - term2)
Qn = Q * (gain * gain)
return Qn, L
def _diag_update_q0p5eq1p5(term1, term2, L, Q, lr_precond):
"""Update diagonal preconditioner using skew-style update.
This variant uses a linear gain update rather than quadratic, providing
different stability characteristics.
Args:
term1: Diagonal term from gradient outer product.
term2: Normalization term (total elements / dimension).
L: Current Lipschitz estimate.
Q: Current diagonal preconditioner.
lr_precond: Preconditioner learning rate.
Returns:
tuple: (updated_Q, updated_L) - New preconditioner and Lipschitz estimate.
"""
ell = jnp.max(term1) + term2
L = jnp.maximum(betaL * L + (1 - betaL) * ell, ell)
z = (lr_precond / L).astype(Q.dtype)
gain = 1.0 - z * (term1 - term2)
Qn = Q * gain
return Qn, L
def _norm_lower_bound(key, A, k=4, iters=5, skh=False):
"""Estimate a lower bound on the spectral norm of a matrix.
Uses randomized power iteration to efficiently estimate the spectral norm
without computing a full SVD.
Args:
key: JAX random key for initialization.
A: Input matrix to estimate norm of.
k: Number of random vectors to use. Defaults to 4.
iters: Number of power iterations. Defaults to 5.
skh: If True, use max absolute value scaling; otherwise use diagonal scaling.
Defaults to False.
Returns:
float: Lower bound estimate of the spectral norm of A.
"""
if skh:
scale = jnp.max(jnp.abs(A))
else:
scale = jnp.max(jnp.diag(A))
A /= scale
mean_energies = jnp.mean(A * A, axis=1, keepdims=False)
j = jnp.argmax(mean_energies)
power = jax.lax.dynamic_index_in_dim(mean_energies, j, 0, keepdims=False)
max_vec = jax.lax.dynamic_index_in_dim(A, j, 0, keepdims=False)
x = (max_vec * jax.lax.rsqrt(power) + jax.random.normal(key, (k, A.shape[1]), A.dtype)) @ A
for _ in range(iters):
x = x / jnp.max(jnp.abs(x))
x = x @ A
x = (x / jnp.linalg.vector_norm(x, axis=1, keepdims=True)) @ A
return jnp.max(jnp.linalg.vector_norm(x, axis=1, keepdims=False)) * scale
def _dense_update(key, term1, term2, L, Q, lr_precond):
"""Update dense preconditioner using QUAD-style update.
Performs a symmetric update step that preserves positive definiteness
of the preconditioner matrix.
Args:
key: JAX random key for norm estimation.
term1: Matrix term from gradient outer product.
term2: Normalization term.
L: Current Lipschitz estimate.
Q: Current dense preconditioner matrix.
lr_precond: Preconditioner learning rate.
Returns:
tuple: (updated_Q, updated_L) - New preconditioner and Lipschitz estimate.
"""
ell = _norm_lower_bound(key, term1) + term2
L = jnp.maximum(betaL * L + (1 - betaL) * ell, ell)
z = (lr_precond / (2.0 * L)).astype(Q.dtype)
P = Q - z * (term1 @ Q - term2 * Q)
P = P - z * (P @ term1 - P * term2)
Qn = (P + P.T) / 2.0
return Qn, L
def _dense_update_q0p5eq1p5(key, term1, term2, L, Q, lr_precond):
"""Update dense preconditioner using skew-style update with Procrustes step.
This variant uses a Procrustes orthogonalization step to maintain near-orthogonal
preconditioner matrices, which can provide more stable training.
Args:
key: JAX random key for norm estimation and Procrustes step.
term1: Matrix term from gradient outer product.
term2: Normalization term.
L: Current Lipschitz estimate.
Q: Current dense preconditioner matrix.
lr_precond: Preconditioner learning rate.
Returns:
tuple: (updated_Q, updated_L) - New preconditioner and Lipschitz estimate.
"""
key1, key2 = jax.random.split(key)
ell = _norm_lower_bound(key1, term1) + term2
L = jnp.maximum(betaL * L + (1 - betaL) * ell, ell)
z = (lr_precond / L).astype(Q.dtype)
Q_updated = Q - z * (term1 @ Q - term2 * Q)
Qn = _procrustes_step(key2, Q_updated)
return Qn, L
def _procrustes_step(key, Q, max_step_size=1 / 8):
"""Perform a Procrustes orthogonalization step on a matrix.
Projects Q towards the nearest orthogonal matrix using a gradient step
on the skew-symmetric component of Q.
Args:
key: JAX random key for norm estimation.
Q: Input matrix to orthogonalize.
max_step_size: Maximum step size for the rotation update. Defaults to 1/8.
Returns:
jax.Array: Updated matrix closer to orthogonal.
"""
R = Q.T - Q
max_abs = jnp.max(jnp.abs(R))
def inner(R):
R = R / max_abs
RQ = R @ Q
tr_RQ = jnp.trace(RQ)
def do_rotation():
a = max_step_size / _norm_lower_bound(key, R, skh=True)
RRQ = R @ RQ
tr_RRQ = jnp.trace(RRQ)
a = jnp.where(tr_RRQ < 0, jnp.minimum(a, -tr_RQ / tr_RRQ), a)
return Q + a * (RQ + 0.5 * a * RRQ)
return jax.lax.cond(tr_RQ > 0, do_rotation, lambda: Q)
return jax.lax.cond(max_abs > jnp.finfo(Q.dtype).tiny, lambda: inner(R), lambda: Q)
def _preconditioning(
key: jax.Array,
Ql: jax.Array,
Qr: jax.Array,
Ll: jax.Array,
Lr: jax.Array,
G: jax.Array,
valid_shape: jax.Array,
diag_left: bool,
diag_right: bool,
lr_precond: jax.Array,
noise_scale: float,
diag_update_fn: Callable,
dense_update_fn: Callable,
):
"""Apply Kronecker-factored preconditioning to a gradient matrix.
This is the core preconditioning function that handles all combinations of
dense/diagonal left and right factors.
Args:
key: JAX random key for noise injection and norm estimation.
Ql: Left Kronecker factor (matrix or diagonal vector).
Qr: Right Kronecker factor (matrix or diagonal vector).
Ll: Left Lipschitz estimate.
Lr: Right Lipschitz estimate.
G: Gradient matrix to precondition.
valid_shape: Array of [valid_rows, valid_cols] for handling padding.
diag_left: Whether left factor is diagonal.
diag_right: Whether right factor is diagonal.
lr_precond: Preconditioner learning rate.
noise_scale: Scale of noise added for numerical stability.
diag_update_fn: Function for diagonal factor updates.
dense_update_fn: Function for dense factor updates.
Returns:
tuple: (Ql_new, Qr_new, Ll_new, Lr_new, Pg_out) containing updated factors,
Lipschitz estimates, and the preconditioned gradient.
"""
key1, key2 = jax.random.split(key)
m, n = valid_shape[0], valid_shape[1]
noise = jax.random.normal(key2, G.shape, G.dtype) * noise_scale
rows = jnp.arange(G.shape[0], dtype=jnp.int32) < m
cols = jnp.arange(G.shape[1], dtype=jnp.int32) < n
mask = rows[:, None] & cols[None, :]
Gn = G + noise * mask
m, n = jnp.asarray(m, dtype=G.dtype), jnp.asarray(n, dtype=G.dtype)
total_numel = m * n
if not diag_left and not diag_right:
Pg = jax.numpy.linalg.multi_dot([Ql.T, Ql, Gn, Qr.T, Qr])
key3, key4 = jax.random.split(key1)
term1L = Pg @ Pg.T
term2L = total_numel / m
Ql_new, Ll_new = dense_update_fn(key3, term1L, term2L, Ll, Ql, lr_precond)
term1R = Pg.T @ Pg
term2R = total_numel / n
Qr_new, Lr_new = dense_update_fn(key4, term1R, term2R, Lr, Qr, lr_precond)
Pg_out = jax.numpy.linalg.multi_dot([Ql_new.T, Ql_new, G, Qr_new.T, Qr_new])
elif diag_left and not diag_right:
Pg = (Ql * Ql)[:, None] * jax.numpy.linalg.multi_dot([Gn, Qr.T, Qr])
term1L = jnp.sum(Pg * Pg, axis=1)
term2L = total_numel / m
Ql_new, Ll_new = diag_update_fn(term1L, term2L, Ll, Ql, lr_precond)
term1R = Pg.T @ Pg
term2R = total_numel / n
Qr_new, Lr_new = dense_update_fn(key1, term1R, term2R, Lr, Qr, lr_precond)
Pg_out = (Ql_new * Ql_new)[:, None] * jax.numpy.linalg.multi_dot([G, Qr_new.T, Qr_new])
elif not diag_left and diag_right:
Pg = jax.numpy.linalg.multi_dot([Ql.T, Ql, Gn]) * (Qr * Qr)[None, :]
term1L = Pg @ Pg.T
term2L = total_numel / m
Ql_new, Ll_new = dense_update_fn(key1, term1L, term2L, Ll, Ql, lr_precond)
term1R = jnp.sum(Pg * Pg, axis=0)
term2R = total_numel / n
Qr_new, Lr_new = diag_update_fn(term1R, term2R, Lr, Qr, lr_precond)
Pg_out = jax.numpy.linalg.multi_dot([Ql_new.T, Ql_new, G]) * (Qr_new * Qr_new)[None, :]
else:
Pg = (Ql * Ql)[:, None] * Gn * (Qr * Qr)[None, :]
term1L = jnp.sum(Pg * Pg, axis=1)
term2L = total_numel / m
Ql_new, Ll_new = diag_update_fn(term1L, term2L, Ll, Ql, lr_precond)
term1R = jnp.sum(Pg * Pg, axis=0)
term2R = total_numel / n
Qr_new, Lr_new = diag_update_fn(term1R, term2R, Lr, Qr, lr_precond)
Pg_out = (Ql_new * Ql_new)[:, None] * G * (Qr_new * Qr_new)[None, :]
return Ql_new, Qr_new, Ll_new, Lr_new, Pg_out
def _preconditioning_one_d(key, Q, L, G, lr_precond, noise_scale, diag_update_fn):
"""Apply diagonal preconditioning to a 1D gradient vector.
Simplified preconditioning for 1D parameters using only diagonal factors.
Args:
key: JAX random key for noise injection.
Q: Diagonal preconditioner vector.
L: Current Lipschitz estimate.
G: 1D gradient vector to precondition.
lr_precond: Preconditioner learning rate.
noise_scale: Scale of noise added for stability.
diag_update_fn: Function for diagonal factor updates.
Returns:
tuple: (Q_new, L_new, Pg_out) containing updated preconditioner,
Lipschitz estimate, and preconditioned gradient.
"""
noise = jax.random.normal(key, G.shape, G.dtype) * noise_scale
Gn = G + noise
Pg = Q * Q * Gn
term1 = Pg * Pg
term2 = 1.0
Qn, Ln = diag_update_fn(term1, term2, L, Q, lr_precond)
Pg_out = Qn * Qn * G
return Qn, Ln, Pg_out
def _balance_qs(Ql, Qr):
"""Balance the scales of left and right Kronecker factors.
Normalizes the factors so that their maximum absolute values are equal,
preventing numerical issues from unbalanced factor scales.
Args:
Ql: Left Kronecker factors, shape [batch, ...].
Qr: Right Kronecker factors, shape [batch, ...].
Returns:
tuple: (balanced_Ql, balanced_Qr) with equalized scales.
"""
@vmap
def _balance_sample(ql, qr):
nl = jnp.max(jnp.abs(ql))
nr = jnp.max(jnp.abs(qr))
geometric_mean = jnp.sqrt(nl * nr)
sL = geometric_mean / nl
sR = geometric_mean / nr
return ql * sL, qr * sR
return _balance_sample(Ql, Qr)
def _block2d(x, block_size):
"""Block each [m, n] matrix in a [B, m, n] tensor into fixed-size blocks.
Partitions each 2D matrix into non-overlapping blocks of size [block_size, block_size],
padding as necessary. Useful for parallel processing of large matrices.
Args:
x: Input tensor of shape [B, m, n] containing B matrices.
block_size: Size of square blocks to partition into.
Returns:
tuple: (blocks, meta) where:
- blocks: Array of shape [B * nr * nc, block_size, block_size]
- meta: Tuple (nr, nc, m, n) with grid dimensions and original shape
"""
B, m, n = x.shape
nr, nc = (m + block_size - 1) // block_size, (n + block_size - 1) // block_size
pm, pn = nr * block_size, nc * block_size
dm, dn = pm - m, pn - n
xpad = jnp.pad(x, ((0, 0), (0, dm), (0, dn)))
x5 = xpad.reshape(B, nr, block_size, nc, block_size).transpose(0, 1, 3, 2, 4)
blocks = x5.reshape(B * nr * nc, block_size, block_size)
return blocks, (nr, nc, m, n)
def _unblock2d(blocks, meta, block_size):
"""Reconstruct matrices from blocked representation (inverse of _block2d).
Args:
blocks: Blocked tensor of shape [B * nr * nc, block_size, block_size].
meta: Tuple (nr, nc, m, n) from _block2d containing grid dimensions and original shape.
block_size: Size of square blocks.
Returns:
jax.Array: Reconstructed tensor of shape [B, m, n] with padding removed.
"""
nr, nc, m, n = meta
bs = block_size
B = blocks.shape[0] // (nr * nc)
x5 = blocks.reshape(B, nr, nc, bs, bs).transpose(0, 1, 3, 2, 4)
x = x5.reshape(B, nr * bs, nc * bs)
return x[:, :m, :n]
def _block_rows(x, block_size):
"""Block matrices along the row dimension only.
Partitions each matrix into horizontal strips of height block_size.
Args:
x: Input tensor of shape [B, m, n].
block_size: Height of row blocks.
Returns:
tuple: (blocks, meta) where:
- blocks: Array of shape [B * nr, block_size, n]
- meta: Tuple (nr, m, n) with number of row blocks and original shape
"""
B, m, n = x.shape
nr = (m + block_size - 1) // block_size
pm = nr * block_size
dm = pm - m
xpad = jnp.pad(x, ((0, 0), (0, dm), (0, 0)))
x3 = xpad.reshape(B, nr, block_size, n)
blocks = x3.reshape(B * nr, block_size, n)
return blocks, (nr, m, n)
def _unblock_rows(blocks, meta, block_size, B):
"""Reconstruct matrices from row-blocked representation (inverse of _block_rows).
Args:
blocks: Row-blocked tensor of shape [B * nr, block_size, n].
meta: Tuple (nr, m, n) from _block_rows.
block_size: Height of row blocks.
B: Batch size.
Returns:
jax.Array: Reconstructed tensor of shape [B, m, n].
"""
nr, m, n = meta
x3 = blocks.reshape(B, nr, block_size, n)
x = x3.reshape(B, nr * block_size, n)
return x[:, :m, :n]
def _block_cols(x, block_size):
"""Block matrices along the column dimension only.
Partitions each matrix into vertical strips of width block_size.
Args:
x: Input tensor of shape [B, m, n].
block_size: Width of column blocks.
Returns:
tuple: (blocks, meta) where:
- blocks: Array of shape [B * nc, m, block_size]
- meta: Tuple (nc, m, n) with number of column blocks and original shape
"""
B, m, n = x.shape
nc = (n + block_size - 1) // block_size
pn = nc * block_size
dn = pn - n
xpad = jnp.pad(x, ((0, 0), (0, 0), (0, dn)))
x4 = xpad.reshape(B, m, nc, block_size).transpose(0, 2, 1, 3)
blocks = x4.reshape(B * nc, m, block_size)
return blocks, (nc, m, n)
def _unblock_cols(blocks, meta, block_size, B):
"""Reconstruct matrices from column-blocked representation (inverse of _block_cols).
Args:
blocks: Column-blocked tensor of shape [B * nc, m, block_size].
meta: Tuple (nc, m, n) from _block_cols.
block_size: Width of column blocks.
B: Batch size.
Returns:
jax.Array: Reconstructed tensor of shape [B, m, n].
"""
nc, m, n = meta
x4 = blocks.reshape(B, nc, m, block_size).transpose(0, 2, 1, 3)
x = x4.reshape(B, m, nc * block_size)
return x[:, :, :n]
def _merge_dims(shape):
"""Merge a multi-dimensional shape into a 2D shape for Kronecker factorization.
Finds the optimal split point to reshape an N-dimensional tensor into a 2D matrix
with balanced dimensions, minimizing the aspect ratio.
Args:
shape: Original tensor shape as a tuple.
Returns:
tuple: Merged 2D shape. Returns original shape if already 1D or 2D,
or if the tensor is effectively 1D (all but one dim are 1).
"""
if len(shape) < 2:
return shape
if np.prod(shape) == np.max(shape):
return (np.max(shape),)
if len(shape) == 2:
return shape
dims = list(shape)
best_ratio, best_split = float("inf"), 1
for s in range(1, len(dims)):
lp, rp = np.prod(dims[:s]), np.prod(dims[s:])
ratio = max(lp, rp) / min(lp, rp)
if ratio < best_ratio:
best_ratio, best_split = ratio, s
return (np.prod(dims[:best_split]), np.prod(dims[best_split:]))
def _identity_padded(block_size, valid, dtype):
"""Create a padded identity matrix for initializing preconditioners.
Creates an identity matrix of the specified valid size, padded with zeros
to reach the full block_size.
Args:
block_size: Target matrix size after padding.
valid: Number of valid rows/columns (size of identity submatrix).
dtype: Data type for the output matrix.
Returns:
jax.Array: Identity matrix of shape [block_size, block_size] with
identity in top-left [valid, valid] corner and zeros elsewhere.
"""
if valid >= block_size:
return jnp.eye(block_size, dtype=dtype)
eye = jnp.eye(valid, dtype=dtype)
return jnp.pad(eye, ((0, block_size - valid), (0, block_size - valid)), constant_values=0)