eformer.mpric.dtypes.precision_types#
Precision type definitions and dtype utilities.
This module provides mappings between string identifiers and JAX numpy dtypes, as well as utility functions for dtype conversion. It supports a wide range of floating-point precisions including standard IEEE formats and newer FP8 variants.
- Module Attributes:
- STRING_TO_DTYPE_MAP (dict): Comprehensive mapping from string identifiers
to jnp.dtype objects. Supports multiple aliases for each dtype (e.g., “bf16” and “bfloat16” both map to jnp.bfloat16).
- DTYPE_TO_STRING_MAP (dict): Reverse mapping from jnp.dtype objects to
their canonical string representations.
- DTYPE_MAPPING (dict): Simplified mapping used by Policy.from_string()
for parsing policy specifications. Supports short forms like “f32” and long forms like “float32”.
- Supported Dtypes:
float16 (fp16, f16): IEEE 754 half precision
bfloat16 (bf16): Brain floating point (Google TPU format)
float32 (fp32, f32): IEEE 754 single precision
float64 (fp64, f64): IEEE 754 double precision
float8_e4m3fn (fp8_e4m3fn): 8-bit float with 4 exponent, 3 mantissa bits
float8_e5m2 (fp8_e5m2, fp8): 8-bit float with 5 exponent, 2 mantissa bits
Additional FP8 variants: e4m3fnuz, e4m3b11fnuz, e5m2fnuz
- eformer.mpric.dtypes.precision_types.DTYPE_MAPPING = {'bf16': <class 'jax.numpy.bfloat16'>, 'bfloat16': <class 'jax.numpy.bfloat16'>, 'f16': <class 'jax.numpy.float16'>, 'f32': <class 'jax.numpy.float32'>, 'f64': <class 'jax.numpy.float64'>, 'f8_e4m3': <class 'jax.numpy.float8_e4m3fn'>, 'f8_e5m2': <class 'jax.numpy.float8_e5m2'>, 'float16': <class 'jax.numpy.float16'>, 'float32': <class 'jax.numpy.float32'>, 'float64': <class 'jax.numpy.float64'>, 'float8_e4m3': <class 'jax.numpy.float8_e4m3fn'>, 'float8_e5m2': <class 'jax.numpy.float8_e5m2'>}#
Simplified dtype mapping used primarily by Policy.from_string(). Supports short forms (f32) and long forms (float32) for convenience.
- eformer.mpric.dtypes.precision_types.DTYPE_TO_STRING_MAP = {<class 'jax.numpy.bfloat16'>: 'bf16', <class 'jax.numpy.float16'>: 'fp16', <class 'jax.numpy.float32'>: 'fp32', <class 'jax.numpy.float64'>: 'fp64', <class 'jax.numpy.float8_e5m2'>: 'fp8_e5m2', <class 'jax.numpy.float8_e4m3fn'>: 'fp8_e4m3fn', <class 'jax.numpy.float8_e4m3fnuz'>: 'fp8_e4m3fnuz', <class 'jax.numpy.float8_e4m3b11fnuz'>: 'fp8_e4m3b11fnuz', <class 'jax.numpy.float8_e5m2fnuz'>: 'fp8_e5m2fnuz'}#
Reverse mapping from JAX numpy dtypes to their canonical string representations. Useful for logging, serialization, and display purposes.
- eformer.mpric.dtypes.precision_types.STRING_TO_DTYPE_MAP = {'bf16': <class 'jax.numpy.bfloat16'>, 'bfloat16': <class 'jax.numpy.bfloat16'>, 'float16': <class 'jax.numpy.float16'>, 'float32': <class 'jax.numpy.float32'>, 'float64': <class 'jax.numpy.float64'>, 'float8_e4m3b11fnuz': <class 'jax.numpy.float8_e4m3b11fnuz'>, 'float8_e4m3fn': <class 'jax.numpy.float8_e4m3fn'>, 'float8_e4m3fnuz': <class 'jax.numpy.float8_e4m3fnuz'>, 'float8_e5m2': <class 'jax.numpy.float8_e5m2'>, 'float8_e5m2fnuz': <class 'jax.numpy.float8_e5m2fnuz'>, 'fp16': <class 'jax.numpy.float16'>, 'fp32': <class 'jax.numpy.float32'>, 'fp64': <class 'jax.numpy.float64'>, 'fp8': <class 'jax.numpy.float8_e5m2'>, 'fp8_e4m3b11fnuz': <class 'jax.numpy.float8_e4m3b11fnuz'>, 'fp8_e4m3fn': <class 'jax.numpy.float8_e4m3fn'>, 'fp8_e4m3fnuz': <class 'jax.numpy.float8_e4m3fnuz'>, 'fp8_e5m2': <class 'jax.numpy.float8_e5m2'>, 'fp8_e5m2fnuz': <class 'jax.numpy.float8_e5m2fnuz'>}#
Comprehensive mapping from string dtype identifiers to JAX numpy dtypes. Supports multiple aliases for each dtype for user convenience.
- eformer.mpric.dtypes.precision_types.get_platform_default_half() dtype[source]#
Get the platform-specific default half-precision dtype.
This function returns the recommended half-precision dtype for the current hardware platform. Different accelerators have different optimal half-precision formats:
TPU: Returns bfloat16, which has better hardware support on TPUs and a larger dynamic range than float16.
GPU/CPU: Returns float16, which is widely supported and has good tensor core acceleration on NVIDIA GPUs.
- Returns
Either jnp.bfloat16 (for TPU) or jnp.float16 (for GPU/CPU).
- Return type
jnp.dtype
Example
>>> dtype = get_platform_default_half() >>> # On TPU: >>> dtype == jnp.bfloat16 True >>> # On GPU: >>> dtype == jnp.float16 True
Note
This function queries the JAX backend at runtime, so the result depends on the actual hardware available, not just the installed JAX version.