eformer.mpric.loss_scaling.loss_scaler#
Loss scaling implementations for mixed precision training.
This module provides loss scaling utilities to prevent gradient underflow and overflow when training with low-precision dtypes like float16. Dynamic loss scaling automatically adjusts the scale factor based on gradient health.
- class eformer.mpric.loss_scaling.loss_scaler.DynamicLossScale(loss_scale: ~jax.jaxlib._jax.Array, counter: ~jax.jaxlib._jax.Array = <factory>, period: int = 2000, factor: int = 2, min_loss_scale: ~jax.jaxlib._jax.Array = <factory>)[source]#
Bases:
objectDynamic loss scaling for mixed precision training.
This class implements dynamic loss scaling, which automatically adjusts the loss scale factor during training to prevent gradient underflow and overflow. This is essential for stable training with low-precision dtypes like float16.
The algorithm works as follows: 1. Scale the loss by the current loss_scale before computing gradients 2. After computing gradients, unscale them by dividing by loss_scale 3. Check if gradients are finite (no NaN or Inf values) 4. If finite: increment counter, increase scale after period steps 5. If not finite: reduce scale by factor, reset counter
This class is immutable (frozen dataclass), so adjust() returns a new instance rather than modifying in place. This design is compatible with JAX’s functional programming paradigm.
- loss_scale#
Current loss scale value as a JAX array.
- Type
jax.jaxlib._jax.Array
- counter#
Number of consecutive steps with finite gradients.
- Type
jax.jaxlib._jax.Array
- period#
Steps between scale increases when gradients are stable.
- Type
int
- factor#
Multiplicative factor for scale adjustments.
- Type
int
- min_loss_scale#
Minimum allowed loss scale to prevent underflow.
- Type
jax.jaxlib._jax.Array
Example
Basic usage in a training loop:
scaler = DynamicLossScale( loss_scale=jnp.array(2**15), period=2000, factor=2, min_loss_scale=jnp.array(1.0) ) for batch in data_loader: # Compute loss and gradients loss, grads = compute_loss_and_grads(params, batch) # Scale and unscale scaled_loss = scaler.scale(loss) unscaled_grads = scaler.unscale(grads) # Check gradient health and update scaler grads_finite = check_finite(unscaled_grads) scaler = scaler.adjust(grads_finite) # Only update params if gradients are valid if grads_finite: params = update_params(params, unscaled_grads)
- adjust(grads_finite: Array) DynamicLossScale[source]#
Adjust the loss scale based on gradient health.
This method implements the core dynamic scaling logic: - If gradients are finite: increment counter, optionally increase scale - If gradients are non-finite: decrease scale, reset counter
The scale is increased when the counter reaches (period - 1), indicating that gradients have been stable for period consecutive steps. The scale is decreased immediately when non-finite gradients are detected.
- Parameters
grads_finite – A boolean JAX array indicating whether all gradients are finite (True) or contain NaN/Inf values (False).
- Returns
A new DynamicLossScale instance with updated loss_scale and counter values.
- Return type
Example
>>> scaler = DynamicLossScale(loss_scale=jnp.array(1024.0)) >>> # Finite gradients - counter increases >>> scaler = scaler.adjust(jnp.array(True)) >>> scaler.counter Array(1, dtype=int32) >>> # Non-finite gradients - scale decreases >>> scaler = scaler.adjust(jnp.array(False)) >>> scaler.loss_scale Array(512., dtype=float32)
Note
The method uses JAX’s lax.select for XLA-compatible conditional logic, ensuring the operation can be traced and JIT-compiled.
- counter: Array#
- factor: int = 2#
- loss_scale: Array#
- min_loss_scale: Array#
- period: int = 2000#
- scale(tree: T) T[source]#
Scale values by multiplying with the loss scale factor.
This method multiplies all values in the input PyTree by the current loss scale. Typically applied to the loss before gradient computation to prevent gradient underflow.
- Parameters
tree – A PyTree of arrays to scale. Can be a single array, dict, list, tuple, or any nested structure of arrays.
- Returns
A new PyTree with the same structure where all arrays have been multiplied by the loss_scale.
Example
>>> scaler = DynamicLossScale(loss_scale=jnp.array(1024.0)) >>> loss = jnp.array(0.001) >>> scaled_loss = scaler.scale(loss) >>> scaled_loss Array(1.024, dtype=float32)
- unscale(tree: T) T[source]#
Unscale values by dividing by the loss scale factor.
This method divides all values in the input PyTree by the current loss scale. Typically applied to gradients after computation to restore them to their true (unscaled) values.
- Parameters
tree – A PyTree of scaled arrays to unscale. Can be a single array, dict, list, tuple, or any nested structure of arrays.
- Returns
A new PyTree with the same structure where all arrays have been divided by the loss_scale.
Example
>>> scaler = DynamicLossScale(loss_scale=jnp.array(1024.0)) >>> scaled_grads = {"w": jnp.array(1024.0)} >>> unscaled_grads = scaler.unscale(scaled_grads) >>> unscaled_grads["w"] Array(1.0, dtype=float32)
- class eformer.mpric.loss_scaling.loss_scaler.LossScaleConfig(initial_scale: float = 32768, growth_interval: int = 2000, scale_factor: int = 2, min_scale: float = 1.0)[source]#
Bases:
objectConfiguration parameters for dynamic loss scaling behavior.
This dataclass holds the hyperparameters that control how dynamic loss scaling behaves during training. The defaults are tuned for typical mixed precision training scenarios.
- initial_scale#
The starting loss scale value. Higher values provide more headroom for gradients but risk overflow. Default is 2^15 (32768), which works well for most models.
- Type
float
- growth_interval#
Number of consecutive steps with finite gradients required before increasing the loss scale. Default is 2000 steps.
- Type
int
- scale_factor#
Multiplicative factor for scaling adjustments. The scale is multiplied by this factor when increasing and divided by it when decreasing. Default is 2.
- Type
int
- min_scale#
Minimum allowed loss scale value. Prevents the scale from becoming too small, which could cause gradient underflow. Default is 1.0.
- Type
float
Example
Default configuration:
config = LossScaleConfig() # initial_scale=32768, growth_interval=2000, scale_factor=2, min_scale=1.0
Aggressive scaling for stable training:
config = LossScaleConfig( initial_scale=2**16, growth_interval=1000, scale_factor=2, min_scale=1.0 )
Conservative scaling for unstable models:
config = LossScaleConfig( initial_scale=2**10, growth_interval=5000, scale_factor=2, min_scale=1.0 )
- growth_interval: int = 2000#
- initial_scale: float = 32768#
- min_scale: float = 1.0#
- scale_factor: int = 2#
- class eformer.mpric.loss_scaling.loss_scaler.NoOpLossScale[source]#
Bases:
objectNo-operation loss scaler that passes values through unchanged.
This class implements the loss scaler interface but performs no actual scaling. It is used when loss scaling is disabled, such as for full precision (float32) training where gradient underflow is not a concern.
Using NoOpLossScale instead of conditionally removing loss scaling logic simplifies code by maintaining a consistent interface regardless of whether scaling is enabled.
- loss_scale#
Always returns 1 (no scaling applied).
Example
>>> scaler = NoOpLossScale() >>> loss = jnp.array(0.5) >>> scaled = scaler.scale(loss) >>> scaled == loss True >>> scaler.adjust(True) is scaler True
- adjust(grads_finite: Array)[source]#
Adjust the loss scale based on gradient health (no-op).
- Parameters
grads_finite – A boolean array indicating whether gradients are finite (contains no NaN or Inf values). Ignored by NoOpLossScale.
- Returns
Returns self unchanged since no adjustment is needed.
- Return type
- property loss_scale#
Return the loss scale value.
- Returns
Always returns 1, indicating no scaling is applied.
- Return type
int