eformer.mpric.loss_scaling.loss_scaler#

Loss scaling implementations for mixed precision training.

This module provides loss scaling utilities to prevent gradient underflow and overflow when training with low-precision dtypes like float16. Dynamic loss scaling automatically adjusts the scale factor based on gradient health.

class eformer.mpric.loss_scaling.loss_scaler.DynamicLossScale(loss_scale: ~jax.jaxlib._jax.Array, counter: ~jax.jaxlib._jax.Array = <factory>, period: int = 2000, factor: int = 2, min_loss_scale: ~jax.jaxlib._jax.Array = <factory>)[source]#

Bases: object

Dynamic loss scaling for mixed precision training.

This class implements dynamic loss scaling, which automatically adjusts the loss scale factor during training to prevent gradient underflow and overflow. This is essential for stable training with low-precision dtypes like float16.

The algorithm works as follows: 1. Scale the loss by the current loss_scale before computing gradients 2. After computing gradients, unscale them by dividing by loss_scale 3. Check if gradients are finite (no NaN or Inf values) 4. If finite: increment counter, increase scale after period steps 5. If not finite: reduce scale by factor, reset counter

This class is immutable (frozen dataclass), so adjust() returns a new instance rather than modifying in place. This design is compatible with JAX’s functional programming paradigm.

loss_scale#

Current loss scale value as a JAX array.

Type

jax.jaxlib._jax.Array

counter#

Number of consecutive steps with finite gradients.

Type

jax.jaxlib._jax.Array

period#

Steps between scale increases when gradients are stable.

Type

int

factor#

Multiplicative factor for scale adjustments.

Type

int

min_loss_scale#

Minimum allowed loss scale to prevent underflow.

Type

jax.jaxlib._jax.Array

Example

Basic usage in a training loop:

scaler = DynamicLossScale(
    loss_scale=jnp.array(2**15),
    period=2000,
    factor=2,
    min_loss_scale=jnp.array(1.0)
)

for batch in data_loader:
    # Compute loss and gradients
    loss, grads = compute_loss_and_grads(params, batch)

    # Scale and unscale
    scaled_loss = scaler.scale(loss)
    unscaled_grads = scaler.unscale(grads)

    # Check gradient health and update scaler
    grads_finite = check_finite(unscaled_grads)
    scaler = scaler.adjust(grads_finite)

    # Only update params if gradients are valid
    if grads_finite:
        params = update_params(params, unscaled_grads)
adjust(grads_finite: Array) DynamicLossScale[source]#

Adjust the loss scale based on gradient health.

This method implements the core dynamic scaling logic: - If gradients are finite: increment counter, optionally increase scale - If gradients are non-finite: decrease scale, reset counter

The scale is increased when the counter reaches (period - 1), indicating that gradients have been stable for period consecutive steps. The scale is decreased immediately when non-finite gradients are detected.

Parameters

grads_finite – A boolean JAX array indicating whether all gradients are finite (True) or contain NaN/Inf values (False).

Returns

A new DynamicLossScale instance with updated loss_scale and counter values.

Return type

DynamicLossScale

Example

>>> scaler = DynamicLossScale(loss_scale=jnp.array(1024.0))
>>> # Finite gradients - counter increases
>>> scaler = scaler.adjust(jnp.array(True))
>>> scaler.counter
Array(1, dtype=int32)
>>> # Non-finite gradients - scale decreases
>>> scaler = scaler.adjust(jnp.array(False))
>>> scaler.loss_scale
Array(512., dtype=float32)

Note

The method uses JAX’s lax.select for XLA-compatible conditional logic, ensuring the operation can be traced and JIT-compiled.

counter: Array#
factor: int = 2#
loss_scale: Array#
min_loss_scale: Array#
period: int = 2000#
scale(tree: T) T[source]#

Scale values by multiplying with the loss scale factor.

This method multiplies all values in the input PyTree by the current loss scale. Typically applied to the loss before gradient computation to prevent gradient underflow.

Parameters

tree – A PyTree of arrays to scale. Can be a single array, dict, list, tuple, or any nested structure of arrays.

Returns

A new PyTree with the same structure where all arrays have been multiplied by the loss_scale.

Example

>>> scaler = DynamicLossScale(loss_scale=jnp.array(1024.0))
>>> loss = jnp.array(0.001)
>>> scaled_loss = scaler.scale(loss)
>>> scaled_loss
Array(1.024, dtype=float32)
unscale(tree: T) T[source]#

Unscale values by dividing by the loss scale factor.

This method divides all values in the input PyTree by the current loss scale. Typically applied to gradients after computation to restore them to their true (unscaled) values.

Parameters

tree – A PyTree of scaled arrays to unscale. Can be a single array, dict, list, tuple, or any nested structure of arrays.

Returns

A new PyTree with the same structure where all arrays have been divided by the loss_scale.

Example

>>> scaler = DynamicLossScale(loss_scale=jnp.array(1024.0))
>>> scaled_grads = {"w": jnp.array(1024.0)}
>>> unscaled_grads = scaler.unscale(scaled_grads)
>>> unscaled_grads["w"]
Array(1.0, dtype=float32)
class eformer.mpric.loss_scaling.loss_scaler.LossScaleConfig(initial_scale: float = 32768, growth_interval: int = 2000, scale_factor: int = 2, min_scale: float = 1.0)[source]#

Bases: object

Configuration parameters for dynamic loss scaling behavior.

This dataclass holds the hyperparameters that control how dynamic loss scaling behaves during training. The defaults are tuned for typical mixed precision training scenarios.

initial_scale#

The starting loss scale value. Higher values provide more headroom for gradients but risk overflow. Default is 2^15 (32768), which works well for most models.

Type

float

growth_interval#

Number of consecutive steps with finite gradients required before increasing the loss scale. Default is 2000 steps.

Type

int

scale_factor#

Multiplicative factor for scaling adjustments. The scale is multiplied by this factor when increasing and divided by it when decreasing. Default is 2.

Type

int

min_scale#

Minimum allowed loss scale value. Prevents the scale from becoming too small, which could cause gradient underflow. Default is 1.0.

Type

float

Example

Default configuration:

config = LossScaleConfig()
# initial_scale=32768, growth_interval=2000, scale_factor=2, min_scale=1.0

Aggressive scaling for stable training:

config = LossScaleConfig(
    initial_scale=2**16,
    growth_interval=1000,
    scale_factor=2,
    min_scale=1.0
)

Conservative scaling for unstable models:

config = LossScaleConfig(
    initial_scale=2**10,
    growth_interval=5000,
    scale_factor=2,
    min_scale=1.0
)
growth_interval: int = 2000#
initial_scale: float = 32768#
min_scale: float = 1.0#
scale_factor: int = 2#
class eformer.mpric.loss_scaling.loss_scaler.NoOpLossScale[source]#

Bases: object

No-operation loss scaler that passes values through unchanged.

This class implements the loss scaler interface but performs no actual scaling. It is used when loss scaling is disabled, such as for full precision (float32) training where gradient underflow is not a concern.

Using NoOpLossScale instead of conditionally removing loss scaling logic simplifies code by maintaining a consistent interface regardless of whether scaling is enabled.

loss_scale#

Always returns 1 (no scaling applied).

Example

>>> scaler = NoOpLossScale()
>>> loss = jnp.array(0.5)
>>> scaled = scaler.scale(loss)
>>> scaled == loss
True
>>> scaler.adjust(True) is scaler
True
adjust(grads_finite: Array)[source]#

Adjust the loss scale based on gradient health (no-op).

Parameters

grads_finite – A boolean array indicating whether gradients are finite (contains no NaN or Inf values). Ignored by NoOpLossScale.

Returns

Returns self unchanged since no adjustment is needed.

Return type

NoOpLossScale

property loss_scale#

Return the loss scale value.

Returns

Always returns 1, indicating no scaling is applied.

Return type

int

scale(tree: T) T[source]#

Scale values by the loss scale factor (no-op).

Parameters

tree – A PyTree of arrays to scale.

Returns

The input tree unchanged.

unscale(tree: T) T[source]#

Unscale values by dividing by the loss scale factor (no-op).

Parameters

tree – A PyTree of scaled arrays to unscale.

Returns

The input tree unchanged.