Skip to content

Commit

Permalink
[cdd/shared/parse/utils/parser_utils.py] Correctly use merge_params
Browse files Browse the repository at this point in the history
… and fix its impl to return ; [cdd/sqlalchemy/utils/shared_utils.py] Prepare for increased test coverage ; [cdd/tests/{test_parse/test_parser_utils.py,test_sqlalchemy/test_emit_sqlalchemy_utils.py}] Increase test coverage ; [cdd/__init__.py] Bump version
  • Loading branch information
SamuelMarks committed Mar 16, 2024
1 parent 4608569 commit 145c49a
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 60 deletions.
2 changes: 1 addition & 1 deletion cdd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from logging import getLogger as get_logger

__author__ = "Samuel Marks" # type: str
__version__ = "0.0.99rc43" # type: str
__version__ = "0.0.99rc44" # type: str
__description__ = (
"Open API to/fro routes, models, and tests. "
"Convert between docstrings, classes, methods, argparse, pydantic, and SQLalchemy."
Expand Down
6 changes: 2 additions & 4 deletions cdd/shared/parse/utils/parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ def ir_merge(target, other):
target["params"] = other["params"]
elif other["params"]:
target_params, other_params = map(itemgetter("params"), (target, other))

merge_params(other_params, target_params)

target["params"] = target_params
target["params"] = merge_params(other_params, target_params)

if "return_type" not in (target.get("returns") or iter(())):
target["returns"] = other["returns"]
Expand Down Expand Up @@ -110,6 +107,7 @@ def merge_params(other_params, target_params):
merge_present_params(other_params[name], target_params[name])
for name in other_params.keys() - target_params.keys():
target_params[name] = other_params[name]
return target_params


def merge_present_params(other_param, target_param):
Expand Down
88 changes: 50 additions & 38 deletions cdd/sqlalchemy/utils/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import ast
from ast import Call, Expr, Load, Name, Subscript, Tuple, keyword
from ast import Call, Expr, Load, Name, Subscript, Tuple, expr, keyword
from operator import attrgetter
from typing import Optional, cast

Expand Down Expand Up @@ -82,13 +82,14 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
return _param.get("default") == cdd.shared.ast_utils.NoneStr, None
elif _param["typ"].startswith("Optional["):
_param["typ"] = _param["typ"][len("Optional[") : -1]
nullable = True
nullable: bool = True
if "Literal[" in _param["typ"]:
parsed_typ: Call = cast(
Call, cdd.shared.ast_utils.get_value(ast.parse(_param["typ"]).body[0])
)
if parsed_typ.value.id != "Literal":
return nullable, parsed_typ.value
assert parsed_typ.value.id == "Literal", "Expected `Literal` got: {!r}".format(
parsed_typ.value.id
)
val = cdd.shared.ast_utils.get_value(parsed_typ.slice)
(
args.append(
Expand All @@ -112,7 +113,7 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
else _update_args_infer_typ_sqlalchemy_for_scalar(_param, args, x_typ_sql)
)
elif _param["typ"].startswith("List["):
after_generic = _param["typ"][len("List[") :]
after_generic: str = _param["typ"][len("List[") :]
if "struct" in after_generic: # "," in after_generic or
name: Name = Name(id="JSON", ctx=Load(), lineno=None, col_offset=None)
else:
Expand Down Expand Up @@ -175,42 +176,53 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
)
)
elif _param.get("typ").startswith("Union["):
# Hack to remove the union type. Enum parse seems to be incorrect?
union_typ: Subscript = cast(Subscript, ast.parse(_param["typ"]).body[0])
assert isinstance(
union_typ.value, Subscript
), "Expected `Subscript` got `{type_name}`".format(
type_name=type(union_typ.value).__name__
)
union_typ_tuple = (
union_typ.value.slice if PY_GTE_3_9 else union_typ.value.slice.value
)
assert isinstance(
union_typ_tuple, Tuple
), "Expected `Tuple` got `{type_name}`".format(
type_name=type(union_typ_tuple).__name__
)
assert (
len(union_typ_tuple.elts) == 2
), "Expected length of 2 got `{tuple_len}`".format(
tuple_len=len(union_typ_tuple.elts)
)
left, right = map(attrgetter("id"), union_typ_tuple.elts)
args.append(
Name(
(
cdd.sqlalchemy.utils.emit_utils.typ2column_type[right]
if right in cdd.sqlalchemy.utils.emit_utils.typ2column_type
else cdd.sqlalchemy.utils.emit_utils.typ2column_type.get(left, left)
),
Load(),
lineno=None,
col_offset=None,
)
)
args.append(_handle_union_of_length_2(_param["typ"]))
else:
_update_args_infer_typ_sqlalchemy_for_scalar(_param, args, x_typ_sql)
return nullable, None


def _handle_union_of_length_2(typ):
"""
Internal function to turn `str` to `Name`
:param typ: `str` which evaluates to `ast.Subscript`
:type typ: ```str```
:return: Parsed out name
:rtype: ```Name```
"""
# Hack to remove the union type. Enum parse seems to be incorrect?
union_typ: Subscript = cast(Subscript, ast.parse(typ).body[0])
assert isinstance(
union_typ.value, Subscript
), "Expected `Subscript` got `{type_name}`".format(
type_name=type(union_typ.value).__name__
)
union_typ_tuple: expr = (
union_typ.value.slice if PY_GTE_3_9 else union_typ.value.slice.value
)
assert isinstance(
union_typ_tuple, Tuple
), "Expected `Tuple` got `{type_name}`".format(
type_name=type(union_typ_tuple).__name__
)
assert (
len(union_typ_tuple.elts) == 2
), "Expected length of 2 got `{tuple_len}`".format(
tuple_len=len(union_typ_tuple.elts)
)
left, right = map(attrgetter("id"), union_typ_tuple.elts)
return Name(
(
cdd.sqlalchemy.utils.emit_utils.typ2column_type[right]
if right in cdd.sqlalchemy.utils.emit_utils.typ2column_type
else cdd.sqlalchemy.utils.emit_utils.typ2column_type.get(left, left)
),
Load(),
lineno=None,
col_offset=None,
)


__all__ = ["update_args_infer_typ_sqlalchemy"]
41 changes: 25 additions & 16 deletions cdd/tests/test_parse/test_parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@
class TestParserUtils(TestCase):
"""Test class for parser_utils"""

def test_get_source_raises(self) -> None:
"""Tests that `get_source` raises an exception"""
with self.assertRaises(TypeError):
get_source(None)

def raise_os_error(_):
"""raise_OSError"""
raise OSError

with patch("inspect.getsourcelines", raise_os_error), self.assertRaises(
OSError
):
get_source(min)

with patch("inspect.getsourcefile", lambda _: None):
self.assertIsNone(get_source(raise_os_error))

def test_ir_merge_empty(self) -> None:
"""Tests for `ir_merge` when both are empty"""
target = {"params": OrderedDict(), "returns": None}
Expand Down Expand Up @@ -250,22 +267,14 @@ def test_infer_raise(self) -> None:
with self.assertRaises(NotImplementedError):
cdd.shared.parse.utils.parser_utils.infer(None)

def test_get_source_raises(self) -> None:
"""Tests that `get_source` raises an exception"""
with self.assertRaises(TypeError):
get_source(None)

def raise_os_error(_):
"""raise_OSError"""
raise OSError

with patch("inspect.getsourcelines", raise_os_error), self.assertRaises(
OSError
):
get_source(min)

with patch("inspect.getsourcefile", lambda _: None):
self.assertIsNone(get_source(raise_os_error))
def test_merge_params(self) -> None:
"""Tests `merge_params` works"""
d0 = {"foo": "bar"}
d1 = {"can": "haz"}
self.assertDictEqual(
cdd.shared.parse.utils.parser_utils.merge_params(deepcopy(d0), d1),
{"foo": "bar", "can": "haz"},
)


unittest_main()
42 changes: 41 additions & 1 deletion cdd/tests/test_sqlalchemy/test_emit_sqlalchemy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ast
import json
from ast import (
AST,
Assign,
Call,
ClassDef,
Expand All @@ -19,8 +20,10 @@
)
from collections import OrderedDict
from copy import deepcopy
from functools import partial
from os import mkdir, path
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Tuple, Union
from unittest import TestCase
from unittest.mock import patch

Expand All @@ -29,7 +32,10 @@
from cdd.shared.ast_utils import set_value
from cdd.shared.source_transformer import to_code
from cdd.shared.types import IntermediateRepr
from cdd.sqlalchemy.utils.shared_utils import update_args_infer_typ_sqlalchemy
from cdd.sqlalchemy.utils.shared_utils import (
_handle_union_of_length_2,
update_args_infer_typ_sqlalchemy,
)
from cdd.tests.mocks.ir import (
intermediate_repr_empty,
intermediate_repr_no_default_doc,
Expand Down Expand Up @@ -296,6 +302,27 @@ def test_update_args_infer_typ_sqlalchemy_when_simple_array_in_typ(self) -> None
# gold=Name(id="Small", ctx=Load(), lineno=None, col_offset=None),
# )

def test_update_args_infer_typ_sqlalchemy_early_exit(self) -> None:
"""Tests that `update_args_infer_typ_sqlalchemy` exits early"""
_update_args_infer_typ_sqlalchemy: Callable[
[dict], Tuple[bool, Optional[Union[List[AST], Tuple[AST]]]]
] = partial(
update_args_infer_typ_sqlalchemy,
args=[],
name="",
nullable=True,
x_typ_sql={},
)
self.assertTupleEqual(
_update_args_infer_typ_sqlalchemy({"typ": None}), (False, None)
)
self.assertTupleEqual(
_update_args_infer_typ_sqlalchemy(
{"typ": None, "default": cdd.shared.ast_utils.NoneStr},
),
(True, None),
)

def test_update_with_imports_from_columns(self) -> None:
"""
Tests basic `cdd.sqlalchemy.utils.emit_utils.update_with_imports_from_columns` usage
Expand Down Expand Up @@ -573,5 +600,18 @@ def test_rewrite_fk(self) -> None:
gold=column_fk_gold,
)

def test__handle_union_of_length_2(self) -> None:
"""Tests that `_handle_union_of_length_2` works"""
run_ast_test(
self,
gen_ast=_handle_union_of_length_2("Union[int, float]"),
gold=Name(
"Float",
Load(),
lineno=None,
col_offset=None,
),
)


unittest_main()

0 comments on commit 145c49a

Please sign in to comment.