Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug with attribute defaults and add default_factory parameter #649

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions src/fundus/parser/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,35 @@ def __repr__(self):


class Attribute(RegisteredFunction):
def __init__(self, func: Callable[[object], Any], priority: Optional[int], validate: bool):
def __init__(
self,
func: Callable[[object], Any],
priority: Optional[int],
validate: bool,
default_factory: Optional[Callable[[], Any]],
):
self.validate = validate
self.default_factory = default_factory
super(Attribute, self).__init__(func=func, priority=priority)

@functools.cached_property
def __default__(self):
if self.default_factory is not None:
return self.default_factory()

annotation = self.__annotations__["return"]
origin = get_origin(annotation)
args = get_args(annotation)

if not (origin or args):
try:
default = annotation()
except TypeError:
default = None
elif callable(origin):
default = origin()
default = annotation()
elif origin == Union:
if type(None) in args:
default = None
else:
raise NotImplementedError(f"Unsupported args {args}")
raise NotImplementedError(f"Cannot determine default for {origin!r} with args {args!r}")
elif isinstance(origin, type):
default = origin()
else:
raise NotImplementedError(f"Unsupported origin {origin}")
return default
Expand All @@ -122,8 +129,15 @@ def wrapper(func):
return wrapper(cls)


def attribute(cls=None, /, *, priority: Optional[int] = None, validate: bool = True):
return _register(cls, factory=Attribute, priority=priority, validate=validate)
def attribute(
cls=None,
/,
*,
priority: Optional[int] = None,
validate: bool = True,
default_factory: Optional[Callable[[], Any]] = None,
):
return _register(cls, factory=Attribute, priority=priority, validate=validate, default_factory=default_factory)


def function(cls=None, /, *, priority: Optional[int] = None):
Expand Down Expand Up @@ -232,6 +246,10 @@ def parse(self, html: str, error_handling: Literal["suppress", "catch", "raise"]
except Exception as err:
if error_handling == "suppress":
parsed_data[attribute_name] = func.__default__
logger.info(
f"Couldn't parse attribute {attribute_name!r} for "
f"{self.precomputed.meta.get('og:url')!r}: {err}"
)
elif error_handling == "catch":
parsed_data[attribute_name] = err
elif error_handling == "raise":
Expand Down
2 changes: 1 addition & 1 deletion src/fundus/scraping/article.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __getattr__(self, item: str):

@property
def plaintext(self) -> Optional[str]:
return str(self.body) or None
return str(self.body) or None if not isinstance(self.body, Exception) else None

@property
def lang(self) -> Optional[str]:
Expand Down
50 changes: 49 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import List
from typing import Any, Dict, List, Optional, Tuple, Union

import lxml.html
import pytest
Expand Down Expand Up @@ -75,6 +75,54 @@ def unvalidated(self) -> str:
assert (funcs := list(unvalidated)) != [parser.unvalidated]
assert funcs[0].__func__ == parser.unvalidated.__func__

def test_default_values_for_attributes(self):
class Parser(BaseParser):
@attribute
def test_optional(self) -> Optional[str]:
raise Exception

@attribute
def test_collection(self) -> Tuple[str, ...]:
raise Exception

@attribute
def test_nested_collection(self) -> List[Tuple[str, str]]:
raise Exception

@attribute(default_factory=lambda: "This is a default")
def test_default_factory(self) -> Union[str, bool]:
raise Exception

@attribute
def test_boolean(self) -> bool:
raise Exception

parser = Parser()

default_values = {attr.__name__: attr.__default__ for attr in parser.attributes()}

expected_values: Dict[str, Any] = {
"test_optional": None,
"test_collection": tuple(),
"test_nested_collection": list(),
"test_default_factory": "This is a default",
"test_boolean": False,
"free_access": False,
}

for name, value in default_values.items():
assert value == expected_values[name]

class ParserWithUnion(BaseParser):
@attribute
def this_should_fail(self) -> Union[str, bool]:
raise Exception

parser_with_union = ParserWithUnion()

with pytest.raises(NotImplementedError):
default_values = {attr.__name__: attr.__default__ for attr in parser_with_union.attributes()}


class TestParserProxy:
def test_empty_proxy(self, empty_parser_proxy):
Expand Down