eformer.mpric.handler.precision_handler#
Precision handler implementation for mixed precision operations.
This module provides the PrecisionHandler class that manages dtype casting and loss scaling for mixed precision training and inference in JAX.
- class eformer.mpric.handler.precision_handler.PrecisionHandler(policy: str | eformer.mpric.policy.policy.Policy, use_dynamic_scale: bool = True, loss_scale_config: LossScaleConfig = None)[source]#
Bases:
objectHandles mixed precision operations for training and inference.
This class provides a unified interface for managing mixed precision training and inference in JAX. It combines a precision policy (defining dtypes for parameters, computations, and outputs) with optional dynamic loss scaling to prevent gradient underflow in low-precision training.
The handler can wrap training step functions and inference functions to automatically handle dtype casting and loss scaling, making it easier to implement mixed precision workflows.
- policy#
The precision policy defining dtypes for different operations.
- loss_scale_config#
Configuration for loss scaling behavior.
- loss_scaler#
The loss scaler instance (DynamicLossScale or NoOpLossScale).
Example
Basic mixed precision training setup:
from eformer.mpric import PrecisionHandler # Create handler with bfloat16 compute, float32 params handler = PrecisionHandler( policy="p=f32,c=bf16,o=f32", use_dynamic_scale=True ) # Wrap your training step @handler.training_step_wrapper def train_step(params, batch): loss, grads = compute_loss_and_grads(params, batch) return loss, grads # Training loop for batch in data_loader: loss, grads, grads_finite = train_step(params, batch) if grads_finite: params = update_params(params, grads)
- cast_for_compute(x: Any) Any[source]#
Cast input arrays to the computation dtype.
This method is JIT-compiled for performance and casts all floating-point arrays in the input PyTree to the computation dtype specified by the policy.
- Parameters
x – A JAX PyTree containing arrays to cast. Can be a single array, a nested dict, list, tuple, or any valid PyTree structure.
- Returns
A new PyTree with the same structure where all floating-point arrays have been cast to the policy’s compute_dtype.
Note
This method is decorated with @jax.jit for efficient execution. Non-floating point arrays are returned unchanged.
- cast_for_output(x: Any) Any[source]#
Cast arrays to the output dtype.
This method is JIT-compiled for performance and casts all floating-point arrays in the input PyTree to the output dtype specified by the policy.
- Parameters
x – A JAX PyTree containing arrays to cast. Can be a single array, a nested dict, list, tuple, or any valid PyTree structure.
- Returns
A new PyTree with the same structure where all floating-point arrays have been cast to the policy’s output_dtype.
Note
This method is decorated with @jax.jit for efficient execution. Non-floating point arrays are returned unchanged.
- cast_params(params: Any) Any[source]#
Cast model parameters to the parameter dtype.
This method casts all floating-point arrays in the parameters PyTree to the parameter dtype specified by the policy. Typically used to maintain high precision for stored parameters.
- Parameters
params – A PyTree of model parameters. Typically a nested dict containing weight and bias arrays.
- Returns
A new PyTree with the same structure where all floating-point arrays have been cast to the policy’s param_dtype.
Example
>>> params = {"layer1": {"weights": jnp.ones((3, 3), dtype=jnp.float16)}} >>> casted_params = handler.cast_params(params) >>> casted_params["layer1"]["weights"].dtype dtype('float32') # if param_dtype is float32
- inference_wrapper(inference_fn)[source]#
Wrap an inference function with precision handling.
This decorator wraps an inference function to automatically handle dtype casting for inputs and outputs according to the precision policy. Unlike training_step_wrapper, this does not perform loss scaling as gradients are not computed during inference.
- Parameters
inference_fn – A callable that performs inference. The function can have any signature and return any PyTree structure.
- Returns
A wrapped function with the same signature as the input function, where inputs are cast to compute_dtype and outputs are cast to output_dtype.
Example
>>> handler = PrecisionHandler("p=f32,c=f16,o=f32") >>> >>> def my_inference(params, inputs): ... return model.apply(params, inputs) >>> >>> wrapped_inference = handler.inference_wrapper(my_inference) >>> outputs = wrapped_inference(params, inputs)
Note
This wrapper is suitable for both single-sample inference and batched inference. All floating-point arrays in both args and kwargs are cast to the computation dtype before calling the wrapped function.
- training_step_wrapper(training_step_fn)[source]#
Wrap a training step function with precision and loss scaling handling.
This decorator wraps a training step function to automatically handle: 1. Casting inputs to the computation dtype before the forward/backward pass 2. Scaling the loss and unscaling gradients for numerical stability 3. Checking gradient finiteness and adjusting the loss scale accordingly 4. Casting outputs to the output dtype
The wrapped function expects the original training step to return a tuple of (loss, grads) and will return (loss, grads, grads_finite).
- Parameters
training_step_fn – A callable that takes model inputs and returns a tuple of (loss, gradients). The function signature should be:
def training_step(*args, **kwargs) -> Tuple[Array, PyTree]- Returns
def wrapped_step(*args, **kwargs) -> Tuple[Array, PyTree, bool]returning (loss, gradients, grads_finite) where grads_finite indicates whether all gradients are finite (no NaN or Inf values).- Return type
A wrapped function with signature
Example
>>> handler = PrecisionHandler("p=f32,c=f16,o=f32") >>> >>> def my_train_step(params, batch): ... loss, grads = jax.value_and_grad(loss_fn)(params, batch) ... return loss, grads >>> >>> wrapped_step = handler.training_step_wrapper(my_train_step) >>> loss, grads, grads_ok = wrapped_step(params, batch) >>> if grads_ok: ... params = apply_gradients(params, grads)
Note
The loss scaler state is updated internally after each call. If gradients contain NaN or Inf values, the loss scale is reduced. After a period of stable gradients, the scale is increased.