eformer.mpric.policy.policy#

Mixed precision policy implementation.

This module provides the Policy dataclass for defining precision configurations used in mixed precision training and inference.

class eformer.mpric.policy.policy.Policy(param_dtype: dtype, compute_dtype: dtype, output_dtype: dtype)[source]#

Bases: object

Mixed precision policy defining casting behavior for different operations.

This immutable dataclass defines the dtypes used for three distinct aspects of mixed precision computation:

  • param_dtype: The dtype for storing model parameters. Typically float32 to maintain precision during optimization.

  • compute_dtype: The dtype for forward/backward pass computations. Lower precision (float16, bfloat16) can speed up computation.

  • output_dtype: The dtype for function outputs. Often matches param_dtype for loss computation accuracy.

The policy is frozen (immutable) to ensure consistency during training and to allow safe usage as a static argument in JIT-compiled functions.

param_dtype#

JAX numpy dtype for model parameters.

Type

numpy.dtype

compute_dtype#

JAX numpy dtype for computations.

Type

numpy.dtype

output_dtype#

JAX numpy dtype for outputs.

Type

numpy.dtype

Example

Creating a policy for TPU training with bfloat16 compute:

policy = Policy(
    param_dtype=jnp.float32,
    compute_dtype=jnp.bfloat16,
    output_dtype=jnp.float32
)

Creating from string specification:

policy = Policy.from_string("p=f32,c=bf16,o=f32")

Note

For TPU training, bfloat16 is typically preferred as compute_dtype due to better hardware support. For GPU training, float16 may offer better performance with tensor cores.

compute_dtype: dtype#
classmethod from_string(policy_str: str) Policy[source]#

Create a Policy from a string specification.

This factory method parses a string specification to create a Policy instance. It supports both simple (single dtype) and detailed (per-operation dtype) specifications.

Parameters

policy_str

A string specifying the precision policy. Supported formats:

Simple format (single dtype for all operations):
  • ”f32”, “float32”: Use float32 for all operations

  • ”bf16”, “bfloat16”: Use bfloat16 for all operations

  • ”f16”, “float16”: Use float16 for all operations

  • ”half”: Use platform-specific half precision (bfloat16 on TPU, float16 on GPU/CPU)

Detailed format (comma-separated key=value pairs):
  • ”p=f32,c=bf16,o=f32”: Explicit dtypes for each operation

  • Keys: p/params, c/compute, o/output

  • Values: Any supported dtype string from DTYPE_MAPPING

Returns

A Policy instance with the specified dtypes.

Raises

ValueError – If an unknown dtype string is provided.

Example

>>> policy = Policy.from_string("p=f32,c=f16,o=f32")
>>> policy.param_dtype
dtype('float32')
>>> policy.compute_dtype
dtype('float16')
>>> policy = Policy.from_string("bf16")
>>> policy.param_dtype == policy.compute_dtype == policy.output_dtype
True
>>> policy = Policy.from_string("half")  # Platform-specific
>>> # On TPU: bfloat16, on GPU: float16

Note

When using the detailed format, if compute_dtype is not specified, it defaults to param_dtype. If output_dtype is not specified, it defaults to compute_dtype.

output_dtype: dtype#
param_dtype: dtype#