diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3470872..60d40ff 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,6 +43,14 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest + - name: Test without pydantic run: | pytest + - name: Test with pydantic v1 + run: | + python -m pip install "pydantic < 2" + pytest + - name: Test with pydantic v2 + run: | + python -m pip install "pydantic >= 2" + pytest diff --git a/README.md b/README.md index 997f9a3..1ad6851 100644 --- a/README.md +++ b/README.md @@ -666,7 +666,7 @@ from tap import tapify class Squarer: """Squarer with a number to square. - :param num: The number to square. + :param num: The number to square. """ num: float @@ -681,6 +681,94 @@ if __name__ == '__main__': Running `python square_dataclass.py --num -1` prints `The square of your number is 1.0.`. +
+Argument descriptions + +For dataclasses, the argument's description (which is displayed in the `-h` help message) can either be specified in the +class docstring or the field's description in `metadata`. If both are specified, the description from the docstring is +used. In the example below, the description is provided in `metadata`. + +```python +# square_dataclass.py +from dataclasses import dataclass, field + +from tap import tapify + +@dataclass +class Squarer: + """Squarer with a number to square. + """ + num: float = field(metadata={"description": "The number to square."}) + + def get_square(self) -> float: + """Get the square of the number.""" + return self.num ** 2 + +if __name__ == '__main__': + squarer = tapify(Squarer) + print(f'The square of your number is {squarer.get_square()}.') +``` + +
+ +#### Pydantic + +Pydantic [Models](https://docs.pydantic.dev/latest/concepts/models/) and +[dataclasses](https://docs.pydantic.dev/latest/concepts/dataclasses/) can be `tapify`d. + +```python +# square_pydantic.py +from pydantic import BaseModel, Field + +from tap import tapify + +class Squarer(BaseModel): + """Squarer with a number to square. + """ + num: float = Field(description="The number to square.") + + def get_square(self) -> float: + """Get the square of the number.""" + return self.num ** 2 + +if __name__ == '__main__': + squarer = tapify(Squarer) + print(f'The square of your number is {squarer.get_square()}.') +``` + +
+Argument descriptions + +For Pydantic v2 models and dataclasses, the argument's description (which is displayed in the `-h` help message) can +either be specified in the class docstring or the field's `description`. If both are specified, the description from the +docstring is used. In the example below, the description is provided in the docstring. + +For Pydantic v1 models and dataclasses, the argument's description must be provided in the class docstring: + +```python +# square_pydantic.py +from pydantic import BaseModel + +from tap import tapify + +class Squarer(BaseModel): + """Squarer with a number to square. + + :param num: The number to square. + """ + num: float + + def get_square(self) -> float: + """Get the square of the number.""" + return self.num ** 2 + +if __name__ == '__main__': + squarer = tapify(Squarer) + print(f'The square of your number is {squarer.get_square()}.') +``` + +
+ ### tapify help The help string on the command line is set based on the docstring for the function or class. For example, running `python square_function.py -h` will print: @@ -751,4 +839,57 @@ if __name__ == '__main__': Running `python person.py --name Jesse --age 1` prints `My name is Jesse.` followed by `My age is 1.`. Without `known_only=True`, the `tapify` calls would raise an error due to the extra argument. ### Explicit boolean arguments + Tapify supports explicit specification of boolean arguments (see [bool](#bool) for more details). By default, `explicit_bool=False` and it can be set with `tapify(..., explicit_bool=True)`. + +## Convert to a `Tap` class + +`to_tap_class` turns a function or class into a `Tap` class. The returned class can be [subclassed](#subclassing) to add +special argument behavior. For example, you can override [`configure`](#configuring-arguments) and +[`process_args`](#argument-processing). + +If the object can be `tapify`d, then it can be `to_tap_class`d, and vice-versa. `to_tap_class` provides full control +over argument parsing. + +### Examples + +#### Simple + +```python +# main.py +""" +My script description +""" + +from pydantic import BaseModel + +from tap import to_tap_class + +class Project(BaseModel): + package: str + is_cool: bool = True + stars: int = 5 + +if __name__ == "__main__": + ProjectTap = to_tap_class(Project) + tap = ProjectTap(description=__doc__) # from the top of this script + args = tap.parse_args() + project = Project(**args.as_dict()) + print(f"Project instance: {project}") +``` + +Running `python main.py --package tap` will print `Project instance: package='tap' is_cool=True stars=5`. + +### Complex + +The general pattern is: + +```python +from tap import to_tap_class + +class MyCustomTap(to_tap_class(my_class_or_function)): + # Special argument behavior, e.g., override configure and/or process_args +``` + +Please see `demo_data_model.py` for an example of overriding [`configure`](#configuring-arguments) and +[`process_args`](#argument-processing). diff --git a/demo_data_model.py b/demo_data_model.py new file mode 100644 index 0000000..6542d99 --- /dev/null +++ b/demo_data_model.py @@ -0,0 +1,96 @@ +""" +Works for Pydantic v1 and v2. + +Example commands: + +python demo_data_model.py -h + +python demo_data_model.py \ + --arg_int 1 \ + --arg_list x y z \ + --argument_with_really_long_name 3 + +python demo_data_model.py \ + --arg_int 1 \ + --arg_list x y z \ + --arg_bool \ + -arg 3.14 +""" +from typing import List, Literal, Optional, Union + +from pydantic import BaseModel, Field +from tap import tapify, to_tap_class, Tap + + +class Model(BaseModel): + """ + My Pydantic Model which contains script args. + """ + + arg_int: int = Field(description="some integer") + arg_bool: bool = Field(default=True) + arg_list: Optional[List[str]] = Field(default=None, description="some list of strings") + + +def main(model: Model) -> None: + print("Parsed args into Model:") + print(model) + + +def to_number(string: str) -> Union[float, int]: + return float(string) if "." in string else int(string) + + +class ModelTap(to_tap_class(Model)): + # You can supply additional arguments here + argument_with_really_long_name: Union[float, int] = 3 + "This argument has a long name and will be aliased with a short one" + + def configure(self) -> None: + # You can still add special argument behavior + self.add_argument("-arg", "--argument_with_really_long_name", type=to_number) + + def process_args(self) -> None: + # You can still validate and modify arguments + # (You should do this in the Pydantic Model. I'm just demonstrating that this functionality is still possible) + if self.argument_with_really_long_name > 4: + raise ValueError("argument_with_really_long_name cannot be > 4") + + # No auto-complete (and other niceties) for the super class attributes b/c this is a dynamic subclass. Sorry + if self.arg_bool and self.arg_list is not None: + self.arg_list.append("processed") + + +# class SubparserA(Tap): +# bar: int # bar help + + +# class SubparserB(Tap): +# baz: Literal["X", "Y", "Z"] # baz help + + +# class ModelTapWithSubparsing(to_tap_class(Model)): +# foo: bool = False # foo help + +# def configure(self): +# self.add_subparsers(help="sub-command help") +# self.add_subparser("a", SubparserA, help="a help", description="Description (a)") +# self.add_subparser("b", SubparserB, help="b help") + + +if __name__ == "__main__": + # You don't have to subclass tap_class_from_data_model(Model) if you just want a plain argument parser: + # ModelTap = to_tap_class(Model) + args = ModelTap(description="Script description").parse_args() + # args = ModelTapWithSubparsing(description="Script description").parse_args() + print("Parsed args:") + print(args) + # Run the main function + model = Model(**args.as_dict()) + main(model) + + +# tapify works with Model. It immediately returns a Model instance instead of a Tap class +# if __name__ == "__main__": +# model = tapify(Model) +# print(model) diff --git a/setup.py b/setup.py index c405119..f4d3b3c 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,11 @@ with open("README.md", encoding="utf-8") as f: long_description = f.read() +test_requirements = [ + "pydantic >= 2.5.0", + "pytest", +] + setup( name="typed-argument-parser", version=__version__, @@ -26,7 +31,8 @@ packages=find_packages(), package_data={"tap": ["py.typed"]}, install_requires=["typing-inspect >= 0.7.1", "docstring-parser >= 0.15"], - tests_require=["pytest"], + tests_require=test_requirements, + extras_require={"dev": test_requirements}, python_requires=">=3.8", classifiers=[ "Programming Language :: Python :: 3", diff --git a/tap/__init__.py b/tap/__init__.py index b10e9d4..ea7b844 100644 --- a/tap/__init__.py +++ b/tap/__init__.py @@ -1,6 +1,13 @@ from argparse import ArgumentError, ArgumentTypeError from tap._version import __version__ from tap.tap import Tap -from tap.tapify import tapify +from tap.tapify import tapify, to_tap_class -__all__ = ["ArgumentError", "ArgumentTypeError", "Tap", "tapify", "__version__"] +__all__ = [ + "ArgumentError", + "ArgumentTypeError", + "Tap", + "tapify", + "to_tap_class", + "__version__", +] diff --git a/tap/tapify.py b/tap/tapify.py index 39e6cd7..58f8c34 100644 --- a/tap/tapify.py +++ b/tap/tapify.py @@ -1,106 +1,334 @@ -"""Tapify module, which can initialize a class or run a function by parsing arguments from the command line.""" -from inspect import signature, Parameter -from typing import Any, Callable, List, Optional, Type, TypeVar, Union +""" +`tapify`: initialize a class or run a function by parsing arguments from the command line. -from docstring_parser import parse +`to_tap_class`: convert a class or function into a `Tap` class, which can then be subclassed to add special argument +handling +""" + +import dataclasses +import inspect +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, TypeVar, Union + +from docstring_parser import Docstring, parse + +try: + import pydantic +except ModuleNotFoundError: + _IS_PYDANTIC_V1 = None + # These are "empty" types. isinstance and issubclass will always be False + BaseModel = type("BaseModel", (object,), {}) + _PydanticField = type("_PydanticField", (object,), {}) + _PYDANTIC_FIELD_TYPES = () +else: + _IS_PYDANTIC_V1 = pydantic.__version__ < "2.0.0" + from pydantic import BaseModel + from pydantic.fields import FieldInfo as PydanticFieldBaseModel + from pydantic.dataclasses import FieldInfo as PydanticFieldDataclass + + _PydanticField = Union[PydanticFieldBaseModel, PydanticFieldDataclass] + # typing.get_args(_PydanticField) is an empty tuple for some reason. Just repeat + _PYDANTIC_FIELD_TYPES = (PydanticFieldBaseModel, PydanticFieldDataclass) from tap import Tap InputType = TypeVar("InputType") OutputType = TypeVar("OutputType") +_ClassOrFunction = Union[Callable[[InputType], OutputType], Type[OutputType]] -def tapify( - class_or_function: Union[Callable[[InputType], OutputType], Type[OutputType]], - known_only: bool = False, - command_line_args: Optional[List[str]] = None, - explicit_bool: bool = False, - **func_kwargs, -) -> OutputType: - """Tapify initializes a class or runs a function by parsing arguments from the command line. - :param class_or_function: The class or function to run with the provided arguments. - :param known_only: If true, ignores extra arguments and only parses known arguments. - :param command_line_args: A list of command line style arguments to parse (e.g., ['--arg', 'value']). - If None, arguments are parsed from the command line (default behavior). - :param explicit_bool: Booleans can be specified on the command line as "--arg True" or "--arg False" - rather than "--arg". Additionally, booleans can be specified by prefixes of True and False - with any capitalization as well as 1 or 0. - :param func_kwargs: Additional keyword arguments for the function. These act as default values when - parsing the command line arguments and overwrite the function defaults but - are overwritten by the parsed command line arguments. +@dataclasses.dataclass +class _ArgData: + """ + Data about an argument which is sufficient to inform a Tap variable/argument. """ - # Get signature from class or function - sig = signature(class_or_function) - # Parse class or function docstring in one line - if isinstance(class_or_function, type) and class_or_function.__init__.__doc__ is not None: - doc = class_or_function.__init__.__doc__ + name: str + + annotation: Type + "The type of values this argument accepts" + + is_required: bool + "Whether or not the argument must be passed in" + + default: Any + "Value of the argument if the argument isn't passed in. This gets ignored if is_required" + + description: Optional[str] = "" + "Human-readable description of the argument" + + +@dataclasses.dataclass(frozen=True) +class _TapData: + """ + Data about a class' or function's arguments which are sufficient to inform a Tap class. + """ + + args_data: List[_ArgData] + "List of data about each argument in the class or function" + + has_kwargs: bool + "True if you can pass variable/extra kwargs to the class or function (as in **kwargs), else False" + + known_only: bool + "If true, ignore extra arguments and only parse known arguments" + + +def _is_pydantic_base_model(obj: Union[Type[Any], Any]) -> bool: + if inspect.isclass(obj): # issublcass requires that obj is a class + return issubclass(obj, BaseModel) else: - doc = class_or_function.__doc__ + return isinstance(obj, BaseModel) - # Parse docstring - docstring = parse(doc) - # Get the description of each argument in the class init or function - param_to_description = {param.arg_name: param.description for param in docstring.params} +def _is_pydantic_dataclass(obj: Union[Type[Any], Any]) -> bool: + if _IS_PYDANTIC_V1: + # There's no public function in v1. This is a somewhat safe but linear check + return dataclasses.is_dataclass(obj) and any(key.startswith("__pydantic") for key in obj.__dict__) + else: + return pydantic.dataclasses.is_pydantic_dataclass(obj) - # Create a Tap object with a description from the docstring of the function or class - description = "\n".join(filter(None, (docstring.short_description, docstring.long_description))) - tap = Tap(description=description, explicit_bool=explicit_bool) - # Keep track of whether **kwargs was provided +def _tap_data_from_data_model( + data_model: Any, func_kwargs: Dict[str, Any], param_to_description: Dict[str, str] = None +) -> _TapData: + """ + Currently only works when `data_model` is a: + - builtin dataclass (class or instance) + - Pydantic dataclass (class or instance) + - Pydantic BaseModel (class or instance). + + The advantage of this function over :func:`_tap_data_from_class_or_function` is that field/argument descriptions are + extracted, b/c this function look at the fields of the data model. + + Note + ---- + Deletes redundant keys from `func_kwargs` + """ + param_to_description = param_to_description or {} + + def arg_data_from_dataclass(name: str, field: dataclasses.Field) -> _ArgData: + def is_required(field: dataclasses.Field) -> bool: + return field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING + + description = param_to_description.get(name, field.metadata.get("description")) + return _ArgData( + name, + field.type, + is_required(field), + field.default, + description, + ) + + def arg_data_from_pydantic(name: str, field: _PydanticField, annotation: Optional[Type] = None) -> _ArgData: + annotation = field.annotation if annotation is None else annotation + # Prefer the description from param_to_description (from the data model / class docstring) over the + # field.description b/c a docstring can be modified on the fly w/o causing real issues + description = param_to_description.get(name, field.description) + return _ArgData(name, annotation, field.is_required(), field.default, description) + + # Determine what type of data model it is and extract fields accordingly + if dataclasses.is_dataclass(data_model): + name_to_field = {field.name: field for field in dataclasses.fields(data_model)} + has_kwargs = False + known_only = False + elif _is_pydantic_base_model(data_model): + name_to_field = data_model.model_fields + # For backwards compatibility, only allow new kwargs to get assigned if the model is explicitly configured to do + # so via extra="allow". See https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra + is_extra_ok = data_model.model_config.get("extra", "ignore") == "allow" + has_kwargs = is_extra_ok + known_only = is_extra_ok + else: + raise TypeError( + "data_model must be a builtin or Pydantic dataclass (instance or class) or " + f"a Pydantic BaseModel (instance or class). Got {type(data_model)}" + ) + + # It's possible to mix fields w/ classes, e.g., use pydantic Fields in a (builtin) dataclass, or use (builtin) + # dataclass fields in a pydantic BaseModel. It's also possible to use (builtin) dataclass fields and pydantic Fields + # in the same data model. Therefore, the type of the data model doesn't determine the type of each field. The + # solution is to iterate through the fields and check each type. + args_data: List[_ArgData] = [] + for name, field in name_to_field.items(): + if isinstance(field, dataclasses.Field): + # Idiosyncrasy: if a pydantic Field is used in a pydantic dataclass, then field.default is a FieldInfo + # object instead of the field's default value. Furthermore, field.annotation is always NoneType. Luckily, + # the actual type of the field is stored in field.type + if isinstance(field.default, _PYDANTIC_FIELD_TYPES): + arg_data = arg_data_from_pydantic(name, field.default, annotation=field.type) + else: + arg_data = arg_data_from_dataclass(name, field) + elif isinstance(field, _PYDANTIC_FIELD_TYPES): + arg_data = arg_data_from_pydantic(name, field) + else: + raise TypeError(f"Each field must be a dataclass or Pydantic field. Got {type(field)}") + # Handle case where func_kwargs is supplied + if name in func_kwargs: + arg_data.default = func_kwargs[name] + arg_data.is_required = False + del func_kwargs[name] + args_data.append(arg_data) + return _TapData(args_data, has_kwargs, known_only) + + +def _tap_data_from_class_or_function( + class_or_function: _ClassOrFunction, func_kwargs: Dict[str, Any], param_to_description: Dict[str, str] +) -> _TapData: + """ + Extract data by inspecting the signature of `class_or_function`. + + Note + ---- + Deletes redundant keys from `func_kwargs` + """ + args_data: List[_ArgData] = [] has_kwargs = False + known_only = False - # Add arguments of class init or function to the Tap object - for param_name, param in sig.parameters.items(): - tap_kwargs = {} + sig = inspect.signature(class_or_function) + for param_name, param in sig.parameters.items(): # Skip **kwargs - if param.kind == Parameter.VAR_KEYWORD: + if param.kind == inspect.Parameter.VAR_KEYWORD: has_kwargs = True known_only = True continue - # Get type of the argument - if param.annotation != Parameter.empty: - # Any type defaults to str (needed for dataclasses where all non-default attributes must have a type) - if param.annotation is Any: - tap._annotations[param.name] = str - # Otherwise, get the type of the argument - else: - tap._annotations[param.name] = param.annotation + if param.annotation != inspect.Parameter.empty: + annotation = param.annotation + else: + annotation = Any - # Get the default or required of the argument if param.name in func_kwargs: - tap_kwargs["default"] = func_kwargs[param.name] + is_required = False + default = func_kwargs[param.name] del func_kwargs[param.name] - elif param.default != Parameter.empty: - tap_kwargs["default"] = param.default + elif param.default != inspect.Parameter.empty: + is_required = False + default = param.default else: - tap_kwargs["required"] = True + is_required = True + default = inspect.Parameter.empty # Can be set to anything. It'll be ignored + + arg_data = _ArgData( + name=param_name, + annotation=annotation, + is_required=is_required, + default=default, + description=param_to_description.get(param.name), + ) + args_data.append(arg_data) + return _TapData(args_data, has_kwargs, known_only) + + +def _is_data_model(obj: Union[Type[Any], Any]) -> bool: + return dataclasses.is_dataclass(obj) or _is_pydantic_base_model(obj) + + +def _docstring(class_or_function) -> Docstring: + is_function = not inspect.isclass(class_or_function) + if is_function or _is_pydantic_base_model(class_or_function): + doc = class_or_function.__doc__ + else: + doc = class_or_function.__init__.__doc__ or class_or_function.__doc__ + return parse(doc) + + +def _tap_data(class_or_function: _ClassOrFunction, param_to_description: Dict[str, str], func_kwargs) -> _TapData: + """ + Controls how :class:`_TapData` is extracted from `class_or_function`. + """ + is_pydantic_v1_data_model = _IS_PYDANTIC_V1 and ( + _is_pydantic_base_model(class_or_function) or _is_pydantic_dataclass(class_or_function) + ) + if _is_data_model(class_or_function) and not is_pydantic_v1_data_model: + # Data models from Pydantic v1 don't lend themselves well to _tap_data_from_data_model. + # _tap_data_from_data_model looks at the data model's fields. In Pydantic v1, the field.type_ attribute stores + # the field's annotation/type. But (in Pydantic v1) there's a bug where field.type_ is set to the inner-most + # type of a subscripted type. For example, annotating a field with list[str] causes field.type_ to be str, not + # list[str]. To get around this, we'll extract _TapData by looking at the signature of the data model + return _tap_data_from_data_model(class_or_function, func_kwargs, param_to_description) + # TODO: allow passing func_kwargs to a Pydantic BaseModel + return _tap_data_from_class_or_function(class_or_function, func_kwargs, param_to_description) + + +def _tap_class(args_data: Sequence[_ArgData]) -> Type[Tap]: + class ArgParser(Tap): + # Overwriting configure would force a user to remember to call super().configure if they want to overwrite it + # Instead, overwrite _configure + def _configure(self): + for arg_data in args_data: + variable = arg_data.name + self._annotations[variable] = str if arg_data.annotation is Any else arg_data.annotation + self.class_variables[variable] = {"comment": arg_data.description or ""} + if arg_data.is_required: + kwargs = {} + else: + kwargs = dict(required=False, default=arg_data.default) + self.add_argument(f"--{variable}", **kwargs) + + super()._configure() + + return ArgParser + + +def to_tap_class(class_or_function: _ClassOrFunction) -> Type[Tap]: + """Creates a `Tap` class from `class_or_function`. This can be subclassed to add custom argument handling and + instantiated to create a typed argument parser. + + :param class_or_function: The class or function to run with the provided arguments. + """ + docstring = _docstring(class_or_function) + param_to_description = {param.arg_name: param.description for param in docstring.params} + # TODO: add func_kwargs + tap_data = _tap_data(class_or_function, param_to_description, func_kwargs={}) + return _tap_class(tap_data.args_data) - # Get the help string of the argument - if param.name in param_to_description: - tap.class_variables[param.name] = {"comment": param_to_description[param.name]} - # Add the argument to the Tap object - tap._add_argument(f"--{param_name}", **tap_kwargs) +def tapify( + class_or_function: Union[Callable[[InputType], OutputType], Type[OutputType]], + known_only: bool = False, + command_line_args: Optional[List[str]] = None, + explicit_bool: bool = False, + **func_kwargs, +) -> OutputType: + """Tapify initializes a class or runs a function by parsing arguments from the command line. + + :param class_or_function: The class or function to run with the provided arguments. + :param known_only: If true, ignores extra arguments and only parses known arguments. + :param command_line_args: A list of command line style arguments to parse (e.g., ['--arg', 'value']). + If None, arguments are parsed from the command line (default behavior). + :param explicit_bool: Booleans can be specified on the command line as "--arg True" or "--arg False" + rather than "--arg". Additionally, booleans can be specified by prefixes of True and False + with any capitalization as well as 1 or 0. + :param func_kwargs: Additional keyword arguments for the function. These act as default values when + parsing the command line arguments and overwrite the function defaults but + are overwritten by the parsed command line arguments. + """ + # We don't directly call to_tap_class b/c we need tap_data, not just tap_class + docstring = _docstring(class_or_function) + param_to_description = {param.arg_name: param.description for param in docstring.params} + tap_data = _tap_data(class_or_function, param_to_description, func_kwargs) + tap_class = _tap_class(tap_data.args_data) + # Create a Tap object with a description from the docstring of the class or function + description = "\n".join(filter(None, (docstring.short_description, docstring.long_description))) + tap = tap_class(description=description, explicit_bool=explicit_bool) # If any func_kwargs remain, they are not used in the function, so raise an error + known_only = known_only or tap_data.known_only if func_kwargs and not known_only: raise ValueError(f"Unknown keyword arguments: {func_kwargs}") # Parse command line arguments - command_line_args = tap.parse_args(args=command_line_args, known_only=known_only) + command_line_args: Tap = tap.parse_args(args=command_line_args, known_only=known_only) # Get command line arguments as a dictionary command_line_args_dict = command_line_args.as_dict() # Get **kwargs from extra command line arguments - if has_kwargs: + if tap_data.has_kwargs: kwargs = {tap.extra_args[i].lstrip("-"): tap.extra_args[i + 1] for i in range(0, len(tap.extra_args), 2)} - command_line_args_dict.update(kwargs) # Initialize the class or run the function with the parsed arguments diff --git a/tests/test_tapify.py b/tests/test_tapify.py index d121c6d..f2aa387 100644 --- a/tests/test_tapify.py +++ b/tests/test_tapify.py @@ -1,3 +1,7 @@ +""" +Tests `tap.tapify`. Currently requires Pydantic v2. +""" + import contextlib from dataclasses import dataclass import io @@ -9,6 +13,14 @@ from tap import tapify +try: + import pydantic +except ModuleNotFoundError: + _IS_PYDANTIC_V1 = None +else: + _IS_PYDANTIC_V1 = pydantic.__version__ < "2.0.0" + + # Suppress prints from SystemExit class DevNull: def write(self, msg): @@ -22,7 +34,7 @@ class Person: def __init__(self, name: str): self.name = name - def __str__(self) -> str: + def __repr__(self) -> str: return f"Person({self.name})" @@ -31,7 +43,7 @@ def __init__(self, problem_1: str, problem_2): self.problem_1 = problem_1 self.problem_2 = problem_2 - def __str__(self) -> str: + def __repr__(self) -> str: return f"Problems({self.problem_1}, {self.problem_2})" @@ -49,7 +61,22 @@ class PieDataclass: def __eq__(self, other: float) -> bool: return other == pie() - for class_or_function in [pie, Pie, PieDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class PieDataclassPydantic: + def __eq__(self, other: float) -> bool: + return other == pie() + + class PieModel(pydantic.BaseModel): + def __eq__(self, other: float) -> bool: + return other == pie() + + pydantic_data_models = [PieDataclassPydantic, PieModel] + else: + pydantic_data_models = [] + + for class_or_function in [pie, Pie, PieDataclass] + pydantic_data_models: self.assertEqual(tapify(class_or_function, command_line_args=[]), 3.14) def test_tapify_simple_types(self): @@ -74,7 +101,34 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.a, self.simple, self.test, self.of, self.types) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + a: int + simple: str + test: float + of: float + types: bool + + def __eq__(self, other: str) -> bool: + return other == concat(self.a, self.simple, self.test, self.of, self.types) + + class ConcatModel(pydantic.BaseModel): + a: int + simple: str + test: float + of: float + types: bool + + def __eq__(self, other: str) -> bool: + return other == concat(self.a, self.simple, self.test, self.of, self.types) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output = tapify( class_or_function, command_line_args=["--a", "1", "--simple", "simple", "--test", "3.14", "--of", "2.718", "--types"], @@ -107,7 +161,37 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.a, self.simple, self.test, self.of, self.types, self.wow) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + a: int + simple: str + test: float + of: float = -0.3 + types: bool = False + wow: str = "abc" + + def __eq__(self, other: str) -> bool: + return other == concat(self.a, self.simple, self.test, self.of, self.types, self.wow) + + class ConcatModel(pydantic.BaseModel): + a: int + simple: str + test: float + of: float = -0.3 + types: bool = False + wow: str = "abc" + + def __eq__(self, other: str) -> bool: + return other == concat(self.a, self.simple, self.test, self.of, self.types, self.wow) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: + print(class_or_function.__name__) output = tapify( class_or_function, command_line_args=["--a", "1", "--simple", "simple", "--test", "3.14", "--types", "--wow", "wee"], @@ -135,7 +219,38 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.complexity, self.requires, self.intelligence) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass(config=dict(arbitrary_types_allowed=True)) # for Person + class ConcatDataclassPydantic: + complexity: List[str] + requires: Tuple[int, int] + intelligence: Person + + def __eq__(self, other: str) -> bool: + return other == concat(self.complexity, self.requires, self.intelligence) + + class ConcatModel(pydantic.BaseModel): + if _IS_PYDANTIC_V1: + + class Config: + arbitrary_types_allowed = True # for Person + + else: + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) # for Person + + complexity: List[str] + requires: Tuple[int, int] + intelligence: Person + + def __eq__(self, other: str) -> bool: + return other == concat(self.complexity, self.requires, self.intelligence) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output = tapify( class_or_function, command_line_args=[ @@ -176,10 +291,51 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.complexity, self.requires, self.intelligence) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass(config=dict(arbitrary_types_allowed=True)) # for Person + class ConcatDataclassPydantic: + complexity: list[int] + requires: tuple[int, int] + intelligence: Person + + def __eq__(self, other: str) -> bool: + return other == concat(self.complexity, self.requires, self.intelligence) + + class ConcatModel(pydantic.BaseModel): + if _IS_PYDANTIC_V1: + + class Config: + arbitrary_types_allowed = True # for Person + + else: + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) # for Person + + complexity: list[int] + requires: tuple[int, int] + intelligence: Person + + def __eq__(self, other: str) -> bool: + return other == concat(self.complexity, self.requires, self.intelligence) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output = tapify( class_or_function, - command_line_args=["--complexity", "1", "2", "3", "--requires", "1", "0", "--intelligence", "jesse",], + command_line_args=[ + "--complexity", + "1", + "2", + "3", + "--requires", + "1", + "0", + "--intelligence", + "jesse", + ], ) self.assertEqual(output, "1 2 3 1 0 Person(jesse)") @@ -225,7 +381,42 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.complexity, self.requires, self.intelligence, self.maybe, self.possibly) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass(config=dict(arbitrary_types_allowed=True)) # for Person + class ConcatDataclassPydantic: + complexity: List[str] + requires: Tuple[int, int] = (2, 5) + intelligence: Person = Person("kyle") + maybe: Optional[str] = None + possibly: Optional[str] = None + + def __eq__(self, other: str) -> bool: + return other == concat(self.complexity, self.requires, self.intelligence, self.maybe, self.possibly) + + class ConcatModel(pydantic.BaseModel): + if _IS_PYDANTIC_V1: + + class Config: + arbitrary_types_allowed = True # for Person + + else: + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) # for Person + + complexity: List[str] + requires: Tuple[int, int] = (2, 5) + intelligence: Person = Person("kyle") + maybe: Optional[str] = None + possibly: Optional[str] = None + + def __eq__(self, other: str) -> bool: + return other == concat(self.complexity, self.requires, self.intelligence, self.maybe, self.possibly) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output = tapify( class_or_function, command_line_args=[ @@ -263,7 +454,30 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.so, self.many, self.args) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + so: int + many: float + args: str + + def __eq__(self, other: str) -> bool: + return other == concat(self.so, self.many, self.args) + + class ConcatModel(pydantic.BaseModel): + so: int + many: float + args: str + + def __eq__(self, other: str) -> bool: + return other == concat(self.so, self.many, self.args) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: with self.assertRaises(SystemExit): tapify(class_or_function, command_line_args=["--so", "23", "--many", "9.3"]) @@ -286,7 +500,28 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.so, self.few) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + so: int + few: float + + def __eq__(self, other: str) -> bool: + return other == concat(self.so, self.few) + + class ConcatModel(pydantic.BaseModel): + so: int + few: float + + def __eq__(self, other: str) -> bool: + return other == concat(self.so, self.few) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: with self.assertRaises(SystemExit): tapify(class_or_function, command_line_args=["--so", "23", "--few", "9.3", "--args", "wow"]) @@ -309,7 +544,28 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.so, self.few) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + so: int + few: float + + def __eq__(self, other: str) -> bool: + return other == concat(self.so, self.few) + + class ConcatModel(pydantic.BaseModel): + so: int + few: float + + def __eq__(self, other: str) -> bool: + return other == concat(self.so, self.few) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output = tapify( class_or_function, command_line_args=["--so", "23", "--few", "9.3", "--args", "wow"], known_only=True ) @@ -339,10 +595,46 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.i, self.like, self.k, self.w, self.args, self.always) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + i: int + like: float + k: int + w: str = "w" + args: str = "argy" + always: bool = False + + def __eq__(self, other: str) -> bool: + return other == concat(self.i, self.like, self.k, self.w, self.args, self.always) + + class ConcatModel(pydantic.BaseModel): + i: int + like: float + k: int + w: str = "w" + args: str = "argy" + always: bool = False + + def __eq__(self, other: str) -> bool: + return other == concat(self.i, self.like, self.k, self.w, self.args, self.always) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output = tapify( class_or_function, - command_line_args=["--i", "23", "--args", "wow", "--like", "3.03",], + command_line_args=[ + "--i", + "23", + "--args", + "wow", + "--like", + "3.03", + ], known_only=True, w="hello", k=5, @@ -375,11 +667,47 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.i, self.like, self.k, self.w, self.args, self.always) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + i: int + like: float + k: int + w: str = "w" + args: str = "argy" + always: bool = False + + def __eq__(self, other: str) -> bool: + return other == concat(self.i, self.like, self.k, self.w, self.args, self.always) + + class ConcatModel(pydantic.BaseModel): + i: int + like: float + k: int + w: str = "w" + args: str = "argy" + always: bool = False + + def __eq__(self, other: str) -> bool: + return other == concat(self.i, self.like, self.k, self.w, self.args, self.always) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: with self.assertRaises(ValueError): tapify( class_or_function, - command_line_args=["--i", "23", "--args", "wow", "--like", "3.03",], + command_line_args=[ + "--i", + "23", + "--args", + "wow", + "--like", + "3.03", + ], w="hello", k=5, like=3.4, @@ -404,7 +732,34 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.problems) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass(config=dict(arbitrary_types_allowed=True)) + class ConcatDataclassPydantic: + problems: Problems + + def __eq__(self, other: str) -> bool: + return other == concat(self.problems) + + class ConcatModel(pydantic.BaseModel): + if _IS_PYDANTIC_V1: + + class Config: + arbitrary_types_allowed = True # for Problems + + else: + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) # for Problems + + problems: Problems + + def __eq__(self, other: str) -> bool: + return other == concat(self.problems) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output = tapify(class_or_function, command_line_args=[], problems=Problems("oh", "no!")) self.assertEqual(output, "Problems(oh, no!)") @@ -448,7 +803,40 @@ def __eq__(self, other: str) -> bool: self.untyped_1, self.typed_1, self.untyped_2, self.typed_2, self.untyped_3, self.typed_3 ) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + untyped_1: Any + typed_1: int + untyped_2: Any = 5 + typed_2: str = "now" + untyped_3: Any = "hi" + typed_3: bool = False + + def __eq__(self, other: str) -> bool: + return other == concat( + self.untyped_1, self.typed_1, self.untyped_2, self.typed_2, self.untyped_3, self.typed_3 + ) + + class ConcatModel(pydantic.BaseModel): + untyped_1: Any + typed_1: int + untyped_2: Any = 5 + typed_2: str = "now" + untyped_3: Any = "hi" + typed_3: bool = False + + def __eq__(self, other: str) -> bool: + return other == concat( + self.untyped_1, self.typed_1, self.untyped_2, self.typed_2, self.untyped_3, self.typed_3 + ) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output = tapify( class_or_function, command_line_args=[ @@ -491,7 +879,33 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.a, self.b, self.c) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + """Concatenate three numbers.""" + + a: int + b: int + c: int + + def __eq__(self, other: str) -> bool: + return other == concat(self.a, self.b, self.c) + + class ConcatModel(pydantic.BaseModel): + """Concatenate three numbers.""" + + a: int + b: int + c: int + + def __eq__(self, other: str) -> bool: + return other == concat(self.a, self.b, self.c) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: output_1 = tapify(class_or_function, command_line_args=["--a", "1", "--b", "2", "--c", "3"]) output_2 = tapify(class_or_function, command_line_args=["--a", "4", "--b", "5", "--c", "6"]) @@ -555,16 +969,54 @@ class ConcatDataclass: def __eq__(self, other: str) -> bool: return other == concat(self.a, self.b, self.c) - for class_or_function in [concat, Concat, ConcatDataclass]: + if _IS_PYDANTIC_V1 is not None: + + @pydantic.dataclasses.dataclass + class ConcatDataclassPydantic: + """Concatenate three numbers. + + :param a: The first number. + :param b: The second number. + :param c: The third number. + """ + + a: int + b: int + c: int + + def __eq__(self, other: str) -> bool: + return other == concat(self.a, self.b, self.c) + + class ConcatModel(pydantic.BaseModel): + """Concatenate three numbers. + + :param a: The first number. + :param b: The second number. + :param c: The third number. + """ + + a: int + b: int + c: int + + def __eq__(self, other: str) -> bool: + return other == concat(self.a, self.b, self.c) + + pydantic_data_models = [ConcatDataclassPydantic, ConcatModel] + else: + pydantic_data_models = [] + + for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models: f = io.StringIO() with contextlib.redirect_stdout(f): with self.assertRaises(SystemExit): tapify(class_or_function, command_line_args=["-h"]) - self.assertIn("Concatenate three numbers.", f.getvalue()) - self.assertIn("--a A (int, required) The first number.", f.getvalue()) - self.assertIn("--b B (int, required) The second number.", f.getvalue()) - self.assertIn("--c C (int, required) The third number.", f.getvalue()) + stdout = f.getvalue() + self.assertIn("Concatenate three numbers.", stdout) + self.assertIn("--a A (int, required) The first number.", stdout) + self.assertIn("--b B (int, required) The second number.", stdout) + self.assertIn("--c C (int, required) The third number.", stdout) class TestTapifyExplicitBool(unittest.TestCase): @@ -630,7 +1082,45 @@ def concat(a: int, b: int = 2, **kwargs) -> str: """ return f'{a}_{b}_{"-".join(f"{k}={v}" for k, v in kwargs.items())}' - self.concat_function = concat + if _IS_PYDANTIC_V1 is not None: + + class ConcatModel(pydantic.BaseModel): + """Concatenate three numbers. + + :param a: The first number. + :param b: The second number. + """ + + if _IS_PYDANTIC_V1: + + class Config: + extra = pydantic.Extra.allow # by default, pydantic ignores extra arguments + + else: + model_config = pydantic.ConfigDict(extra="allow") # by default, pydantic ignores extra arguments + + a: int + b: int = 2 + + def __eq__(self, other: str) -> bool: + if _IS_PYDANTIC_V1: + # Get the kwarg names in the correct order by parsing other + kwargs_str = other.split("_")[-1] + if not kwargs_str: + kwarg_names = [] + else: + kwarg_names = [kv_str.split("=")[0] for kv_str in kwargs_str.split("-")] + kwargs = {name: getattr(self, name) for name in kwarg_names} + # Need to explictly check that the extra names from other are identical to what's stored in self + # Checking other == concat(...) isn't sufficient b/c self could have more extra fields + assert set(kwarg_names) == set(self.__dict__.keys()) - set(self.__fields__.keys()) + else: + kwargs = self.model_extra + return other == concat(self.a, self.b, **kwargs) + + pydantic_data_models = [ConcatModel] + else: + pydantic_data_models = [] class Concat: def __init__(self, a: int, b: int = 2, **kwargs: Dict[str, str]): @@ -646,22 +1136,22 @@ def __init__(self, a: int, b: int = 2, **kwargs: Dict[str, str]): def __eq__(self, other: str) -> bool: return other == concat(self.a, self.b, **self.kwargs) - self.concat_class = Concat + self.class_or_functions = [concat, Concat] + pydantic_data_models def test_tapify_empty_kwargs(self) -> None: - for class_or_function in [self.concat_function, self.concat_class]: + for class_or_function in self.class_or_functions: output = tapify(class_or_function, command_line_args=["--a", "1"]) self.assertEqual(output, "1_2_") def test_tapify_has_kwargs(self) -> None: - for class_or_function in [self.concat_function, self.concat_class]: + for class_or_function in self.class_or_functions: output = tapify(class_or_function, command_line_args=["--a", "1", "--c", "3", "--d", "4"]) self.assertEqual(output, "1_2_c=3-d=4") def test_tapify_has_kwargs_replace_default(self) -> None: - for class_or_function in [self.concat_function, self.concat_class]: + for class_or_function in self.class_or_functions: output = tapify(class_or_function, command_line_args=["--a", "1", "--c", "3", "--b", "5", "--d", "4"]) self.assertEqual(output, "1_5_c=3-d=4") diff --git a/tests/test_to_tap_class.py b/tests/test_to_tap_class.py new file mode 100644 index 0000000..730e4ed --- /dev/null +++ b/tests/test_to_tap_class.py @@ -0,0 +1,553 @@ +""" +Tests `tap.to_tap_class`. +""" + +from contextlib import redirect_stdout, redirect_stderr +import dataclasses +import io +import re +import sys +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union + +import pytest + +from tap import to_tap_class, Tap +from tap.utils import type_to_str + + +try: + import pydantic +except ModuleNotFoundError: + _IS_PYDANTIC_V1 = None +else: + _IS_PYDANTIC_V1 = pydantic.__version__ < "2.0.0" + + +# To properly test the help message, we need to know how argparse formats it. It changed from 3.8 -> 3.9 -> 3.10 +_OPTIONS_TITLE = "options" if not sys.version_info < (3, 10) else "optional arguments" +_ARG_LIST_DOTS = "..." if not sys.version_info < (3, 9) else "[ARG_LIST ...]" + + +@dataclasses.dataclass +class _Args: + """ + These are the arguments which every type of class or function must contain. + """ + + arg_int: int = dataclasses.field(metadata=dict(description="some integer")) + arg_bool: bool = True + arg_list: Optional[List[str]] = dataclasses.field(default=None, metadata=dict(description="some list of strings")) + + +def _monkeypatch_eq(cls): + """ + Monkey-patches `cls.__eq__` to check that the attribute values are equal to a dataclass representation of them. + """ + + def _equality(self, other: _Args) -> bool: + return _Args(self.arg_int, arg_bool=self.arg_bool, arg_list=self.arg_list) == other + + cls.__eq__ = _equality + return cls + + +# Define a few different classes or functions which all take the same arguments (same by name, annotation, and default +# if not required) + + +def function(arg_int: int, arg_bool: bool = True, arg_list: Optional[List[str]] = None) -> _Args: + """ + :param arg_int: some integer + :param arg_list: some list of strings + """ + return _Args(arg_int, arg_bool=arg_bool, arg_list=arg_list) + + +@_monkeypatch_eq +class Class: + def __init__(self, arg_int: int, arg_bool: bool = True, arg_list: Optional[List[str]] = None): + """ + :param arg_int: some integer + :param arg_list: some list of strings + """ + self.arg_int = arg_int + self.arg_bool = arg_bool + self.arg_list = arg_list + + +DataclassBuiltin = _Args + + +if _IS_PYDANTIC_V1 is None: + pass # will raise NameError if attempting to use DataclassPydantic or Model later +elif _IS_PYDANTIC_V1: + # For Pydantic v1 data models, we rely on the docstring to get descriptions + + @_monkeypatch_eq + @pydantic.dataclasses.dataclass + class DataclassPydantic: + """ + Dataclass (pydantic v1) + + :param arg_int: some integer + :param arg_list: some list of strings + """ + + arg_int: int + arg_bool: bool = True + arg_list: Optional[List[str]] = None + + @_monkeypatch_eq + class Model(pydantic.BaseModel): + """ + Pydantic model (pydantic v1) + + :param arg_int: some integer + :param arg_list: some list of strings + """ + + arg_int: int + arg_bool: bool = True + arg_list: Optional[List[str]] = None + +else: + # For pydantic v2 data models, we check the docstring and Field for the description + + @_monkeypatch_eq + @pydantic.dataclasses.dataclass + class DataclassPydantic: + """ + Dataclass (pydantic) + + :param arg_list: some list of strings + """ + + # Mixing field types should be ok + arg_int: int = pydantic.dataclasses.Field(description="some integer") + arg_bool: bool = dataclasses.field(default=True) + arg_list: Optional[List[str]] = pydantic.Field(default=None) + + @_monkeypatch_eq + class Model(pydantic.BaseModel): + """ + Pydantic model + + :param arg_int: some integer + """ + + # Mixing field types should be ok + arg_int: int + arg_bool: bool = dataclasses.field(default=True) + arg_list: Optional[List[str]] = pydantic.dataclasses.Field(default=None, description="some list of strings") + + +@pytest.fixture( + scope="module", + params=[ + function, + Class, + DataclassBuiltin, + DataclassBuiltin( + 1, arg_bool=False, arg_list=["these", "values", "don't", "matter"] + ), # to_tap_class also works on instances of data models. It ignores the attribute values + ] + + ([] if _IS_PYDANTIC_V1 is None else [DataclassPydantic, Model]), + # NOTE: instances of DataclassPydantic and Model can be tested for pydantic v2 but not v1 +) +def class_or_function_(request: pytest.FixtureRequest): + """ + Parametrized class_or_function. + """ + return request.param + + +# Define some functions which take a class or function and calls `tap.to_tap_class` on it to create a `tap.Tap` +# subclass (class, not instance) + + +def subclasser_simple(class_or_function: Any) -> Type[Tap]: + """ + Plain subclass, does nothing extra. + """ + return to_tap_class(class_or_function) + + +def subclasser_complex(class_or_function): + """ + It's conceivable that someone has a data model, but they want to add more arguments or handling when running a + script. + """ + + def to_number(string: str) -> Union[float, int]: + return float(string) if "." in string else int(string) + + class TapSubclass(to_tap_class(class_or_function)): + # You can supply additional arguments here + argument_with_really_long_name: Union[float, int] = 3 + "This argument has a long name and will be aliased with a short one" + + def configure(self) -> None: + # You can still add special argument behavior + self.add_argument("-arg", "--argument_with_really_long_name", type=to_number) + + def process_args(self) -> None: + # You can still validate and modify arguments + if self.argument_with_really_long_name > 4: + raise ValueError("argument_with_really_long_name cannot be > 4") + + # No auto-complete (and other niceties) for the super class attributes b/c this is a dynamic subclass. Sorry + if self.arg_bool and self.arg_list is not None: + self.arg_list.append("processed") + + return TapSubclass + + +def subclasser_subparser(class_or_function): + class SubparserA(Tap): + bar: int # bar help + + class SubparserB(Tap): + baz: Literal["X", "Y", "Z"] # baz help + + class TapSubclass(to_tap_class(class_or_function)): + foo: bool = False # foo help + + def configure(self): + self.add_subparsers(help="sub-command help") + self.add_subparser("a", SubparserA, help="a help", description="Description (a)") + self.add_subparser("b", SubparserB, help="b help") + + return TapSubclass + + +# Test that the subclasser parses the args correctly or raises the correct error. +# The subclassers are tested separately b/c the parametrizaiton of args_string_and_arg_to_expected_value depends on the +# subclasser. +# First, some helper functions. + + +def _test_raises_system_exit(tap: Tap, args_string: str) -> str: + is_help = ( + args_string.endswith("-h") + or args_string.endswith("--help") + or " -h " in args_string + or " --help " in args_string + ) + f = io.StringIO() + with redirect_stdout(f) if is_help else redirect_stderr(f): + with pytest.raises(SystemExit): + tap.parse_args(args_string.split()) + + return f.getvalue() + + +def _test_subclasser( + subclasser: Callable[[Any], Type[Tap]], + class_or_function: Any, + args_string_and_arg_to_expected_value: Tuple[str, Union[Dict[str, Any], BaseException]], + test_call: bool = True, +): + """ + Tests that the `subclasser` converts `class_or_function` to a `Tap` class which parses the argument string + correctly. + + Setting `test_call=True` additionally tests that calling the `class_or_function` on the parsed arguments works. + """ + args_string, arg_to_expected_value = args_string_and_arg_to_expected_value + TapSubclass = subclasser(class_or_function) + assert issubclass(TapSubclass, Tap) + tap = TapSubclass(description="Script description") + + if isinstance(arg_to_expected_value, SystemExit): + stderr = _test_raises_system_exit(tap, args_string) + assert str(arg_to_expected_value) in stderr + elif isinstance(arg_to_expected_value, BaseException): + expected_exception = arg_to_expected_value.__class__ + expected_error_message = str(arg_to_expected_value) or None + with pytest.raises(expected_exception=expected_exception, match=expected_error_message): + args = tap.parse_args(args_string.split()) + else: + # args_string is a valid argument combo + # Test that parsing works correctly + args = tap.parse_args(args_string.split()) + assert arg_to_expected_value == args.as_dict() + if test_call and callable(class_or_function): + result = class_or_function(**args.as_dict()) + assert result == _Args(**arg_to_expected_value) + + +def _test_subclasser_message( + subclasser: Callable[[Any], Type[Tap]], + class_or_function: Any, + message_expected: str, + description: str = "Script description", + args_string: str = "-h", +): + """ + Tests that:: + + subclasser(class_or_function)(description=description).parse_args(args_string.split()) + + outputs `message_expected` to stdout, ignoring differences in whitespaces/newlines/tabs. + """ + + def replace_whitespace(string: str) -> str: + return re.sub(r"\s+", " ", string).strip() # FYI this line was written by an LLM + + TapSubclass = subclasser(class_or_function) + tap = TapSubclass(description=description) + message = _test_raises_system_exit(tap, args_string) + # Standardize to ignore trivial differences due to terminal settings + assert replace_whitespace(message) == replace_whitespace(message_expected) + + +# Test sublcasser_simple + + +@pytest.mark.parametrize( + "args_string_and_arg_to_expected_value", + [ + ( + "--arg_int 1 --arg_list x y z", + {"arg_int": 1, "arg_bool": True, "arg_list": ["x", "y", "z"]}, + ), + ( + "--arg_int 1 --arg_bool", + {"arg_int": 1, "arg_bool": False, "arg_list": None}, + ), + # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance + ( + "--arg_list x y z --arg_bool", + SystemExit("error: the following arguments are required: --arg_int"), + ), + ( + "--arg_int not_an_int --arg_list x y z --arg_bool", + SystemExit("error: argument --arg_int: invalid int value: 'not_an_int'"), + ), + ], +) +def test_subclasser_simple( + class_or_function_: Any, args_string_and_arg_to_expected_value: Tuple[str, Union[Dict[str, Any], BaseException]] +): + _test_subclasser(subclasser_simple, class_or_function_, args_string_and_arg_to_expected_value) + + +# @pytest.mark.skipif(sys.version_info < (3, 10), reason="argparse is different. Need to fix help_message_expected") +def test_subclasser_simple_help_message(class_or_function_: Any): + description = "Script description" + help_message_expected = f""" +usage: pytest --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h] + +{description} + +{_OPTIONS_TITLE}: + --arg_int ARG_INT (int, required) some integer + --arg_bool (bool, default=True) + --arg_list [ARG_LIST {_ARG_LIST_DOTS}] + ({type_to_str(Optional[List[str]])}, default=None) some list of strings + -h, --help show this help message and exit +""".lstrip( + "\n" + ) + _test_subclasser_message(subclasser_simple, class_or_function_, help_message_expected, description=description) + + +# Test subclasser_complex + + +@pytest.mark.parametrize( + "args_string_and_arg_to_expected_value", + [ + ( + "--arg_int 1 --arg_list x y z", + { + "arg_int": 1, + "arg_bool": True, + "arg_list": ["x", "y", "z", "processed"], + "argument_with_really_long_name": 3, + }, + ), + ( + "--arg_int 1 --arg_list x y z -arg 2", + { + "arg_int": 1, + "arg_bool": True, + "arg_list": ["x", "y", "z", "processed"], + "argument_with_really_long_name": 2, + }, + ), + ( + "--arg_int 1 --arg_bool --argument_with_really_long_name 2.3", + { + "arg_int": 1, + "arg_bool": False, + "arg_list": None, + "argument_with_really_long_name": 2.3, + }, + ), + # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance + ( + "--arg_list x y z --arg_bool", + SystemExit("error: the following arguments are required: --arg_int"), + ), + ( + "--arg_int 1 --arg_list x y z -arg not_a_float_or_int", + SystemExit( + "error: argument -arg/--argument_with_really_long_name: invalid to_number value: 'not_a_float_or_int'" + ), + ), + ( + "--arg_int 1 --arg_list x y z -arg 5", # Wrong value arg (aliases argument_with_really_long_name) + ValueError("argument_with_really_long_name cannot be > 4"), + ), + ], +) +def test_subclasser_complex( + class_or_function_: Any, args_string_and_arg_to_expected_value: Tuple[str, Union[Dict[str, Any], BaseException]] +): + # Currently setting test_call=False b/c all data models except the pydantic Model don't accept extra args + _test_subclasser(subclasser_complex, class_or_function_, args_string_and_arg_to_expected_value, test_call=False) + + +# @pytest.mark.skipif(sys.version_info < (3, 10), reason="argparse is different. Need to fix help_message_expected") +def test_subclasser_complex_help_message(class_or_function_: Any): + description = "Script description" + help_message_expected = f""" +usage: pytest [-arg ARGUMENT_WITH_REALLY_LONG_NAME] --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h] + +{description} + +{_OPTIONS_TITLE}: + -arg ARGUMENT_WITH_REALLY_LONG_NAME, --argument_with_really_long_name ARGUMENT_WITH_REALLY_LONG_NAME + (Union[float, int], default=3) This argument has a long name and will be aliased with a short one + --arg_int ARG_INT (int, required) some integer + --arg_bool (bool, default=True) + --arg_list [ARG_LIST {_ARG_LIST_DOTS}] + ({type_to_str(Optional[List[str]])}, default=None) some list of strings + -h, --help show this help message and exit +""".lstrip( + "\n" + ) + _test_subclasser_message(subclasser_complex, class_or_function_, help_message_expected, description=description) + + +# Test subclasser_subparser + + +@pytest.mark.parametrize( + "args_string_and_arg_to_expected_value", + [ + ( + "--arg_int 1", + {"arg_int": 1, "arg_bool": True, "arg_list": None, "foo": False}, + ), + ( + "--arg_int 1 a --bar 2", + {"arg_int": 1, "arg_bool": True, "arg_list": None, "bar": 2, "foo": False}, + ), + ( + "--arg_int 1 --foo a --bar 2", + {"arg_int": 1, "arg_bool": True, "arg_list": None, "bar": 2, "foo": True}, + ), + ( + "--arg_int 1 b --baz X", + {"arg_int": 1, "arg_bool": True, "arg_list": None, "baz": "X", "foo": False}, + ), + ( + "--foo --arg_bool --arg_list x y z --arg_int 1 b --baz Y", + {"arg_int": 1, "arg_bool": False, "arg_list": ["x", "y", "z"], "baz": "Y", "foo": True}, + ), + # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance + ( + "a --bar 1", + SystemExit("error: the following arguments are required: --arg_int"), + ), + ( + "--arg_int not_an_int a --bar 1", + SystemExit("error: argument --arg_int: invalid int value: 'not_an_int'"), + ), + ( + "--arg_int 1 --baz X --foo b", + SystemExit( + "error: argument {a,b}: invalid choice: 'X' (choose from 'a', 'b')" + if sys.version_info >= (3, 9) + else "error: invalid choice: 'X' (choose from 'a', 'b')" + ), + ), + ( + "--arg_int 1 b --baz X --foo", + SystemExit("error: unrecognized arguments: --foo"), + ), + ( + "--arg_int 1 --foo b --baz A", + SystemExit("""error: argument --baz: Value for variable "baz" must be one of ['X', 'Y', 'Z']."""), + ), + ], +) +def test_subclasser_subparser( + class_or_function_: Any, args_string_and_arg_to_expected_value: Tuple[str, Union[Dict[str, Any], BaseException]] +): + # Currently setting test_call=False b/c all data models except the pydantic Model don't accept extra args + _test_subclasser(subclasser_subparser, class_or_function_, args_string_and_arg_to_expected_value, test_call=False) + + +# @pytest.mark.skipif(sys.version_info < (3, 10), reason="argparse is different. Need to fix help_message_expected") +@pytest.mark.parametrize( + "args_string_and_description_and_expected_message", + [ + ( + "-h", + "Script description", + # foo help likely missing b/c class nesting. In a demo in a Python 3.8 env, foo help appears in -h + f""" +usage: pytest [--foo] --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h] {{a,b}} ... + +Script description + +positional arguments: + {{a,b}} sub-command help + a a help + b b help + +{_OPTIONS_TITLE}: + --foo (bool, default=False) {'' if sys.version_info < (3, 9) else 'foo help'} + --arg_int ARG_INT (int, required) some integer + --arg_bool (bool, default=True) + --arg_list [ARG_LIST {_ARG_LIST_DOTS}] + ({type_to_str(Optional[List[str]])}, default=None) some list of strings + -h, --help show this help message and exit +""", + ), + ( + "a -h", + "Description (a)", + f""" +usage: pytest a --bar BAR [-h] + +Description (a) + +{_OPTIONS_TITLE}: + --bar BAR (int, required) bar help + -h, --help show this help message and exit +""", + ), + ( + "b -h", + "", + f""" +usage: pytest b --baz {{X,Y,Z}} [-h] + +{_OPTIONS_TITLE}: + --baz {{X,Y,Z}} (Literal['X', 'Y', 'Z'], required) baz help + -h, --help show this help message and exit +""", + ), + ], +) +def test_subclasser_subparser_help_message( + class_or_function_: Any, args_string_and_description_and_expected_message: Tuple[str, str] +): + args_string, description, expected_message = args_string_and_description_and_expected_message + _test_subclasser_message( + subclasser_subparser, class_or_function_, expected_message, description=description, args_string=args_string + )