diff --git a/CHANGELOG.md b/CHANGELOG.md index 5007db3..c1db751 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ # Changelog -## [0.10.0] 06-06-2022 +## [0.10.1] 25-07-2022 -### NEW +### FIX -- Adding a new `Password` type that will hide the default value when printing help with the Rich Formatter. +- Fix parsing issue of `Optional` types. diff --git a/piou/utils.py b/piou/utils.py index 296c33a..727c9b5 100644 --- a/piou/utils.py +++ b/piou/utils.py @@ -12,7 +12,7 @@ from pathlib import Path from typing import ( Any, Optional, get_args, get_origin, get_type_hints, - Literal, TypeVar, Generic, Callable + Literal, TypeVar, Generic, Callable, Union ) from uuid import UUID @@ -29,36 +29,50 @@ class Password(str): T = TypeVar('T', str, int, float, dt.date, dt.datetime, Path, dict, list, Password) +def extract_optional_type(t: Any): + if get_origin(t) is Union: + types = tuple(x for x in get_args(t) if x is not type(None)) + return Union[types] # type: ignore + return t + + def convert_to_type(data_type: Any, value: str, *, case_sensitive: bool = True): """ Converts `value` to `data_type`, if not possible raises the appropriate error """ + _data_type = extract_optional_type(data_type) - if data_type is Any or data_type is bool: + if _data_type is Any or _data_type is bool: return value - elif data_type is str or data_type is Password: + elif _data_type is str or _data_type is Password: return str(value) - elif data_type is int: + elif _data_type is int: return int(value) - elif data_type is float: + elif _data_type is float: return float(value) - elif data_type is UUID: + elif _data_type is UUID: return UUID(value) - elif data_type is dt.date: + elif _data_type is dt.date: return dt.date.fromisoformat(value) - elif data_type is dt.datetime: + elif _data_type is dt.datetime: return dt.datetime.fromisoformat(value) - elif data_type is Path: + elif _data_type is Path: p = Path(value) if not p.exists(): raise FileNotFoundError(f'File not found: "{value}"') return p - elif data_type is dict: + elif _data_type is dict: return json.loads(value) - elif get_origin(data_type) is Literal: - possible_fields = get_args(data_type) + elif inspect.isclass(_data_type) and issubclass(_data_type, Enum): + return _data_type[value].value + elif _data_type is list or get_origin(_data_type) is list: + list_type = get_args(_data_type) + return [convert_to_type(list_type[0] if list_type else str, + x) for x in value.split(' ')] + elif get_origin(_data_type) is Literal: + possible_fields = get_args(_data_type) _possible_fields_case = possible_fields if not case_sensitive: _possible_fields_case = [x.lower() for x in possible_fields] + [ @@ -67,12 +81,10 @@ def convert_to_type(data_type: Any, value: str, possible_fields = ', '.join(possible_fields) raise ValueError(f'"{value}" is not a valid value for Literal[{possible_fields}]') return value - elif data_type is list or get_origin(data_type) is list: - list_type = get_args(data_type) + elif _data_type is list or get_origin(_data_type) is list: + list_type = get_args(_data_type) return [convert_to_type(list_type[0] if list_type else str, x) for x in value.split(' ')] - elif issubclass(data_type, Enum): - return data_type[value].value else: raise NotImplementedError(f'No parser implemented for data type "{data_type}"') diff --git a/pyproject.toml b/pyproject.toml index d1c19fd..7e303bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "piou" -version = "0.10.0" +version = "0.10.1" description = "A CLI toolkit" authors = ["Julien Brayere "] license = "MIT" diff --git a/tests/test_cli.py b/tests/test_cli.py index 9c9766b..cf59203 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,7 @@ import re from enum import Enum from pathlib import Path -from typing import Literal +from typing import Literal, Optional from uuid import UUID import pytest @@ -35,10 +35,24 @@ def test_command_option(cmd, is_required, is_positional): assert cmd.is_positional_arg == is_positional +@pytest.mark.parametrize('data_type, expected', [ + (str, str), + (Optional[str], str), + (list, list), + (list[str], list[str]), + (Optional[list[int]], list[int]), + ]) +def test_extract_optional_type(data_type, expected): + from piou.utils import extract_optional_type + + assert extract_optional_type(data_type) == expected + + @pytest.mark.parametrize('data_type, value, expected', [ (str, '123', '123'), (str, 'foo bar', 'foo bar'), (int, '123', 123), + (int, '123', 123), (float, '123', 123), (float, '0.123', 0.123), # (bytes, 'foo'.encode('utf-8'), b'foo'), @@ -56,6 +70,7 @@ def test_command_option(cmd, is_required, is_positional): def test_convert_to_type(data_type, value, expected): from piou.utils import convert_to_type assert convert_to_type(data_type, value) == expected + assert convert_to_type(Optional[data_type], value) == expected def testing_case_sensitivity():