Skip to content

Commit

Permalink
fix: issue when parsing optional field
Browse files Browse the repository at this point in the history
  • Loading branch information
Andarius committed Jul 25, 2022
1 parent b254d4e commit be8f26c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 21 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Changelog

## [0.10.0] 06-06-2022
## [0.10.1] 25-07-2022

### NEW
### FIX

- Adding a new `Password` type that will hide the default value when printing help with the Rich Formatter.
- Fix parsing issue of `Optional` types.
44 changes: 28 additions & 16 deletions piou/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pathlib import Path
from typing import (
Any, Optional, get_args, get_origin, get_type_hints,
Literal, TypeVar, Generic, Callable
Literal, TypeVar, Generic, Callable, Union
)
from uuid import UUID

Expand All @@ -29,36 +29,50 @@ class Password(str):
T = TypeVar('T', str, int, float, dt.date, dt.datetime, Path, dict, list, Password)


def extract_optional_type(t: Any):
if get_origin(t) is Union:
types = tuple(x for x in get_args(t) if x is not type(None))
return Union[types] # type: ignore
return t


def convert_to_type(data_type: Any, value: str,
*,
case_sensitive: bool = True):
"""
Converts `value` to `data_type`, if not possible raises the appropriate error
"""
_data_type = extract_optional_type(data_type)

if data_type is Any or data_type is bool:
if _data_type is Any or _data_type is bool:
return value
elif data_type is str or data_type is Password:
elif _data_type is str or _data_type is Password:
return str(value)
elif data_type is int:
elif _data_type is int:
return int(value)
elif data_type is float:
elif _data_type is float:
return float(value)
elif data_type is UUID:
elif _data_type is UUID:
return UUID(value)
elif data_type is dt.date:
elif _data_type is dt.date:
return dt.date.fromisoformat(value)
elif data_type is dt.datetime:
elif _data_type is dt.datetime:
return dt.datetime.fromisoformat(value)
elif data_type is Path:
elif _data_type is Path:
p = Path(value)
if not p.exists():
raise FileNotFoundError(f'File not found: "{value}"')
return p
elif data_type is dict:
elif _data_type is dict:
return json.loads(value)
elif get_origin(data_type) is Literal:
possible_fields = get_args(data_type)
elif inspect.isclass(_data_type) and issubclass(_data_type, Enum):
return _data_type[value].value
elif _data_type is list or get_origin(_data_type) is list:
list_type = get_args(_data_type)
return [convert_to_type(list_type[0] if list_type else str,
x) for x in value.split(' ')]
elif get_origin(_data_type) is Literal:
possible_fields = get_args(_data_type)
_possible_fields_case = possible_fields
if not case_sensitive:
_possible_fields_case = [x.lower() for x in possible_fields] + [
Expand All @@ -67,12 +81,10 @@ def convert_to_type(data_type: Any, value: str,
possible_fields = ', '.join(possible_fields)
raise ValueError(f'"{value}" is not a valid value for Literal[{possible_fields}]')
return value
elif data_type is list or get_origin(data_type) is list:
list_type = get_args(data_type)
elif _data_type is list or get_origin(_data_type) is list:
list_type = get_args(_data_type)
return [convert_to_type(list_type[0] if list_type else str,
x) for x in value.split(' ')]
elif issubclass(data_type, Enum):
return data_type[value].value
else:
raise NotImplementedError(f'No parser implemented for data type "{data_type}"')

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "piou"
version = "0.10.0"
version = "0.10.1"
description = "A CLI toolkit"
authors = ["Julien Brayere <julien.brayere@gmail.com>"]
license = "MIT"
Expand Down
17 changes: 16 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from enum import Enum
from pathlib import Path
from typing import Literal
from typing import Literal, Optional
from uuid import UUID

import pytest
Expand Down Expand Up @@ -35,10 +35,24 @@ def test_command_option(cmd, is_required, is_positional):
assert cmd.is_positional_arg == is_positional


@pytest.mark.parametrize('data_type, expected', [
(str, str),
(Optional[str], str),
(list, list),
(list[str], list[str]),
(Optional[list[int]], list[int]),
])
def test_extract_optional_type(data_type, expected):
from piou.utils import extract_optional_type

assert extract_optional_type(data_type) == expected


@pytest.mark.parametrize('data_type, value, expected', [
(str, '123', '123'),
(str, 'foo bar', 'foo bar'),
(int, '123', 123),
(int, '123', 123),
(float, '123', 123),
(float, '0.123', 0.123),
# (bytes, 'foo'.encode('utf-8'), b'foo'),
Expand All @@ -56,6 +70,7 @@ def test_command_option(cmd, is_required, is_positional):
def test_convert_to_type(data_type, value, expected):
from piou.utils import convert_to_type
assert convert_to_type(data_type, value) == expected
assert convert_to_type(Optional[data_type], value) == expected


def testing_case_sensitivity():
Expand Down

0 comments on commit be8f26c

Please sign in to comment.