diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 3470872..60d40ff 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -43,6 +43,14 @@ jobs:
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- - name: Test with pytest
+ - name: Test without pydantic
run: |
pytest
+ - name: Test with pydantic v1
+ run: |
+ python -m pip install "pydantic < 2"
+ pytest
+ - name: Test with pydantic v2
+ run: |
+ python -m pip install "pydantic >= 2"
+ pytest
diff --git a/README.md b/README.md
index 997f9a3..1ad6851 100644
--- a/README.md
+++ b/README.md
@@ -666,7 +666,7 @@ from tap import tapify
class Squarer:
"""Squarer with a number to square.
- :param num: The number to square.
+ :param num: The number to square.
"""
num: float
@@ -681,6 +681,94 @@ if __name__ == '__main__':
Running `python square_dataclass.py --num -1` prints `The square of your number is 1.0.`.
+
+Argument descriptions
+
+For dataclasses, the argument's description (which is displayed in the `-h` help message) can either be specified in the
+class docstring or the field's description in `metadata`. If both are specified, the description from the docstring is
+used. In the example below, the description is provided in `metadata`.
+
+```python
+# square_dataclass.py
+from dataclasses import dataclass, field
+
+from tap import tapify
+
+@dataclass
+class Squarer:
+ """Squarer with a number to square.
+ """
+ num: float = field(metadata={"description": "The number to square."})
+
+ def get_square(self) -> float:
+ """Get the square of the number."""
+ return self.num ** 2
+
+if __name__ == '__main__':
+ squarer = tapify(Squarer)
+ print(f'The square of your number is {squarer.get_square()}.')
+```
+
+
+
+#### Pydantic
+
+Pydantic [Models](https://docs.pydantic.dev/latest/concepts/models/) and
+[dataclasses](https://docs.pydantic.dev/latest/concepts/dataclasses/) can be `tapify`d.
+
+```python
+# square_pydantic.py
+from pydantic import BaseModel, Field
+
+from tap import tapify
+
+class Squarer(BaseModel):
+ """Squarer with a number to square.
+ """
+ num: float = Field(description="The number to square.")
+
+ def get_square(self) -> float:
+ """Get the square of the number."""
+ return self.num ** 2
+
+if __name__ == '__main__':
+ squarer = tapify(Squarer)
+ print(f'The square of your number is {squarer.get_square()}.')
+```
+
+
+Argument descriptions
+
+For Pydantic v2 models and dataclasses, the argument's description (which is displayed in the `-h` help message) can
+either be specified in the class docstring or the field's `description`. If both are specified, the description from the
+docstring is used. In the example below, the description is provided in the docstring.
+
+For Pydantic v1 models and dataclasses, the argument's description must be provided in the class docstring:
+
+```python
+# square_pydantic.py
+from pydantic import BaseModel
+
+from tap import tapify
+
+class Squarer(BaseModel):
+ """Squarer with a number to square.
+
+ :param num: The number to square.
+ """
+ num: float
+
+ def get_square(self) -> float:
+ """Get the square of the number."""
+ return self.num ** 2
+
+if __name__ == '__main__':
+ squarer = tapify(Squarer)
+ print(f'The square of your number is {squarer.get_square()}.')
+```
+
+
+
### tapify help
The help string on the command line is set based on the docstring for the function or class. For example, running `python square_function.py -h` will print:
@@ -751,4 +839,57 @@ if __name__ == '__main__':
Running `python person.py --name Jesse --age 1` prints `My name is Jesse.` followed by `My age is 1.`. Without `known_only=True`, the `tapify` calls would raise an error due to the extra argument.
### Explicit boolean arguments
+
Tapify supports explicit specification of boolean arguments (see [bool](#bool) for more details). By default, `explicit_bool=False` and it can be set with `tapify(..., explicit_bool=True)`.
+
+## Convert to a `Tap` class
+
+`to_tap_class` turns a function or class into a `Tap` class. The returned class can be [subclassed](#subclassing) to add
+special argument behavior. For example, you can override [`configure`](#configuring-arguments) and
+[`process_args`](#argument-processing).
+
+If the object can be `tapify`d, then it can be `to_tap_class`d, and vice-versa. `to_tap_class` provides full control
+over argument parsing.
+
+### Examples
+
+#### Simple
+
+```python
+# main.py
+"""
+My script description
+"""
+
+from pydantic import BaseModel
+
+from tap import to_tap_class
+
+class Project(BaseModel):
+ package: str
+ is_cool: bool = True
+ stars: int = 5
+
+if __name__ == "__main__":
+ ProjectTap = to_tap_class(Project)
+ tap = ProjectTap(description=__doc__) # from the top of this script
+ args = tap.parse_args()
+ project = Project(**args.as_dict())
+ print(f"Project instance: {project}")
+```
+
+Running `python main.py --package tap` will print `Project instance: package='tap' is_cool=True stars=5`.
+
+### Complex
+
+The general pattern is:
+
+```python
+from tap import to_tap_class
+
+class MyCustomTap(to_tap_class(my_class_or_function)):
+ # Special argument behavior, e.g., override configure and/or process_args
+```
+
+Please see `demo_data_model.py` for an example of overriding [`configure`](#configuring-arguments) and
+[`process_args`](#argument-processing).
diff --git a/demo_data_model.py b/demo_data_model.py
new file mode 100644
index 0000000..6542d99
--- /dev/null
+++ b/demo_data_model.py
@@ -0,0 +1,96 @@
+"""
+Works for Pydantic v1 and v2.
+
+Example commands:
+
+python demo_data_model.py -h
+
+python demo_data_model.py \
+ --arg_int 1 \
+ --arg_list x y z \
+ --argument_with_really_long_name 3
+
+python demo_data_model.py \
+ --arg_int 1 \
+ --arg_list x y z \
+ --arg_bool \
+ -arg 3.14
+"""
+from typing import List, Literal, Optional, Union
+
+from pydantic import BaseModel, Field
+from tap import tapify, to_tap_class, Tap
+
+
+class Model(BaseModel):
+ """
+ My Pydantic Model which contains script args.
+ """
+
+ arg_int: int = Field(description="some integer")
+ arg_bool: bool = Field(default=True)
+ arg_list: Optional[List[str]] = Field(default=None, description="some list of strings")
+
+
+def main(model: Model) -> None:
+ print("Parsed args into Model:")
+ print(model)
+
+
+def to_number(string: str) -> Union[float, int]:
+ return float(string) if "." in string else int(string)
+
+
+class ModelTap(to_tap_class(Model)):
+ # You can supply additional arguments here
+ argument_with_really_long_name: Union[float, int] = 3
+ "This argument has a long name and will be aliased with a short one"
+
+ def configure(self) -> None:
+ # You can still add special argument behavior
+ self.add_argument("-arg", "--argument_with_really_long_name", type=to_number)
+
+ def process_args(self) -> None:
+ # You can still validate and modify arguments
+ # (You should do this in the Pydantic Model. I'm just demonstrating that this functionality is still possible)
+ if self.argument_with_really_long_name > 4:
+ raise ValueError("argument_with_really_long_name cannot be > 4")
+
+ # No auto-complete (and other niceties) for the super class attributes b/c this is a dynamic subclass. Sorry
+ if self.arg_bool and self.arg_list is not None:
+ self.arg_list.append("processed")
+
+
+# class SubparserA(Tap):
+# bar: int # bar help
+
+
+# class SubparserB(Tap):
+# baz: Literal["X", "Y", "Z"] # baz help
+
+
+# class ModelTapWithSubparsing(to_tap_class(Model)):
+# foo: bool = False # foo help
+
+# def configure(self):
+# self.add_subparsers(help="sub-command help")
+# self.add_subparser("a", SubparserA, help="a help", description="Description (a)")
+# self.add_subparser("b", SubparserB, help="b help")
+
+
+if __name__ == "__main__":
+ # You don't have to subclass tap_class_from_data_model(Model) if you just want a plain argument parser:
+ # ModelTap = to_tap_class(Model)
+ args = ModelTap(description="Script description").parse_args()
+ # args = ModelTapWithSubparsing(description="Script description").parse_args()
+ print("Parsed args:")
+ print(args)
+ # Run the main function
+ model = Model(**args.as_dict())
+ main(model)
+
+
+# tapify works with Model. It immediately returns a Model instance instead of a Tap class
+# if __name__ == "__main__":
+# model = tapify(Model)
+# print(model)
diff --git a/setup.py b/setup.py
index c405119..f4d3b3c 100644
--- a/setup.py
+++ b/setup.py
@@ -12,6 +12,11 @@
with open("README.md", encoding="utf-8") as f:
long_description = f.read()
+test_requirements = [
+ "pydantic >= 2.5.0",
+ "pytest",
+]
+
setup(
name="typed-argument-parser",
version=__version__,
@@ -26,7 +31,8 @@
packages=find_packages(),
package_data={"tap": ["py.typed"]},
install_requires=["typing-inspect >= 0.7.1", "docstring-parser >= 0.15"],
- tests_require=["pytest"],
+ tests_require=test_requirements,
+ extras_require={"dev": test_requirements},
python_requires=">=3.8",
classifiers=[
"Programming Language :: Python :: 3",
diff --git a/tap/__init__.py b/tap/__init__.py
index b10e9d4..ea7b844 100644
--- a/tap/__init__.py
+++ b/tap/__init__.py
@@ -1,6 +1,13 @@
from argparse import ArgumentError, ArgumentTypeError
from tap._version import __version__
from tap.tap import Tap
-from tap.tapify import tapify
+from tap.tapify import tapify, to_tap_class
-__all__ = ["ArgumentError", "ArgumentTypeError", "Tap", "tapify", "__version__"]
+__all__ = [
+ "ArgumentError",
+ "ArgumentTypeError",
+ "Tap",
+ "tapify",
+ "to_tap_class",
+ "__version__",
+]
diff --git a/tap/tapify.py b/tap/tapify.py
index 39e6cd7..58f8c34 100644
--- a/tap/tapify.py
+++ b/tap/tapify.py
@@ -1,106 +1,334 @@
-"""Tapify module, which can initialize a class or run a function by parsing arguments from the command line."""
-from inspect import signature, Parameter
-from typing import Any, Callable, List, Optional, Type, TypeVar, Union
+"""
+`tapify`: initialize a class or run a function by parsing arguments from the command line.
-from docstring_parser import parse
+`to_tap_class`: convert a class or function into a `Tap` class, which can then be subclassed to add special argument
+handling
+"""
+
+import dataclasses
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Sequence, Type, TypeVar, Union
+
+from docstring_parser import Docstring, parse
+
+try:
+ import pydantic
+except ModuleNotFoundError:
+ _IS_PYDANTIC_V1 = None
+ # These are "empty" types. isinstance and issubclass will always be False
+ BaseModel = type("BaseModel", (object,), {})
+ _PydanticField = type("_PydanticField", (object,), {})
+ _PYDANTIC_FIELD_TYPES = ()
+else:
+ _IS_PYDANTIC_V1 = pydantic.__version__ < "2.0.0"
+ from pydantic import BaseModel
+ from pydantic.fields import FieldInfo as PydanticFieldBaseModel
+ from pydantic.dataclasses import FieldInfo as PydanticFieldDataclass
+
+ _PydanticField = Union[PydanticFieldBaseModel, PydanticFieldDataclass]
+ # typing.get_args(_PydanticField) is an empty tuple for some reason. Just repeat
+ _PYDANTIC_FIELD_TYPES = (PydanticFieldBaseModel, PydanticFieldDataclass)
from tap import Tap
InputType = TypeVar("InputType")
OutputType = TypeVar("OutputType")
+_ClassOrFunction = Union[Callable[[InputType], OutputType], Type[OutputType]]
-def tapify(
- class_or_function: Union[Callable[[InputType], OutputType], Type[OutputType]],
- known_only: bool = False,
- command_line_args: Optional[List[str]] = None,
- explicit_bool: bool = False,
- **func_kwargs,
-) -> OutputType:
- """Tapify initializes a class or runs a function by parsing arguments from the command line.
- :param class_or_function: The class or function to run with the provided arguments.
- :param known_only: If true, ignores extra arguments and only parses known arguments.
- :param command_line_args: A list of command line style arguments to parse (e.g., ['--arg', 'value']).
- If None, arguments are parsed from the command line (default behavior).
- :param explicit_bool: Booleans can be specified on the command line as "--arg True" or "--arg False"
- rather than "--arg". Additionally, booleans can be specified by prefixes of True and False
- with any capitalization as well as 1 or 0.
- :param func_kwargs: Additional keyword arguments for the function. These act as default values when
- parsing the command line arguments and overwrite the function defaults but
- are overwritten by the parsed command line arguments.
+@dataclasses.dataclass
+class _ArgData:
+ """
+ Data about an argument which is sufficient to inform a Tap variable/argument.
"""
- # Get signature from class or function
- sig = signature(class_or_function)
- # Parse class or function docstring in one line
- if isinstance(class_or_function, type) and class_or_function.__init__.__doc__ is not None:
- doc = class_or_function.__init__.__doc__
+ name: str
+
+ annotation: Type
+ "The type of values this argument accepts"
+
+ is_required: bool
+ "Whether or not the argument must be passed in"
+
+ default: Any
+ "Value of the argument if the argument isn't passed in. This gets ignored if is_required"
+
+ description: Optional[str] = ""
+ "Human-readable description of the argument"
+
+
+@dataclasses.dataclass(frozen=True)
+class _TapData:
+ """
+ Data about a class' or function's arguments which are sufficient to inform a Tap class.
+ """
+
+ args_data: List[_ArgData]
+ "List of data about each argument in the class or function"
+
+ has_kwargs: bool
+ "True if you can pass variable/extra kwargs to the class or function (as in **kwargs), else False"
+
+ known_only: bool
+ "If true, ignore extra arguments and only parse known arguments"
+
+
+def _is_pydantic_base_model(obj: Union[Type[Any], Any]) -> bool:
+ if inspect.isclass(obj): # issublcass requires that obj is a class
+ return issubclass(obj, BaseModel)
else:
- doc = class_or_function.__doc__
+ return isinstance(obj, BaseModel)
- # Parse docstring
- docstring = parse(doc)
- # Get the description of each argument in the class init or function
- param_to_description = {param.arg_name: param.description for param in docstring.params}
+def _is_pydantic_dataclass(obj: Union[Type[Any], Any]) -> bool:
+ if _IS_PYDANTIC_V1:
+ # There's no public function in v1. This is a somewhat safe but linear check
+ return dataclasses.is_dataclass(obj) and any(key.startswith("__pydantic") for key in obj.__dict__)
+ else:
+ return pydantic.dataclasses.is_pydantic_dataclass(obj)
- # Create a Tap object with a description from the docstring of the function or class
- description = "\n".join(filter(None, (docstring.short_description, docstring.long_description)))
- tap = Tap(description=description, explicit_bool=explicit_bool)
- # Keep track of whether **kwargs was provided
+def _tap_data_from_data_model(
+ data_model: Any, func_kwargs: Dict[str, Any], param_to_description: Dict[str, str] = None
+) -> _TapData:
+ """
+ Currently only works when `data_model` is a:
+ - builtin dataclass (class or instance)
+ - Pydantic dataclass (class or instance)
+ - Pydantic BaseModel (class or instance).
+
+ The advantage of this function over :func:`_tap_data_from_class_or_function` is that field/argument descriptions are
+ extracted, b/c this function look at the fields of the data model.
+
+ Note
+ ----
+ Deletes redundant keys from `func_kwargs`
+ """
+ param_to_description = param_to_description or {}
+
+ def arg_data_from_dataclass(name: str, field: dataclasses.Field) -> _ArgData:
+ def is_required(field: dataclasses.Field) -> bool:
+ return field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING
+
+ description = param_to_description.get(name, field.metadata.get("description"))
+ return _ArgData(
+ name,
+ field.type,
+ is_required(field),
+ field.default,
+ description,
+ )
+
+ def arg_data_from_pydantic(name: str, field: _PydanticField, annotation: Optional[Type] = None) -> _ArgData:
+ annotation = field.annotation if annotation is None else annotation
+ # Prefer the description from param_to_description (from the data model / class docstring) over the
+ # field.description b/c a docstring can be modified on the fly w/o causing real issues
+ description = param_to_description.get(name, field.description)
+ return _ArgData(name, annotation, field.is_required(), field.default, description)
+
+ # Determine what type of data model it is and extract fields accordingly
+ if dataclasses.is_dataclass(data_model):
+ name_to_field = {field.name: field for field in dataclasses.fields(data_model)}
+ has_kwargs = False
+ known_only = False
+ elif _is_pydantic_base_model(data_model):
+ name_to_field = data_model.model_fields
+ # For backwards compatibility, only allow new kwargs to get assigned if the model is explicitly configured to do
+ # so via extra="allow". See https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra
+ is_extra_ok = data_model.model_config.get("extra", "ignore") == "allow"
+ has_kwargs = is_extra_ok
+ known_only = is_extra_ok
+ else:
+ raise TypeError(
+ "data_model must be a builtin or Pydantic dataclass (instance or class) or "
+ f"a Pydantic BaseModel (instance or class). Got {type(data_model)}"
+ )
+
+ # It's possible to mix fields w/ classes, e.g., use pydantic Fields in a (builtin) dataclass, or use (builtin)
+ # dataclass fields in a pydantic BaseModel. It's also possible to use (builtin) dataclass fields and pydantic Fields
+ # in the same data model. Therefore, the type of the data model doesn't determine the type of each field. The
+ # solution is to iterate through the fields and check each type.
+ args_data: List[_ArgData] = []
+ for name, field in name_to_field.items():
+ if isinstance(field, dataclasses.Field):
+ # Idiosyncrasy: if a pydantic Field is used in a pydantic dataclass, then field.default is a FieldInfo
+ # object instead of the field's default value. Furthermore, field.annotation is always NoneType. Luckily,
+ # the actual type of the field is stored in field.type
+ if isinstance(field.default, _PYDANTIC_FIELD_TYPES):
+ arg_data = arg_data_from_pydantic(name, field.default, annotation=field.type)
+ else:
+ arg_data = arg_data_from_dataclass(name, field)
+ elif isinstance(field, _PYDANTIC_FIELD_TYPES):
+ arg_data = arg_data_from_pydantic(name, field)
+ else:
+ raise TypeError(f"Each field must be a dataclass or Pydantic field. Got {type(field)}")
+ # Handle case where func_kwargs is supplied
+ if name in func_kwargs:
+ arg_data.default = func_kwargs[name]
+ arg_data.is_required = False
+ del func_kwargs[name]
+ args_data.append(arg_data)
+ return _TapData(args_data, has_kwargs, known_only)
+
+
+def _tap_data_from_class_or_function(
+ class_or_function: _ClassOrFunction, func_kwargs: Dict[str, Any], param_to_description: Dict[str, str]
+) -> _TapData:
+ """
+ Extract data by inspecting the signature of `class_or_function`.
+
+ Note
+ ----
+ Deletes redundant keys from `func_kwargs`
+ """
+ args_data: List[_ArgData] = []
has_kwargs = False
+ known_only = False
- # Add arguments of class init or function to the Tap object
- for param_name, param in sig.parameters.items():
- tap_kwargs = {}
+ sig = inspect.signature(class_or_function)
+ for param_name, param in sig.parameters.items():
# Skip **kwargs
- if param.kind == Parameter.VAR_KEYWORD:
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
has_kwargs = True
known_only = True
continue
- # Get type of the argument
- if param.annotation != Parameter.empty:
- # Any type defaults to str (needed for dataclasses where all non-default attributes must have a type)
- if param.annotation is Any:
- tap._annotations[param.name] = str
- # Otherwise, get the type of the argument
- else:
- tap._annotations[param.name] = param.annotation
+ if param.annotation != inspect.Parameter.empty:
+ annotation = param.annotation
+ else:
+ annotation = Any
- # Get the default or required of the argument
if param.name in func_kwargs:
- tap_kwargs["default"] = func_kwargs[param.name]
+ is_required = False
+ default = func_kwargs[param.name]
del func_kwargs[param.name]
- elif param.default != Parameter.empty:
- tap_kwargs["default"] = param.default
+ elif param.default != inspect.Parameter.empty:
+ is_required = False
+ default = param.default
else:
- tap_kwargs["required"] = True
+ is_required = True
+ default = inspect.Parameter.empty # Can be set to anything. It'll be ignored
+
+ arg_data = _ArgData(
+ name=param_name,
+ annotation=annotation,
+ is_required=is_required,
+ default=default,
+ description=param_to_description.get(param.name),
+ )
+ args_data.append(arg_data)
+ return _TapData(args_data, has_kwargs, known_only)
+
+
+def _is_data_model(obj: Union[Type[Any], Any]) -> bool:
+ return dataclasses.is_dataclass(obj) or _is_pydantic_base_model(obj)
+
+
+def _docstring(class_or_function) -> Docstring:
+ is_function = not inspect.isclass(class_or_function)
+ if is_function or _is_pydantic_base_model(class_or_function):
+ doc = class_or_function.__doc__
+ else:
+ doc = class_or_function.__init__.__doc__ or class_or_function.__doc__
+ return parse(doc)
+
+
+def _tap_data(class_or_function: _ClassOrFunction, param_to_description: Dict[str, str], func_kwargs) -> _TapData:
+ """
+ Controls how :class:`_TapData` is extracted from `class_or_function`.
+ """
+ is_pydantic_v1_data_model = _IS_PYDANTIC_V1 and (
+ _is_pydantic_base_model(class_or_function) or _is_pydantic_dataclass(class_or_function)
+ )
+ if _is_data_model(class_or_function) and not is_pydantic_v1_data_model:
+ # Data models from Pydantic v1 don't lend themselves well to _tap_data_from_data_model.
+ # _tap_data_from_data_model looks at the data model's fields. In Pydantic v1, the field.type_ attribute stores
+ # the field's annotation/type. But (in Pydantic v1) there's a bug where field.type_ is set to the inner-most
+ # type of a subscripted type. For example, annotating a field with list[str] causes field.type_ to be str, not
+ # list[str]. To get around this, we'll extract _TapData by looking at the signature of the data model
+ return _tap_data_from_data_model(class_or_function, func_kwargs, param_to_description)
+ # TODO: allow passing func_kwargs to a Pydantic BaseModel
+ return _tap_data_from_class_or_function(class_or_function, func_kwargs, param_to_description)
+
+
+def _tap_class(args_data: Sequence[_ArgData]) -> Type[Tap]:
+ class ArgParser(Tap):
+ # Overwriting configure would force a user to remember to call super().configure if they want to overwrite it
+ # Instead, overwrite _configure
+ def _configure(self):
+ for arg_data in args_data:
+ variable = arg_data.name
+ self._annotations[variable] = str if arg_data.annotation is Any else arg_data.annotation
+ self.class_variables[variable] = {"comment": arg_data.description or ""}
+ if arg_data.is_required:
+ kwargs = {}
+ else:
+ kwargs = dict(required=False, default=arg_data.default)
+ self.add_argument(f"--{variable}", **kwargs)
+
+ super()._configure()
+
+ return ArgParser
+
+
+def to_tap_class(class_or_function: _ClassOrFunction) -> Type[Tap]:
+ """Creates a `Tap` class from `class_or_function`. This can be subclassed to add custom argument handling and
+ instantiated to create a typed argument parser.
+
+ :param class_or_function: The class or function to run with the provided arguments.
+ """
+ docstring = _docstring(class_or_function)
+ param_to_description = {param.arg_name: param.description for param in docstring.params}
+ # TODO: add func_kwargs
+ tap_data = _tap_data(class_or_function, param_to_description, func_kwargs={})
+ return _tap_class(tap_data.args_data)
- # Get the help string of the argument
- if param.name in param_to_description:
- tap.class_variables[param.name] = {"comment": param_to_description[param.name]}
- # Add the argument to the Tap object
- tap._add_argument(f"--{param_name}", **tap_kwargs)
+def tapify(
+ class_or_function: Union[Callable[[InputType], OutputType], Type[OutputType]],
+ known_only: bool = False,
+ command_line_args: Optional[List[str]] = None,
+ explicit_bool: bool = False,
+ **func_kwargs,
+) -> OutputType:
+ """Tapify initializes a class or runs a function by parsing arguments from the command line.
+
+ :param class_or_function: The class or function to run with the provided arguments.
+ :param known_only: If true, ignores extra arguments and only parses known arguments.
+ :param command_line_args: A list of command line style arguments to parse (e.g., ['--arg', 'value']).
+ If None, arguments are parsed from the command line (default behavior).
+ :param explicit_bool: Booleans can be specified on the command line as "--arg True" or "--arg False"
+ rather than "--arg". Additionally, booleans can be specified by prefixes of True and False
+ with any capitalization as well as 1 or 0.
+ :param func_kwargs: Additional keyword arguments for the function. These act as default values when
+ parsing the command line arguments and overwrite the function defaults but
+ are overwritten by the parsed command line arguments.
+ """
+ # We don't directly call to_tap_class b/c we need tap_data, not just tap_class
+ docstring = _docstring(class_or_function)
+ param_to_description = {param.arg_name: param.description for param in docstring.params}
+ tap_data = _tap_data(class_or_function, param_to_description, func_kwargs)
+ tap_class = _tap_class(tap_data.args_data)
+ # Create a Tap object with a description from the docstring of the class or function
+ description = "\n".join(filter(None, (docstring.short_description, docstring.long_description)))
+ tap = tap_class(description=description, explicit_bool=explicit_bool)
# If any func_kwargs remain, they are not used in the function, so raise an error
+ known_only = known_only or tap_data.known_only
if func_kwargs and not known_only:
raise ValueError(f"Unknown keyword arguments: {func_kwargs}")
# Parse command line arguments
- command_line_args = tap.parse_args(args=command_line_args, known_only=known_only)
+ command_line_args: Tap = tap.parse_args(args=command_line_args, known_only=known_only)
# Get command line arguments as a dictionary
command_line_args_dict = command_line_args.as_dict()
# Get **kwargs from extra command line arguments
- if has_kwargs:
+ if tap_data.has_kwargs:
kwargs = {tap.extra_args[i].lstrip("-"): tap.extra_args[i + 1] for i in range(0, len(tap.extra_args), 2)}
-
command_line_args_dict.update(kwargs)
# Initialize the class or run the function with the parsed arguments
diff --git a/tests/test_tapify.py b/tests/test_tapify.py
index d121c6d..f2aa387 100644
--- a/tests/test_tapify.py
+++ b/tests/test_tapify.py
@@ -1,3 +1,7 @@
+"""
+Tests `tap.tapify`. Currently requires Pydantic v2.
+"""
+
import contextlib
from dataclasses import dataclass
import io
@@ -9,6 +13,14 @@
from tap import tapify
+try:
+ import pydantic
+except ModuleNotFoundError:
+ _IS_PYDANTIC_V1 = None
+else:
+ _IS_PYDANTIC_V1 = pydantic.__version__ < "2.0.0"
+
+
# Suppress prints from SystemExit
class DevNull:
def write(self, msg):
@@ -22,7 +34,7 @@ class Person:
def __init__(self, name: str):
self.name = name
- def __str__(self) -> str:
+ def __repr__(self) -> str:
return f"Person({self.name})"
@@ -31,7 +43,7 @@ def __init__(self, problem_1: str, problem_2):
self.problem_1 = problem_1
self.problem_2 = problem_2
- def __str__(self) -> str:
+ def __repr__(self) -> str:
return f"Problems({self.problem_1}, {self.problem_2})"
@@ -49,7 +61,22 @@ class PieDataclass:
def __eq__(self, other: float) -> bool:
return other == pie()
- for class_or_function in [pie, Pie, PieDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class PieDataclassPydantic:
+ def __eq__(self, other: float) -> bool:
+ return other == pie()
+
+ class PieModel(pydantic.BaseModel):
+ def __eq__(self, other: float) -> bool:
+ return other == pie()
+
+ pydantic_data_models = [PieDataclassPydantic, PieModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [pie, Pie, PieDataclass] + pydantic_data_models:
self.assertEqual(tapify(class_or_function, command_line_args=[]), 3.14)
def test_tapify_simple_types(self):
@@ -74,7 +101,34 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.a, self.simple, self.test, self.of, self.types)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ a: int
+ simple: str
+ test: float
+ of: float
+ types: bool
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.a, self.simple, self.test, self.of, self.types)
+
+ class ConcatModel(pydantic.BaseModel):
+ a: int
+ simple: str
+ test: float
+ of: float
+ types: bool
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.a, self.simple, self.test, self.of, self.types)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output = tapify(
class_or_function,
command_line_args=["--a", "1", "--simple", "simple", "--test", "3.14", "--of", "2.718", "--types"],
@@ -107,7 +161,37 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.a, self.simple, self.test, self.of, self.types, self.wow)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ a: int
+ simple: str
+ test: float
+ of: float = -0.3
+ types: bool = False
+ wow: str = "abc"
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.a, self.simple, self.test, self.of, self.types, self.wow)
+
+ class ConcatModel(pydantic.BaseModel):
+ a: int
+ simple: str
+ test: float
+ of: float = -0.3
+ types: bool = False
+ wow: str = "abc"
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.a, self.simple, self.test, self.of, self.types, self.wow)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
+ print(class_or_function.__name__)
output = tapify(
class_or_function,
command_line_args=["--a", "1", "--simple", "simple", "--test", "3.14", "--types", "--wow", "wee"],
@@ -135,7 +219,38 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.complexity, self.requires, self.intelligence)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass(config=dict(arbitrary_types_allowed=True)) # for Person
+ class ConcatDataclassPydantic:
+ complexity: List[str]
+ requires: Tuple[int, int]
+ intelligence: Person
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.complexity, self.requires, self.intelligence)
+
+ class ConcatModel(pydantic.BaseModel):
+ if _IS_PYDANTIC_V1:
+
+ class Config:
+ arbitrary_types_allowed = True # for Person
+
+ else:
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) # for Person
+
+ complexity: List[str]
+ requires: Tuple[int, int]
+ intelligence: Person
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.complexity, self.requires, self.intelligence)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output = tapify(
class_or_function,
command_line_args=[
@@ -176,10 +291,51 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.complexity, self.requires, self.intelligence)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass(config=dict(arbitrary_types_allowed=True)) # for Person
+ class ConcatDataclassPydantic:
+ complexity: list[int]
+ requires: tuple[int, int]
+ intelligence: Person
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.complexity, self.requires, self.intelligence)
+
+ class ConcatModel(pydantic.BaseModel):
+ if _IS_PYDANTIC_V1:
+
+ class Config:
+ arbitrary_types_allowed = True # for Person
+
+ else:
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) # for Person
+
+ complexity: list[int]
+ requires: tuple[int, int]
+ intelligence: Person
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.complexity, self.requires, self.intelligence)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output = tapify(
class_or_function,
- command_line_args=["--complexity", "1", "2", "3", "--requires", "1", "0", "--intelligence", "jesse",],
+ command_line_args=[
+ "--complexity",
+ "1",
+ "2",
+ "3",
+ "--requires",
+ "1",
+ "0",
+ "--intelligence",
+ "jesse",
+ ],
)
self.assertEqual(output, "1 2 3 1 0 Person(jesse)")
@@ -225,7 +381,42 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.complexity, self.requires, self.intelligence, self.maybe, self.possibly)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass(config=dict(arbitrary_types_allowed=True)) # for Person
+ class ConcatDataclassPydantic:
+ complexity: List[str]
+ requires: Tuple[int, int] = (2, 5)
+ intelligence: Person = Person("kyle")
+ maybe: Optional[str] = None
+ possibly: Optional[str] = None
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.complexity, self.requires, self.intelligence, self.maybe, self.possibly)
+
+ class ConcatModel(pydantic.BaseModel):
+ if _IS_PYDANTIC_V1:
+
+ class Config:
+ arbitrary_types_allowed = True # for Person
+
+ else:
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) # for Person
+
+ complexity: List[str]
+ requires: Tuple[int, int] = (2, 5)
+ intelligence: Person = Person("kyle")
+ maybe: Optional[str] = None
+ possibly: Optional[str] = None
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.complexity, self.requires, self.intelligence, self.maybe, self.possibly)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output = tapify(
class_or_function,
command_line_args=[
@@ -263,7 +454,30 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.so, self.many, self.args)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ so: int
+ many: float
+ args: str
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.so, self.many, self.args)
+
+ class ConcatModel(pydantic.BaseModel):
+ so: int
+ many: float
+ args: str
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.so, self.many, self.args)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
with self.assertRaises(SystemExit):
tapify(class_or_function, command_line_args=["--so", "23", "--many", "9.3"])
@@ -286,7 +500,28 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.so, self.few)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ so: int
+ few: float
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.so, self.few)
+
+ class ConcatModel(pydantic.BaseModel):
+ so: int
+ few: float
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.so, self.few)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
with self.assertRaises(SystemExit):
tapify(class_or_function, command_line_args=["--so", "23", "--few", "9.3", "--args", "wow"])
@@ -309,7 +544,28 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.so, self.few)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ so: int
+ few: float
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.so, self.few)
+
+ class ConcatModel(pydantic.BaseModel):
+ so: int
+ few: float
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.so, self.few)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output = tapify(
class_or_function, command_line_args=["--so", "23", "--few", "9.3", "--args", "wow"], known_only=True
)
@@ -339,10 +595,46 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.i, self.like, self.k, self.w, self.args, self.always)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ i: int
+ like: float
+ k: int
+ w: str = "w"
+ args: str = "argy"
+ always: bool = False
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.i, self.like, self.k, self.w, self.args, self.always)
+
+ class ConcatModel(pydantic.BaseModel):
+ i: int
+ like: float
+ k: int
+ w: str = "w"
+ args: str = "argy"
+ always: bool = False
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.i, self.like, self.k, self.w, self.args, self.always)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output = tapify(
class_or_function,
- command_line_args=["--i", "23", "--args", "wow", "--like", "3.03",],
+ command_line_args=[
+ "--i",
+ "23",
+ "--args",
+ "wow",
+ "--like",
+ "3.03",
+ ],
known_only=True,
w="hello",
k=5,
@@ -375,11 +667,47 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.i, self.like, self.k, self.w, self.args, self.always)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ i: int
+ like: float
+ k: int
+ w: str = "w"
+ args: str = "argy"
+ always: bool = False
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.i, self.like, self.k, self.w, self.args, self.always)
+
+ class ConcatModel(pydantic.BaseModel):
+ i: int
+ like: float
+ k: int
+ w: str = "w"
+ args: str = "argy"
+ always: bool = False
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.i, self.like, self.k, self.w, self.args, self.always)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
with self.assertRaises(ValueError):
tapify(
class_or_function,
- command_line_args=["--i", "23", "--args", "wow", "--like", "3.03",],
+ command_line_args=[
+ "--i",
+ "23",
+ "--args",
+ "wow",
+ "--like",
+ "3.03",
+ ],
w="hello",
k=5,
like=3.4,
@@ -404,7 +732,34 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.problems)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass(config=dict(arbitrary_types_allowed=True))
+ class ConcatDataclassPydantic:
+ problems: Problems
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.problems)
+
+ class ConcatModel(pydantic.BaseModel):
+ if _IS_PYDANTIC_V1:
+
+ class Config:
+ arbitrary_types_allowed = True # for Problems
+
+ else:
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) # for Problems
+
+ problems: Problems
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.problems)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output = tapify(class_or_function, command_line_args=[], problems=Problems("oh", "no!"))
self.assertEqual(output, "Problems(oh, no!)")
@@ -448,7 +803,40 @@ def __eq__(self, other: str) -> bool:
self.untyped_1, self.typed_1, self.untyped_2, self.typed_2, self.untyped_3, self.typed_3
)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ untyped_1: Any
+ typed_1: int
+ untyped_2: Any = 5
+ typed_2: str = "now"
+ untyped_3: Any = "hi"
+ typed_3: bool = False
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(
+ self.untyped_1, self.typed_1, self.untyped_2, self.typed_2, self.untyped_3, self.typed_3
+ )
+
+ class ConcatModel(pydantic.BaseModel):
+ untyped_1: Any
+ typed_1: int
+ untyped_2: Any = 5
+ typed_2: str = "now"
+ untyped_3: Any = "hi"
+ typed_3: bool = False
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(
+ self.untyped_1, self.typed_1, self.untyped_2, self.typed_2, self.untyped_3, self.typed_3
+ )
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output = tapify(
class_or_function,
command_line_args=[
@@ -491,7 +879,33 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.a, self.b, self.c)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ """Concatenate three numbers."""
+
+ a: int
+ b: int
+ c: int
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.a, self.b, self.c)
+
+ class ConcatModel(pydantic.BaseModel):
+ """Concatenate three numbers."""
+
+ a: int
+ b: int
+ c: int
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.a, self.b, self.c)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
output_1 = tapify(class_or_function, command_line_args=["--a", "1", "--b", "2", "--c", "3"])
output_2 = tapify(class_or_function, command_line_args=["--a", "4", "--b", "5", "--c", "6"])
@@ -555,16 +969,54 @@ class ConcatDataclass:
def __eq__(self, other: str) -> bool:
return other == concat(self.a, self.b, self.c)
- for class_or_function in [concat, Concat, ConcatDataclass]:
+ if _IS_PYDANTIC_V1 is not None:
+
+ @pydantic.dataclasses.dataclass
+ class ConcatDataclassPydantic:
+ """Concatenate three numbers.
+
+ :param a: The first number.
+ :param b: The second number.
+ :param c: The third number.
+ """
+
+ a: int
+ b: int
+ c: int
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.a, self.b, self.c)
+
+ class ConcatModel(pydantic.BaseModel):
+ """Concatenate three numbers.
+
+ :param a: The first number.
+ :param b: The second number.
+ :param c: The third number.
+ """
+
+ a: int
+ b: int
+ c: int
+
+ def __eq__(self, other: str) -> bool:
+ return other == concat(self.a, self.b, self.c)
+
+ pydantic_data_models = [ConcatDataclassPydantic, ConcatModel]
+ else:
+ pydantic_data_models = []
+
+ for class_or_function in [concat, Concat, ConcatDataclass] + pydantic_data_models:
f = io.StringIO()
with contextlib.redirect_stdout(f):
with self.assertRaises(SystemExit):
tapify(class_or_function, command_line_args=["-h"])
- self.assertIn("Concatenate three numbers.", f.getvalue())
- self.assertIn("--a A (int, required) The first number.", f.getvalue())
- self.assertIn("--b B (int, required) The second number.", f.getvalue())
- self.assertIn("--c C (int, required) The third number.", f.getvalue())
+ stdout = f.getvalue()
+ self.assertIn("Concatenate three numbers.", stdout)
+ self.assertIn("--a A (int, required) The first number.", stdout)
+ self.assertIn("--b B (int, required) The second number.", stdout)
+ self.assertIn("--c C (int, required) The third number.", stdout)
class TestTapifyExplicitBool(unittest.TestCase):
@@ -630,7 +1082,45 @@ def concat(a: int, b: int = 2, **kwargs) -> str:
"""
return f'{a}_{b}_{"-".join(f"{k}={v}" for k, v in kwargs.items())}'
- self.concat_function = concat
+ if _IS_PYDANTIC_V1 is not None:
+
+ class ConcatModel(pydantic.BaseModel):
+ """Concatenate three numbers.
+
+ :param a: The first number.
+ :param b: The second number.
+ """
+
+ if _IS_PYDANTIC_V1:
+
+ class Config:
+ extra = pydantic.Extra.allow # by default, pydantic ignores extra arguments
+
+ else:
+ model_config = pydantic.ConfigDict(extra="allow") # by default, pydantic ignores extra arguments
+
+ a: int
+ b: int = 2
+
+ def __eq__(self, other: str) -> bool:
+ if _IS_PYDANTIC_V1:
+ # Get the kwarg names in the correct order by parsing other
+ kwargs_str = other.split("_")[-1]
+ if not kwargs_str:
+ kwarg_names = []
+ else:
+ kwarg_names = [kv_str.split("=")[0] for kv_str in kwargs_str.split("-")]
+ kwargs = {name: getattr(self, name) for name in kwarg_names}
+ # Need to explictly check that the extra names from other are identical to what's stored in self
+ # Checking other == concat(...) isn't sufficient b/c self could have more extra fields
+ assert set(kwarg_names) == set(self.__dict__.keys()) - set(self.__fields__.keys())
+ else:
+ kwargs = self.model_extra
+ return other == concat(self.a, self.b, **kwargs)
+
+ pydantic_data_models = [ConcatModel]
+ else:
+ pydantic_data_models = []
class Concat:
def __init__(self, a: int, b: int = 2, **kwargs: Dict[str, str]):
@@ -646,22 +1136,22 @@ def __init__(self, a: int, b: int = 2, **kwargs: Dict[str, str]):
def __eq__(self, other: str) -> bool:
return other == concat(self.a, self.b, **self.kwargs)
- self.concat_class = Concat
+ self.class_or_functions = [concat, Concat] + pydantic_data_models
def test_tapify_empty_kwargs(self) -> None:
- for class_or_function in [self.concat_function, self.concat_class]:
+ for class_or_function in self.class_or_functions:
output = tapify(class_or_function, command_line_args=["--a", "1"])
self.assertEqual(output, "1_2_")
def test_tapify_has_kwargs(self) -> None:
- for class_or_function in [self.concat_function, self.concat_class]:
+ for class_or_function in self.class_or_functions:
output = tapify(class_or_function, command_line_args=["--a", "1", "--c", "3", "--d", "4"])
self.assertEqual(output, "1_2_c=3-d=4")
def test_tapify_has_kwargs_replace_default(self) -> None:
- for class_or_function in [self.concat_function, self.concat_class]:
+ for class_or_function in self.class_or_functions:
output = tapify(class_or_function, command_line_args=["--a", "1", "--c", "3", "--b", "5", "--d", "4"])
self.assertEqual(output, "1_5_c=3-d=4")
diff --git a/tests/test_to_tap_class.py b/tests/test_to_tap_class.py
new file mode 100644
index 0000000..730e4ed
--- /dev/null
+++ b/tests/test_to_tap_class.py
@@ -0,0 +1,553 @@
+"""
+Tests `tap.to_tap_class`.
+"""
+
+from contextlib import redirect_stdout, redirect_stderr
+import dataclasses
+import io
+import re
+import sys
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
+
+import pytest
+
+from tap import to_tap_class, Tap
+from tap.utils import type_to_str
+
+
+try:
+ import pydantic
+except ModuleNotFoundError:
+ _IS_PYDANTIC_V1 = None
+else:
+ _IS_PYDANTIC_V1 = pydantic.__version__ < "2.0.0"
+
+
+# To properly test the help message, we need to know how argparse formats it. It changed from 3.8 -> 3.9 -> 3.10
+_OPTIONS_TITLE = "options" if not sys.version_info < (3, 10) else "optional arguments"
+_ARG_LIST_DOTS = "..." if not sys.version_info < (3, 9) else "[ARG_LIST ...]"
+
+
+@dataclasses.dataclass
+class _Args:
+ """
+ These are the arguments which every type of class or function must contain.
+ """
+
+ arg_int: int = dataclasses.field(metadata=dict(description="some integer"))
+ arg_bool: bool = True
+ arg_list: Optional[List[str]] = dataclasses.field(default=None, metadata=dict(description="some list of strings"))
+
+
+def _monkeypatch_eq(cls):
+ """
+ Monkey-patches `cls.__eq__` to check that the attribute values are equal to a dataclass representation of them.
+ """
+
+ def _equality(self, other: _Args) -> bool:
+ return _Args(self.arg_int, arg_bool=self.arg_bool, arg_list=self.arg_list) == other
+
+ cls.__eq__ = _equality
+ return cls
+
+
+# Define a few different classes or functions which all take the same arguments (same by name, annotation, and default
+# if not required)
+
+
+def function(arg_int: int, arg_bool: bool = True, arg_list: Optional[List[str]] = None) -> _Args:
+ """
+ :param arg_int: some integer
+ :param arg_list: some list of strings
+ """
+ return _Args(arg_int, arg_bool=arg_bool, arg_list=arg_list)
+
+
+@_monkeypatch_eq
+class Class:
+ def __init__(self, arg_int: int, arg_bool: bool = True, arg_list: Optional[List[str]] = None):
+ """
+ :param arg_int: some integer
+ :param arg_list: some list of strings
+ """
+ self.arg_int = arg_int
+ self.arg_bool = arg_bool
+ self.arg_list = arg_list
+
+
+DataclassBuiltin = _Args
+
+
+if _IS_PYDANTIC_V1 is None:
+ pass # will raise NameError if attempting to use DataclassPydantic or Model later
+elif _IS_PYDANTIC_V1:
+ # For Pydantic v1 data models, we rely on the docstring to get descriptions
+
+ @_monkeypatch_eq
+ @pydantic.dataclasses.dataclass
+ class DataclassPydantic:
+ """
+ Dataclass (pydantic v1)
+
+ :param arg_int: some integer
+ :param arg_list: some list of strings
+ """
+
+ arg_int: int
+ arg_bool: bool = True
+ arg_list: Optional[List[str]] = None
+
+ @_monkeypatch_eq
+ class Model(pydantic.BaseModel):
+ """
+ Pydantic model (pydantic v1)
+
+ :param arg_int: some integer
+ :param arg_list: some list of strings
+ """
+
+ arg_int: int
+ arg_bool: bool = True
+ arg_list: Optional[List[str]] = None
+
+else:
+ # For pydantic v2 data models, we check the docstring and Field for the description
+
+ @_monkeypatch_eq
+ @pydantic.dataclasses.dataclass
+ class DataclassPydantic:
+ """
+ Dataclass (pydantic)
+
+ :param arg_list: some list of strings
+ """
+
+ # Mixing field types should be ok
+ arg_int: int = pydantic.dataclasses.Field(description="some integer")
+ arg_bool: bool = dataclasses.field(default=True)
+ arg_list: Optional[List[str]] = pydantic.Field(default=None)
+
+ @_monkeypatch_eq
+ class Model(pydantic.BaseModel):
+ """
+ Pydantic model
+
+ :param arg_int: some integer
+ """
+
+ # Mixing field types should be ok
+ arg_int: int
+ arg_bool: bool = dataclasses.field(default=True)
+ arg_list: Optional[List[str]] = pydantic.dataclasses.Field(default=None, description="some list of strings")
+
+
+@pytest.fixture(
+ scope="module",
+ params=[
+ function,
+ Class,
+ DataclassBuiltin,
+ DataclassBuiltin(
+ 1, arg_bool=False, arg_list=["these", "values", "don't", "matter"]
+ ), # to_tap_class also works on instances of data models. It ignores the attribute values
+ ]
+ + ([] if _IS_PYDANTIC_V1 is None else [DataclassPydantic, Model]),
+ # NOTE: instances of DataclassPydantic and Model can be tested for pydantic v2 but not v1
+)
+def class_or_function_(request: pytest.FixtureRequest):
+ """
+ Parametrized class_or_function.
+ """
+ return request.param
+
+
+# Define some functions which take a class or function and calls `tap.to_tap_class` on it to create a `tap.Tap`
+# subclass (class, not instance)
+
+
+def subclasser_simple(class_or_function: Any) -> Type[Tap]:
+ """
+ Plain subclass, does nothing extra.
+ """
+ return to_tap_class(class_or_function)
+
+
+def subclasser_complex(class_or_function):
+ """
+ It's conceivable that someone has a data model, but they want to add more arguments or handling when running a
+ script.
+ """
+
+ def to_number(string: str) -> Union[float, int]:
+ return float(string) if "." in string else int(string)
+
+ class TapSubclass(to_tap_class(class_or_function)):
+ # You can supply additional arguments here
+ argument_with_really_long_name: Union[float, int] = 3
+ "This argument has a long name and will be aliased with a short one"
+
+ def configure(self) -> None:
+ # You can still add special argument behavior
+ self.add_argument("-arg", "--argument_with_really_long_name", type=to_number)
+
+ def process_args(self) -> None:
+ # You can still validate and modify arguments
+ if self.argument_with_really_long_name > 4:
+ raise ValueError("argument_with_really_long_name cannot be > 4")
+
+ # No auto-complete (and other niceties) for the super class attributes b/c this is a dynamic subclass. Sorry
+ if self.arg_bool and self.arg_list is not None:
+ self.arg_list.append("processed")
+
+ return TapSubclass
+
+
+def subclasser_subparser(class_or_function):
+ class SubparserA(Tap):
+ bar: int # bar help
+
+ class SubparserB(Tap):
+ baz: Literal["X", "Y", "Z"] # baz help
+
+ class TapSubclass(to_tap_class(class_or_function)):
+ foo: bool = False # foo help
+
+ def configure(self):
+ self.add_subparsers(help="sub-command help")
+ self.add_subparser("a", SubparserA, help="a help", description="Description (a)")
+ self.add_subparser("b", SubparserB, help="b help")
+
+ return TapSubclass
+
+
+# Test that the subclasser parses the args correctly or raises the correct error.
+# The subclassers are tested separately b/c the parametrizaiton of args_string_and_arg_to_expected_value depends on the
+# subclasser.
+# First, some helper functions.
+
+
+def _test_raises_system_exit(tap: Tap, args_string: str) -> str:
+ is_help = (
+ args_string.endswith("-h")
+ or args_string.endswith("--help")
+ or " -h " in args_string
+ or " --help " in args_string
+ )
+ f = io.StringIO()
+ with redirect_stdout(f) if is_help else redirect_stderr(f):
+ with pytest.raises(SystemExit):
+ tap.parse_args(args_string.split())
+
+ return f.getvalue()
+
+
+def _test_subclasser(
+ subclasser: Callable[[Any], Type[Tap]],
+ class_or_function: Any,
+ args_string_and_arg_to_expected_value: Tuple[str, Union[Dict[str, Any], BaseException]],
+ test_call: bool = True,
+):
+ """
+ Tests that the `subclasser` converts `class_or_function` to a `Tap` class which parses the argument string
+ correctly.
+
+ Setting `test_call=True` additionally tests that calling the `class_or_function` on the parsed arguments works.
+ """
+ args_string, arg_to_expected_value = args_string_and_arg_to_expected_value
+ TapSubclass = subclasser(class_or_function)
+ assert issubclass(TapSubclass, Tap)
+ tap = TapSubclass(description="Script description")
+
+ if isinstance(arg_to_expected_value, SystemExit):
+ stderr = _test_raises_system_exit(tap, args_string)
+ assert str(arg_to_expected_value) in stderr
+ elif isinstance(arg_to_expected_value, BaseException):
+ expected_exception = arg_to_expected_value.__class__
+ expected_error_message = str(arg_to_expected_value) or None
+ with pytest.raises(expected_exception=expected_exception, match=expected_error_message):
+ args = tap.parse_args(args_string.split())
+ else:
+ # args_string is a valid argument combo
+ # Test that parsing works correctly
+ args = tap.parse_args(args_string.split())
+ assert arg_to_expected_value == args.as_dict()
+ if test_call and callable(class_or_function):
+ result = class_or_function(**args.as_dict())
+ assert result == _Args(**arg_to_expected_value)
+
+
+def _test_subclasser_message(
+ subclasser: Callable[[Any], Type[Tap]],
+ class_or_function: Any,
+ message_expected: str,
+ description: str = "Script description",
+ args_string: str = "-h",
+):
+ """
+ Tests that::
+
+ subclasser(class_or_function)(description=description).parse_args(args_string.split())
+
+ outputs `message_expected` to stdout, ignoring differences in whitespaces/newlines/tabs.
+ """
+
+ def replace_whitespace(string: str) -> str:
+ return re.sub(r"\s+", " ", string).strip() # FYI this line was written by an LLM
+
+ TapSubclass = subclasser(class_or_function)
+ tap = TapSubclass(description=description)
+ message = _test_raises_system_exit(tap, args_string)
+ # Standardize to ignore trivial differences due to terminal settings
+ assert replace_whitespace(message) == replace_whitespace(message_expected)
+
+
+# Test sublcasser_simple
+
+
+@pytest.mark.parametrize(
+ "args_string_and_arg_to_expected_value",
+ [
+ (
+ "--arg_int 1 --arg_list x y z",
+ {"arg_int": 1, "arg_bool": True, "arg_list": ["x", "y", "z"]},
+ ),
+ (
+ "--arg_int 1 --arg_bool",
+ {"arg_int": 1, "arg_bool": False, "arg_list": None},
+ ),
+ # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance
+ (
+ "--arg_list x y z --arg_bool",
+ SystemExit("error: the following arguments are required: --arg_int"),
+ ),
+ (
+ "--arg_int not_an_int --arg_list x y z --arg_bool",
+ SystemExit("error: argument --arg_int: invalid int value: 'not_an_int'"),
+ ),
+ ],
+)
+def test_subclasser_simple(
+ class_or_function_: Any, args_string_and_arg_to_expected_value: Tuple[str, Union[Dict[str, Any], BaseException]]
+):
+ _test_subclasser(subclasser_simple, class_or_function_, args_string_and_arg_to_expected_value)
+
+
+# @pytest.mark.skipif(sys.version_info < (3, 10), reason="argparse is different. Need to fix help_message_expected")
+def test_subclasser_simple_help_message(class_or_function_: Any):
+ description = "Script description"
+ help_message_expected = f"""
+usage: pytest --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h]
+
+{description}
+
+{_OPTIONS_TITLE}:
+ --arg_int ARG_INT (int, required) some integer
+ --arg_bool (bool, default=True)
+ --arg_list [ARG_LIST {_ARG_LIST_DOTS}]
+ ({type_to_str(Optional[List[str]])}, default=None) some list of strings
+ -h, --help show this help message and exit
+""".lstrip(
+ "\n"
+ )
+ _test_subclasser_message(subclasser_simple, class_or_function_, help_message_expected, description=description)
+
+
+# Test subclasser_complex
+
+
+@pytest.mark.parametrize(
+ "args_string_and_arg_to_expected_value",
+ [
+ (
+ "--arg_int 1 --arg_list x y z",
+ {
+ "arg_int": 1,
+ "arg_bool": True,
+ "arg_list": ["x", "y", "z", "processed"],
+ "argument_with_really_long_name": 3,
+ },
+ ),
+ (
+ "--arg_int 1 --arg_list x y z -arg 2",
+ {
+ "arg_int": 1,
+ "arg_bool": True,
+ "arg_list": ["x", "y", "z", "processed"],
+ "argument_with_really_long_name": 2,
+ },
+ ),
+ (
+ "--arg_int 1 --arg_bool --argument_with_really_long_name 2.3",
+ {
+ "arg_int": 1,
+ "arg_bool": False,
+ "arg_list": None,
+ "argument_with_really_long_name": 2.3,
+ },
+ ),
+ # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance
+ (
+ "--arg_list x y z --arg_bool",
+ SystemExit("error: the following arguments are required: --arg_int"),
+ ),
+ (
+ "--arg_int 1 --arg_list x y z -arg not_a_float_or_int",
+ SystemExit(
+ "error: argument -arg/--argument_with_really_long_name: invalid to_number value: 'not_a_float_or_int'"
+ ),
+ ),
+ (
+ "--arg_int 1 --arg_list x y z -arg 5", # Wrong value arg (aliases argument_with_really_long_name)
+ ValueError("argument_with_really_long_name cannot be > 4"),
+ ),
+ ],
+)
+def test_subclasser_complex(
+ class_or_function_: Any, args_string_and_arg_to_expected_value: Tuple[str, Union[Dict[str, Any], BaseException]]
+):
+ # Currently setting test_call=False b/c all data models except the pydantic Model don't accept extra args
+ _test_subclasser(subclasser_complex, class_or_function_, args_string_and_arg_to_expected_value, test_call=False)
+
+
+# @pytest.mark.skipif(sys.version_info < (3, 10), reason="argparse is different. Need to fix help_message_expected")
+def test_subclasser_complex_help_message(class_or_function_: Any):
+ description = "Script description"
+ help_message_expected = f"""
+usage: pytest [-arg ARGUMENT_WITH_REALLY_LONG_NAME] --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h]
+
+{description}
+
+{_OPTIONS_TITLE}:
+ -arg ARGUMENT_WITH_REALLY_LONG_NAME, --argument_with_really_long_name ARGUMENT_WITH_REALLY_LONG_NAME
+ (Union[float, int], default=3) This argument has a long name and will be aliased with a short one
+ --arg_int ARG_INT (int, required) some integer
+ --arg_bool (bool, default=True)
+ --arg_list [ARG_LIST {_ARG_LIST_DOTS}]
+ ({type_to_str(Optional[List[str]])}, default=None) some list of strings
+ -h, --help show this help message and exit
+""".lstrip(
+ "\n"
+ )
+ _test_subclasser_message(subclasser_complex, class_or_function_, help_message_expected, description=description)
+
+
+# Test subclasser_subparser
+
+
+@pytest.mark.parametrize(
+ "args_string_and_arg_to_expected_value",
+ [
+ (
+ "--arg_int 1",
+ {"arg_int": 1, "arg_bool": True, "arg_list": None, "foo": False},
+ ),
+ (
+ "--arg_int 1 a --bar 2",
+ {"arg_int": 1, "arg_bool": True, "arg_list": None, "bar": 2, "foo": False},
+ ),
+ (
+ "--arg_int 1 --foo a --bar 2",
+ {"arg_int": 1, "arg_bool": True, "arg_list": None, "bar": 2, "foo": True},
+ ),
+ (
+ "--arg_int 1 b --baz X",
+ {"arg_int": 1, "arg_bool": True, "arg_list": None, "baz": "X", "foo": False},
+ ),
+ (
+ "--foo --arg_bool --arg_list x y z --arg_int 1 b --baz Y",
+ {"arg_int": 1, "arg_bool": False, "arg_list": ["x", "y", "z"], "baz": "Y", "foo": True},
+ ),
+ # The rest are invalid argument combos, as indicated by the 2nd elt being a BaseException instance
+ (
+ "a --bar 1",
+ SystemExit("error: the following arguments are required: --arg_int"),
+ ),
+ (
+ "--arg_int not_an_int a --bar 1",
+ SystemExit("error: argument --arg_int: invalid int value: 'not_an_int'"),
+ ),
+ (
+ "--arg_int 1 --baz X --foo b",
+ SystemExit(
+ "error: argument {a,b}: invalid choice: 'X' (choose from 'a', 'b')"
+ if sys.version_info >= (3, 9)
+ else "error: invalid choice: 'X' (choose from 'a', 'b')"
+ ),
+ ),
+ (
+ "--arg_int 1 b --baz X --foo",
+ SystemExit("error: unrecognized arguments: --foo"),
+ ),
+ (
+ "--arg_int 1 --foo b --baz A",
+ SystemExit("""error: argument --baz: Value for variable "baz" must be one of ['X', 'Y', 'Z']."""),
+ ),
+ ],
+)
+def test_subclasser_subparser(
+ class_or_function_: Any, args_string_and_arg_to_expected_value: Tuple[str, Union[Dict[str, Any], BaseException]]
+):
+ # Currently setting test_call=False b/c all data models except the pydantic Model don't accept extra args
+ _test_subclasser(subclasser_subparser, class_or_function_, args_string_and_arg_to_expected_value, test_call=False)
+
+
+# @pytest.mark.skipif(sys.version_info < (3, 10), reason="argparse is different. Need to fix help_message_expected")
+@pytest.mark.parametrize(
+ "args_string_and_description_and_expected_message",
+ [
+ (
+ "-h",
+ "Script description",
+ # foo help likely missing b/c class nesting. In a demo in a Python 3.8 env, foo help appears in -h
+ f"""
+usage: pytest [--foo] --arg_int ARG_INT [--arg_bool] [--arg_list [ARG_LIST {_ARG_LIST_DOTS}]] [-h] {{a,b}} ...
+
+Script description
+
+positional arguments:
+ {{a,b}} sub-command help
+ a a help
+ b b help
+
+{_OPTIONS_TITLE}:
+ --foo (bool, default=False) {'' if sys.version_info < (3, 9) else 'foo help'}
+ --arg_int ARG_INT (int, required) some integer
+ --arg_bool (bool, default=True)
+ --arg_list [ARG_LIST {_ARG_LIST_DOTS}]
+ ({type_to_str(Optional[List[str]])}, default=None) some list of strings
+ -h, --help show this help message and exit
+""",
+ ),
+ (
+ "a -h",
+ "Description (a)",
+ f"""
+usage: pytest a --bar BAR [-h]
+
+Description (a)
+
+{_OPTIONS_TITLE}:
+ --bar BAR (int, required) bar help
+ -h, --help show this help message and exit
+""",
+ ),
+ (
+ "b -h",
+ "",
+ f"""
+usage: pytest b --baz {{X,Y,Z}} [-h]
+
+{_OPTIONS_TITLE}:
+ --baz {{X,Y,Z}} (Literal['X', 'Y', 'Z'], required) baz help
+ -h, --help show this help message and exit
+""",
+ ),
+ ],
+)
+def test_subclasser_subparser_help_message(
+ class_or_function_: Any, args_string_and_description_and_expected_message: Tuple[str, str]
+):
+ args_string, description, expected_message = args_string_and_description_and_expected_message
+ _test_subclasser_message(
+ subclasser_subparser, class_or_function_, expected_message, description=description, args_string=args_string
+ )