Source code for eformer.aparser._aparser

# Copyright 2026 The EasyDeL/eFormer Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import dataclasses
import json
import os
import sys
import types
import typing as tp
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
from copy import copy
from enum import Enum
from inspect import isclass
from pathlib import Path

import yaml

from eformer.paths import ePath

DataClass = tp.NewType("DataClass", tp.Any)
DataClassType = tp.NewType("DataClassType", tp.Any)


[docs]def string_to_bool(v: str | bool) -> bool: """Convert a string to a boolean. Accepts various string representations for truthy and falsy values. Case-insensitive matching is used. Args: v: Value to convert. Can be a string or already a boolean. Returns: Boolean value corresponding to the input. Raises: ArgumentTypeError: If the string cannot be interpreted as boolean. Example: >>> string_to_bool("yes") True >>> string_to_bool("false") False >>> string_to_bool("1") True Accepted truthy values: "yes", "true", "t", "y", "1" Accepted falsy values: "no", "false", "f", "n", "0" """ if isinstance(v, bool): return v lower_v = v.lower() if lower_v in ("yes", "true", "t", "y", "1"): return True elif lower_v in ("no", "false", "f", "n", "0"): return False raise ArgumentTypeError( f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." )
[docs]def make_choice_type_function(choices: list[tp.Any]) -> tp.Callable[[str], tp.Any]: """Create a type converter function for argparse choices. Creates a function that maps string representations back to their original values from a list of choices. This is useful for Literal and Enum types where argparse receives strings but the original values may be different types. Args: choices: List of valid choice values. Returns: A function that converts a string argument to its corresponding choice value, or returns the original string if no match is found. Example: >>> converter = make_choice_type_function([1, 2, 3]) >>> converter("2") 2 >>> converter = make_choice_type_function(["small", "large"]) >>> converter("small") 'small' """ str_to_choice = {str(choice): choice for choice in choices} return lambda arg: str_to_choice.get(arg, arg)
[docs]def Argu( *, aliases: str | list[str] | None = None, help: str | None = None, # noqa default: tp.Any = dataclasses.MISSING, default_factory: tp.Callable[[], tp.Any] = dataclasses.MISSING, metadata: dict | None = None, **kwargs, ) -> dataclasses.Field: """Create a dataclass field with argument parsing metadata. A convenience wrapper around dataclasses.field() that adds metadata for command-line argument generation. Use this to specify help text, aliases, and other argparse options for dataclass fields. Args: aliases: Alternative command-line names for this argument. Can be a single string or list of strings. Example: aliases=["-lr", "--rate"] help: Help text displayed in --help output. default: Default value for the field. default_factory: Factory function for mutable default values. metadata: Additional metadata dictionary to extend. **kwargs: Additional arguments passed to dataclasses.field(). Returns: A dataclass Field object with argument metadata. Example: >>> from dataclasses import dataclass >>> from eformer.aparser import Argu, DataClassArgumentParser >>> >>> @dataclass >>> class Config: ... learning_rate: float = Argu( ... default=1e-4, ... aliases=["-lr"], ... help="Learning rate for optimizer" ... ) ... output_dir: str = Argu( ... default="./output", ... help="Directory for saving outputs" ... ) >>> >>> parser = DataClassArgumentParser(Config) >>> # Can use: --learning-rate 0.01 OR -lr 0.01 """ if metadata is None: metadata = {} if aliases is not None: metadata["aliases"] = aliases if help is not None: metadata["help"] = help return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
[docs]class DataClassArgumentParser(ArgumentParser): """ArgumentParser that generates arguments from dataclass type hints. This class extends argparse.ArgumentParser to automatically create command-line arguments from dataclass field definitions. It supports multiple dataclasses, various field types, and configuration loading from files. Supported field types: - Basic types: str, int, float, bool - Optional types: Optional[T] / T | None - Literal types: Literal["a", "b", "c"] - Enum types: MyEnum with automatic choices - List types: list[T] with nargs="+" Special handling: - Boolean fields get --no-{field} variants when default is True - Fields with underscores also accept hyphenated names - Aliases can be specified via Argu() metadata Attributes: dataclass_types: List of dataclass types to generate arguments for. Example: >>> from dataclasses import dataclass >>> from typing import Literal >>> >>> @dataclass >>> class TrainConfig: ... batch_size: int = 32 ... learning_rate: float = 1e-4 ... optimizer: Literal["adam", "sgd"] = "adam" ... use_amp: bool = True >>> >>> parser = DataClassArgumentParser(TrainConfig) >>> config, = parser.parse_args_into_dataclasses(["--batch-size", "64"]) >>> print(config.batch_size) 64 Multiple dataclasses: >>> parser = DataClassArgumentParser([TrainConfig, ModelConfig]) >>> train_cfg, model_cfg = parser.parse_args_into_dataclasses() """ dataclass_types: tp.Iterable[DataClassType] def __init__( self, dataclass_types: DataClassType | tp.Iterable[DataClassType], **kwargs: tp.Any, ) -> None: """Initialize the parser with one or more dataclass types. Args: dataclass_types: A single dataclass type or iterable of dataclass types. Arguments are generated from all provided dataclasses. **kwargs: Additional arguments passed to ArgumentParser. Defaults to ArgumentDefaultsHelpFormatter if not specified. """ if "formatter_class" not in kwargs: kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter super().__init__(**kwargs) if dataclasses.is_dataclass(dataclass_types): dataclass_types = [dataclass_types] self.dataclass_types = list(dataclass_types) for dtype in self.dataclass_types: self._add_dataclass_arguments(dtype) @staticmethod def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field) -> None: """Convert a dataclass field into a corresponding argparse argument. Handles type conversion, default values, choices, and special cases like boolean flags and Optional types. Args: parser: ArgumentParser or argument group to add the argument to. field: Dataclass field to convert. Raises: RuntimeError: If the field has an unresolved string type annotation. ValueError: If the field has an unsupported Union type. """ long_options = [f"--{field.name}"] if "_" in field.name: long_options.append(f"--{field.name.replace('_', '-')}") kwargs = field.metadata.copy() if isinstance(field.type, str): raise RuntimeError( f"Unresolved type detected for field '{field.name}'. Ensure that type annotations are fully resolved." ) aliases = kwargs.pop("aliases", []) if isinstance(aliases, str): aliases = [aliases] origin_type = getattr(field.type, "__origin__", None) if origin_type in (tp.Union, getattr(types, "UnionType", None)): union_args = field.type.__args__ if len(union_args) == 2 and type(None) in union_args: field.type = next(arg for arg in union_args if arg is not type(None)) origin_type = getattr(field.type, "__origin__", None) else: raise ValueError( f"Only tp.Optional types (tp.Union[T, None]) are supported for " f"field '{field.name}', got {field.type}." ) bool_kwargs: dict[str, tp.Any] = {} if field.type is bool: bool_kwargs = copy(kwargs) kwargs["type"] = string_to_bool default_val = False if field.default is dataclasses.MISSING else field.default kwargs["default"] = default_val kwargs["nargs"] = "?" kwargs["const"] = True elif origin_type is tp.Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): if origin_type is tp.Literal: kwargs["choices"] = field.type.__args__ else: kwargs["choices"] = [member.value for member in field.type] kwargs["type"] = make_choice_type_function(kwargs["choices"]) if field.default is not dataclasses.MISSING: kwargs["default"] = field.default else: kwargs["required"] = True elif isclass(field.type) and issubclass(field.type, list): kwargs["type"] = field.type.__args__[0] kwargs["nargs"] = "+" if field.default_factory is not dataclasses.MISSING: kwargs["default"] = field.default_factory() elif field.default is dataclasses.MISSING: kwargs["required"] = True else: kwargs["type"] = field.type if field.default is not dataclasses.MISSING: kwargs["default"] = field.default elif field.default_factory is not dataclasses.MISSING: kwargs["default"] = field.default_factory() else: kwargs["required"] = True current_type = kwargs["type"] type_args = tp.get_args(current_type) if type_args and type(None) in type_args: non_none_args = [arg for arg in type_args if arg is not type(None)] if len(non_none_args) == 1: kwargs["type"] = non_none_args[0] elif len(non_none_args) > 1: kwargs["type"] = tp.Union[tuple(non_none_args)] # noqa:UP007 parser.add_argument(*long_options, *aliases, **kwargs) if field.type is bool and field.default is True: bool_kwargs["default"] = False parser.add_argument( f"--no_{field.name}", f"--no-{field.name.replace('_', '-')}", action="store_false", dest=field.name, **bool_kwargs, ) def _add_dataclass_arguments(self, dtype: DataClassType) -> None: """Add arguments for all init-enabled fields of a dataclass. Fields with init=False are skipped. If the dataclass has an _argument_group_name attribute, arguments are added to a named group. Args: dtype: Dataclass type to add arguments for. Raises: RuntimeError: If type hints cannot be resolved for the dataclass. """ group_name = getattr(dtype, "_argument_group_name", None) parser = self.add_argument_group(group_name) if group_name else self try: type_hints: dict[str, type] = tp.get_type_hints(dtype) except NameError as e: raise RuntimeError( f"Type resolution failed for {dtype}. Consider declaring the class in global scope or disabling " "PEP 563 (postponed evaluation of annotations)." ) from e except TypeError as ex: if sys.version_info < (3, 10) and "unsupported operand type(s) for |" in str(ex): python_version = ".".join(map(str, sys.version_info[:3])) raise RuntimeError( f"Type resolution failed for {dtype} on Python {python_version}. " "Please use typing.tp.Union and typing.tp.Optional instead of the | syntax for union types." ) from ex raise for field in dataclasses.fields(dtype): if not field.init: continue field.type = type_hints[field.name] self._parse_dataclass_field(parser, field)
[docs] def parse_args_into_dataclasses( self, args: list[str] | None = None, return_remaining_strings: bool = False, look_for_args_file: bool = True, args_filename: str | None = None, args_file_flag: str | None = None, ) -> tuple[tp.Any, ...]: """Parse command-line arguments into dataclass instances. Parses arguments and constructs instances of all registered dataclass types. Supports loading additional arguments from files. Args: args: List of argument strings. If None, uses sys.argv[1:]. return_remaining_strings: If True, include unparsed arguments in output. look_for_args_file: If True, look for a .args file matching the script name. args_filename: Explicit path to an args file to load. args_file_flag: Command-line flag for specifying args file(s). Returns: Tuple of dataclass instances, one per registered dataclass type. If return_remaining_strings is True, the last element is a list of unparsed argument strings. Raises: ValueError: If there are unknown arguments and return_remaining_strings is False. Example: >>> parser = DataClassArgumentParser([TrainConfig, ModelConfig]) >>> train, model = parser.parse_args_into_dataclasses() >>> >>> # With remaining args >>> train, model, remaining = parser.parse_args_into_dataclasses( ... return_remaining_strings=True ... ) """ if args_file_flag or args_filename or (look_for_args_file and sys.argv): args_files: list[Path] = [] if args_filename: args_files.append(Path(args_filename)) elif look_for_args_file and sys.argv: args_files.append(Path(sys.argv[0]).with_suffix(".args")) if args_file_flag: args_file_parser = ArgumentParser(add_help=False) args_file_parser.add_argument(args_file_flag, type=str, action="append") cfg, args = args_file_parser.parse_known_args(args=args) cmd_args_file_paths = getattr(cfg, args_file_flag.lstrip("-"), None) if cmd_args_file_paths: args_files.extend(Path(p) for p in cmd_args_file_paths) file_args: list[str] = [] for args_file in args_files: if args_file.exists(): file_args.extend(args_file.read_text(encoding="utf-8").split()) if args is None: args = sys.argv[1:] args = file_args + args namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: field_names = {f.name for f in dataclasses.fields(dtype) if f.init} init_args = {k: v for k, v in vars(namespace).items() if k in field_names} for key in init_args: delattr(namespace, key) outputs.append(dtype(**init_args)) if namespace.__dict__: outputs.append(namespace) if return_remaining_strings: return (*outputs, remaining_args) elif remaining_args: raise ValueError(f"Some arguments were not used by DataClassArgumentParser: {remaining_args}") return tuple(outputs)
[docs] def parse_dict(self, args: dict[str, tp.Any], allow_extra_keys: bool = False) -> tuple[tp.Any, ...]: """Parse a dictionary of configuration values into dataclass instances. Useful for programmatic configuration or loading from config files. Args: args: Dictionary with keys matching dataclass field names. allow_extra_keys: If True, ignore keys that don't match any field. If False, raise ValueError for unknown keys. Returns: Tuple of dataclass instances, one per registered dataclass type. Raises: ValueError: If allow_extra_keys is False and unknown keys are present. Example: >>> parser = DataClassArgumentParser(TrainConfig) >>> config, = parser.parse_dict({ ... "learning_rate": 0.001, ... "batch_size": 64 ... }) """ unused_keys = set(args.keys()) outputs = [] for dtype in self.dataclass_types: field_names = {f.name for f in dataclasses.fields(dtype) if f.init} init_args = {k: v for k, v in args.items() if k in field_names} unused_keys -= init_args.keys() outputs.append(dtype(**init_args)) if not allow_extra_keys and unused_keys: raise ValueError(f"Unused keys in configuration: {sorted(unused_keys)}") return tuple(outputs)
[docs] def parse_json_file( self, json_file: str | os.PathLike, allow_extra_keys: bool = False, ) -> tuple[tp.Any, ...]: """Load a JSON file and parse it into dataclass instances. Args: json_file: Path to the JSON configuration file. Supports both local paths and GCS paths (gs://). allow_extra_keys: If True, ignore keys that don't match any field. Returns: Tuple of dataclass instances, one per registered dataclass type. Raises: FileNotFoundError: If the JSON file doesn't exist. json.JSONDecodeError: If the file contains invalid JSON. ValueError: If allow_extra_keys is False and unknown keys are present. Example: >>> parser = DataClassArgumentParser(TrainConfig) >>> config, = parser.parse_json_file("config.json") """ data = json.loads(ePath(json_file).read_text()) return self.parse_dict(data, allow_extra_keys=allow_extra_keys)
[docs] def parse_yaml_file( self, yaml_file: str | os.PathLike, allow_extra_keys: bool = False, ) -> tuple[tp.Any, ...]: """Load a YAML file and parse it into dataclass instances. Args: yaml_file: Path to the YAML configuration file. allow_extra_keys: If True, ignore keys that don't match any field. Returns: Tuple of dataclass instances, one per registered dataclass type. Raises: FileNotFoundError: If the YAML file doesn't exist. yaml.YAMLError: If the file contains invalid YAML. ValueError: If allow_extra_keys is False and unknown keys are present. Example: >>> parser = DataClassArgumentParser(TrainConfig) >>> config, = parser.parse_yaml_file("config.yaml") """ yaml_text = Path(yaml_file).read_text(encoding="utf-8") data = yaml.safe_load(yaml_text) return self.parse_dict(data, allow_extra_keys=allow_extra_keys)