Skip to content

Commit

Permalink
fix: initialize doc with dataclass obj and kwargs (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-charlotte authored Oct 28, 2022
1 parent 0ceb397 commit 030f5b3
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
16 changes: 14 additions & 2 deletions docarray/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy as cp
import dataclasses
from dataclasses import fields
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, Tuple, Dict
Expand Down Expand Up @@ -39,7 +40,15 @@ def __init__(

if kwargs:
try:
self._data = self._data_class(self, **kwargs)
if self._data is not None:
if self._unresolved_fields_dest in kwargs.keys():
getattr(self, self._unresolved_fields_dest).update(
kwargs[self._unresolved_fields_dest]
)
kwargs.pop(self._unresolved_fields_dest)
self._data = dataclasses.replace(self._data, **kwargs)
else:
self._data = self._data_class(self, **kwargs)
except TypeError as ex:
if unknown_fields_handler == 'raise':
raise AttributeError(f'unknown attributes') from ex
Expand All @@ -58,7 +67,10 @@ def __init__(
for k in _unresolved:
kwargs.pop(k)

self._data = self._data_class(self, **kwargs)
if self._data is not None:
self._data = dataclasses.replace(self._data, **kwargs)
else:
self._data = self._data_class(self, **kwargs)

if _unknown_kwargs and unknown_fields_handler == 'catch':
getattr(self, self._unresolved_fields_dest).update(
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/document/test_multi_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,77 @@ def test_set_multimodal_nested(serialization, nested_mmdoc):
assert new_inner_list_doc in d.other_doc_list['@.[heading]']


def test_initialize_document_with_dataclass_and_additional_text_attr():
@dataclass
class MyDoc:
chunk_text: Text

d = Document(MyDoc(chunk_text='chunk level text'), text='top level text')

assert d.text == 'top level text'
assert d.chunk_text.text == 'chunk level text'


def test_initialize_document_with_dataclass_and_additional_unknown_attributes():
@dataclass
class MyDoc:
chunk_text: Text

d = Document(
MyDoc(chunk_text='chunk level text'),
hello='top level text',
)

assert d.tags['hello'] == 'top level text'
assert d.chunk_text.text == 'chunk level text'


def test_doc_with_dataclass_with_str_attr_and_additional_unknown_attribute():
@dataclass
class MyDoc:
name_mydoc: str

d = Document(MyDoc(name_mydoc='mydoc'), name_doc='doc')

assert d.tags['name_mydoc'] == 'mydoc'
assert d.tags['name_doc'] == 'doc'


def test_doc_with_dataclass_with_str_attr_and_additional_tags_arg():
@dataclass
class MyDoc:
name_mydoc: str

d = Document(MyDoc(name_mydoc='mydoc'), tags={'name_doc': 'doc'})

assert d.tags['name_mydoc'] == 'mydoc'
assert d.tags['name_doc'] == 'doc'


def test_doc_with_dataclass_with_str_and_additional_tags_arg_and_unknown_attribute():
@dataclass
class MyDoc:
name_mydoc: str

d = Document(
MyDoc(name_mydoc='mydoc'), tags={'name_doc': 'doc'}, something_else='hello'
)

assert d.tags['name_mydoc'] == 'mydoc'
assert d.tags['name_doc'] == 'doc'
assert d.tags['something_else'] == 'hello'


def test_doc_with_dataclass_with_str_attr_and_additional_unknown_attr_with_same_name():
@dataclass
class MyDoc:
name: str

d = Document(MyDoc(name='mydoc'), name='doc')

assert d.tags['name'] == 'doc'


def test_empty_list_dataclass():
@dataclass()
class A:
Expand Down

0 comments on commit 030f5b3

Please sign in to comment.