Skip to content

Commit

Permalink
Docstring fix for inherited fields + test
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 27, 2021
1 parent 7ff0e00 commit 311a66b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 32 deletions.
74 changes: 43 additions & 31 deletions dcargs/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class _FieldData:


@dataclasses.dataclass
class _Tokenization:
class _ClassTokenization:
tokens: List[_Token]
tokens_from_line: Dict[int, List[_Token]]
field_data_from_name: Dict[str, _FieldData]

@staticmethod
@functools.lru_cache(maxsize=4)
def make(cls) -> "_Tokenization":
@functools.lru_cache(maxsize=8)
def make(cls) -> "_ClassTokenization":
"""Parse the source code of a class, and cache some tokenization information."""
readline = io.BytesIO(inspect.getsource(cls).encode("utf-8")).readline

Expand Down Expand Up @@ -66,34 +66,57 @@ def make(cls) -> "_Tokenization":
)
prev_field_line_number = token.line_number

return _Tokenization(
return _ClassTokenization(
tokens=tokens,
tokens_from_line=tokens_from_line,
field_data_from_name=field_data_from_name,
)


def get_class_tokenization_with_field(
cls: Type, field_name: str
) -> Optional[_ClassTokenization]:
# Search for token in this class + all parents.
found_field: bool = False
classes_to_search = cls.mro()
for search_cls in classes_to_search:
# Unwrap generics.
origin_cls = get_origin(search_cls)
if origin_cls is not None:
search_cls = origin_cls

# Skip parent classes that aren't dataclasses.
if not dataclasses.is_dataclass(search_cls):
continue

try:
tokenization = _ClassTokenization.make(search_cls) # type: ignore
except OSError as e:
# Dynamic dataclasses will result in an OSError -- this is fine, we just assume
# there's no docstring.
assert "could not find class definition" in e.args[0]
return None

# Grab field-specific tokenization data.
if field_name in tokenization.field_data_from_name:
found_field = True
break

assert (
found_field
), "Docstring parsing error -- this usually means that there are multiple \
dataclasses in the same file with the same name but different scopes."

return tokenization


def get_field_docstring(cls: Type, field_name: str) -> Optional[str]:
"""Get docstring for a field in a class."""

origin_cls = get_origin(cls)
if origin_cls is not None:
cls = origin_cls

assert dataclasses.is_dataclass(cls)
try:
tokenization = _Tokenization.make(cls) # type: ignore
except OSError as e:
# Dynamic dataclasses will result in an OSError -- this is fine, we just assume
# there's no docstring.
assert "could not find class definition" in e.args[0]
tokenization = get_class_tokenization_with_field(cls, field_name)
if tokenization is None: # Currently only happens for dynamic dataclasses.
return None

# Grab field-specific tokenization data.
assert (
field_name in tokenization.field_data_from_name
), "Docstring parsing error -- this usually means that there are multiple \
dataclasses in the same file with the same name but different scopes."
field_data = tokenization.field_data_from_name[field_name]

# Check for docstring-style comment.
Expand Down Expand Up @@ -126,17 +149,6 @@ def get_field_docstring(cls: Type, field_name: str) -> Optional[str]:
break

line_number += 1
# if (
# field_data.line_number + 1 in tokenization.tokens_from_line
# and len(tokenization.tokens_from_line[field_data.line_number + 1]) > 0
# ):
# first_token_on_next_line = tokenization.tokens_from_line[
# field_data.line_number + 1
# ][0]
# if first_token_on_next_line.token_type == tokenize.STRING:
# docstring = first_token_on_next_line.token.strip()
# assert docstring.endswith('"""') and docstring.startswith('"""')
# return _strings.dedent(docstring[3:-3])

# Check for comment on the same line as the field.
final_token_on_line = tokenization.tokens_from_line[field_data.line_number][-1]
Expand Down
4 changes: 3 additions & 1 deletion dcargs/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ class DataclassDumper(yaml.Dumper):

contained_types = list(_get_contained_special_types_from_instance(instance))
contained_type_names = list(map(lambda cls: cls.__name__, contained_types))

# Note: this is currently a stricter than necessary assert.
assert len(set(contained_type_names)) == len(
contained_type_names
), f"Contained dataclass type names must all be unique, but got {contained_type_names}"
), f"Contained dataclass/enum names must all be unique, but got {contained_type_names}"

dumper: yaml.Dumper
data: Any
Expand Down
55 changes: 55 additions & 0 deletions tests/test_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,61 @@ class HelptextHardString:
)


def test_helptext_with_inheritance():
@dataclasses.dataclass
class Parent:
# fmt: off
x: str = (
"This docstring may be tougher to parse!"
)
"""Helptext."""
# fmt: on

@dataclasses.dataclass
class Child(Parent):
pass

f = io.StringIO()
with pytest.raises(SystemExit):
with contextlib.redirect_stdout(f):
dcargs.parse(Child, args=["--help"])
helptext = f.getvalue()
assert (
"--x STR Helptext. (default: This docstring may be tougher to parse!)\n"
in helptext
)


def test_helptext_with_inheritance_overriden():
@dataclasses.dataclass
class Parent2:
# fmt: off
x: str = (
"This docstring may be tougher to parse!"
)
"""Helptext."""
# fmt: on

@dataclasses.dataclass
class Child2(Parent2):
# fmt: off
x: str = (
"This docstring may be tougher to parse?"
)
"""Helptext."""
# fmt: on

f = io.StringIO()
with pytest.raises(SystemExit):
with contextlib.redirect_stdout(f):
dcargs.parse(Child2, args=["--help"])
helptext = f.getvalue()
assert (
"--x STR Helptext. (default: This docstring may be tougher to parse?)\n"
in helptext
)


def test_tuple_helptext():
@dataclasses.dataclass
class TupleHelptext:
Expand Down

0 comments on commit 311a66b

Please sign in to comment.