Skip to content

Commit

Permalink
fix(array): fix text getter (#142)
Browse files Browse the repository at this point in the history
* fix(array): fix text getter

* fix(array): tests

* fix(plot): plot local audio and video

Co-authored-by: numb3r3 <wangfelix87@gmail.com>
  • Loading branch information
hanxiao and numb3r3 authored Feb 25, 2022
1 parent 01b3976 commit 354d2e4
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 20 deletions.
23 changes: 10 additions & 13 deletions docarray/array/mixins/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def tensors(self) -> Optional['ArrayType']:
:return: a :class:`ArrayType` of tensors
"""
if self and self[0].content_type == 'tensor':
if self:
return unravel(self, 'tensor')
if self:
return unravel(self, 'tensor')

@tensors.setter
def tensors(self, value: 'ArrayType'):
Expand All @@ -82,9 +81,8 @@ def texts(self) -> Optional[List[str]]:
:return: a list of texts
"""
if self and self[0].content_type == 'text':
if self:
return [d.text for d in self]
if self:
return [d.text for d in self]

@texts.setter
def texts(self, value: Sequence[str]):
Expand All @@ -106,9 +104,8 @@ def blobs(self) -> Optional[List[bytes]]:
:return: a list of blobs
"""
if self and self[0].content_type == 'blob':
if self:
return [d.blob for d in self]
if self:
return [d.blob for d in self]

@blobs.setter
def blobs(self, value: List[bytes]):
Expand All @@ -133,9 +130,9 @@ def contents(self) -> Optional[Union[Sequence['DocumentContentType'], 'ArrayType
:return: a list of texts, blobs or :class:`ArrayType`
"""
if self:
content_type = self[0].content_type
content_type = self[0].content_type or self[-1].content_type
if content_type:
return getattr(self, f'{self[0].content_type}s')
return getattr(self, f'{content_type}s')

@contents.setter
def contents(
Expand All @@ -146,9 +143,9 @@ def contents(
:param value: a list of texts, blobs or :class:`ArrayType`
"""
if self:
content_type = self[0].content_type
content_type = self[0].content_type or self[-1].content_type
if content_type:
setattr(self, f'{self[0].content_type}s', value)
setattr(self, f'{content_type}s', value)


def _get_len(value):
Expand Down
20 changes: 18 additions & 2 deletions docarray/document/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def display(self):

if self.uri:
if self.mime_type.startswith('audio'):
_html5_audio_player(self.uri)
uri = _convert_display_uri(self.uri, self.mime_type)
_html5_audio_player(uri)
elif self.mime_type.startswith('video'):
_html5_video_player(self.uri)
uri = _convert_display_uri(self.uri, self.mime_type)
_html5_video_player(uri)
else:
display(Image(self.uri))
elif self.tensor is not None:
Expand All @@ -58,7 +60,20 @@ def display(self):
plot = deprecate_by(display, removed_at='0.5')


def _convert_display_uri(uri, mime_type):
import urllib
from .helper import _to_datauri, _uri_to_blob

scheme = urllib.parse.urlparse(uri).scheme

if scheme not in ['data', 'http', 'https']:
blob = _uri_to_blob(uri)
return _to_datauri(mime_type, blob)
return uri


def _html5_video_player(uri):
from IPython.display import display
from IPython.core.display import HTML # noqa

src = f'''
Expand All @@ -73,6 +88,7 @@ def _html5_video_player(uri):


def _html5_audio_player(uri):
from IPython.display import display
from IPython.core.display import HTML # noqa

src = f'''
Expand Down
11 changes: 8 additions & 3 deletions tests/unit/array/mixins/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,15 @@ def test_content_empty(da_len, da_cls, config, start_weaviate):
da = da_cls.empty(da_len, config=config)
else:
da = da_cls.empty(da_len)
assert not da.texts

assert not da.contents
assert not da.tensors
assert not da.blobs
if da_len == 0:
assert not da.texts
assert not da.blobs
else:
assert da.texts == [''] * da_len
assert da.blobs == [b''] * da_len

da.texts = ['hello'] * da_len
if da_len == 0:
Expand All @@ -102,7 +107,7 @@ def test_content_empty(da_len, da_cls, config, start_weaviate):
assert da.contents == ['hello'] * da_len
assert da.texts == ['hello'] * da_len
assert not da.tensors
assert not da.blobs
assert da.blobs == [b''] * da_len


@pytest.mark.parametrize('da_len', [0, 1, 2])
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/array/mixins/test_getset.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_texts_getter_da(docs, config, da_cls, start_weaviate):

# unfortunately protobuf does not distinguish None and '' on string
# so non-set str field in Pb is ''
assert not da.texts
assert set(da.texts) == set([''])


@pytest.mark.parametrize(
Expand Down Expand Up @@ -282,7 +282,7 @@ def test_blobs_getter_setter(docs, da_cls, config, start_weaviate):

# unfortunately protobuf does not distinguish None and '' on string
# so non-set str field in Pb is ''
assert not da.blobs
assert set(da.blobs) == set([b''])


@pytest.mark.parametrize(
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/document/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,19 @@ def test_plot_image():
d.uri = None

d.display()


def test_plot_audio():
d = Document(uri=os.path.join(cur_dir, 'toydata/hello.wav'))
d.display()

d.convert_uri_to_datauri()
d.display()


def test_plot_video():
d = Document(uri=os.path.join(cur_dir, 'toydata/mov_bbb.mp4'))
d.display()

d.convert_uri_to_datauri()
d.display()

0 comments on commit 354d2e4

Please sign in to comment.