Skip to content

Commit

Permalink
feat: multi modal document (#188)
Browse files Browse the repository at this point in the history
Co-authored-by: winstonww <winstonwongww2@gmail.com>
  • Loading branch information
alaeddine-13 and winstonww authored Mar 23, 2022
1 parent 60d1665 commit 5a8544a
Show file tree
Hide file tree
Showing 17 changed files with 1,125 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install wheel
pip install --no-cache-dir ".[full,test]"
sudo apt-get install libsndfile1
- name: Test
id: test
run: |
Expand Down
110 changes: 97 additions & 13 deletions docarray/array/mixins/traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,73 @@
Optional,
Callable,
Tuple,
Dict,
List,
)

if TYPE_CHECKING:
from ... import DocumentArray, Document
from ...types import T


ATTRIBUTES_SEPARATOR = ','
PATHS_SEPARATOR = ','

SLICE_BASE = r'[-\d:]+'
WRAPPED_SLICE_BASE = r'\[[-\d:]+\]'

SLICE = rf'({SLICE_BASE}|{WRAPPED_SLICE_BASE})?'
SLICE_TAGGED = rf'(?P<slice>{SLICE})'

ATTRIBUTE_NAME = r'[a-zA-Z][a-zA-Z0-9]*'

# accepts both syntaxes: '.[att]' or '.att'
# However, this makes the grammar ambiguous. E.g:
# 'r.attr' should it be parsed into tokens 'r', '.', 'attr' or 'r', '.', 'att', 'r' ?
ATTRIBUTE = rf'\.(\[({ATTRIBUTE_NAME}({ATTRIBUTES_SEPARATOR}{ATTRIBUTE_NAME})*)\]|{ATTRIBUTE_NAME})'
ATTRIBUTE_TAGGED = rf'\.(\[(?P<attributes>{ATTRIBUTE_NAME}({ATTRIBUTES_SEPARATOR}{ATTRIBUTE_NAME})*)\]|(?P<attribute>{ATTRIBUTE_NAME}))'

SELECTOR = rf'(r|c|m|{ATTRIBUTE})'
SELECTOR_TAGGED = rf'(?P<selector>r|c|m|{ATTRIBUTE_TAGGED})'

REMAINDER = rf'({SELECTOR}{SLICE})*'
REMAINDER_TAGGED = rf'(?P<remainder>({SELECTOR}{SLICE})*)'

TRAVERSAL_PATH = rf'{SELECTOR}{SLICE}{REMAINDER}'
TRAVERSAL_PATH_TAGGED = rf'(?P<path>{SELECTOR_TAGGED}{SLICE_TAGGED}){REMAINDER_TAGGED}'


PATHS_REMAINDER_TAGGED = rf'(?P<paths_remainder>({PATHS_SEPARATOR}{TRAVERSAL_PATH})*)'

TRAVERSAL_PATH_LIST_TAGGED = (
rf'^(?P<traversal_path>{TRAVERSAL_PATH}){PATHS_REMAINDER_TAGGED}$'
)

ATTRIBUTE_REGEX = re.compile(rf'^{ATTRIBUTE}$')
TRAVERSAL_PATH_REGEX = re.compile(rf'^{TRAVERSAL_PATH_TAGGED}$')
TRAVERSAL_PATH_LIST_REGEX = re.compile(TRAVERSAL_PATH_LIST_TAGGED)


def _re_traversal_path_split(path: str) -> List[str]:
res = []
remainder = path
while True:
m = TRAVERSAL_PATH_LIST_REGEX.match(remainder)
if not m:
raise ValueError(
f'`path`:{path} is invalid, please refer to https://docarray.jina.ai/fundamentals/documentarray/access-elements/#index-by-nested-structure'
)
group_dict = m.groupdict()
current, remainder = group_dict['traversal_path'], group_dict['paths_remainder']
res.append(current)
if not remainder:
break
else:
remainder = remainder[1:]

return res


class TraverseMixin:
"""
A mixin used for traversing :class:`DocumentArray`.
Expand All @@ -36,13 +96,16 @@ def traverse(
- `r`: docs in this TraversableSequence
- `m`: all match-documents at adjacency 1
- `c`: all child-documents at granularity 1
- `r.[attribute]`: access attribute of a multi modal document
- `cc`: all child-documents at granularity 2
- `mm`: all match-documents at adjacency 2
- `cm`: all match-document at adjacency 1 and granularity 1
- `r,c`: docs in this TraversableSequence and all child-documents at granularity 1
- `r[start:end]`: access sub document array using slice
"""
for p in traversal_paths.split(','):
traversal_paths = re.sub(r'\s+', '', traversal_paths)
for p in _re_traversal_path_split(traversal_paths):
yield from self._traverse(self, p, filter_fn=filter_fn)

@staticmethod
Expand All @@ -53,21 +116,33 @@ def _traverse(
):
path = re.sub(r'\s+', '', path)
if path:
cur_loc, cur_slice, _left = _parse_path_string(path)
group_dict = _parse_path_string(path)
cur_loc = group_dict['selector']
cur_slice = group_dict['slice']
remainder = group_dict['remainder']

if cur_loc == 'r':
yield from TraverseMixin._traverse(
docs[cur_slice], _left, filter_fn=filter_fn
docs[cur_slice], remainder, filter_fn=filter_fn
)
elif cur_loc == 'm':
for d in docs:
yield from TraverseMixin._traverse(
d.matches[cur_slice], _left, filter_fn=filter_fn
d.matches[cur_slice], remainder, filter_fn=filter_fn
)
elif cur_loc == 'c':
for d in docs:
yield from TraverseMixin._traverse(
d.chunks[cur_slice], _left, filter_fn=filter_fn
d.chunks[cur_slice], remainder, filter_fn=filter_fn
)
elif ATTRIBUTE_REGEX.match(cur_loc):
for d in docs:
for attribute in group_dict['attributes']:
yield from TraverseMixin._traverse(
d.get_multi_modal_attribute(attribute)[cur_slice],
remainder,
filter_fn=filter_fn,
)
else:
raise ValueError(
f'`path`:{path} is invalid, please refer to https://docarray.jina.ai/fundamentals/documentarray/access-elements/#index-by-nested-structure'
Expand All @@ -92,7 +167,8 @@ def traverse_flat_per_path(
:param filter_fn: function to filter docs during traversal
:yield: :class:``TraversableSequence`` containing the document of all leaves per path.
"""
for p in traversal_paths.split(','):
traversal_paths = re.sub(r'\s+', '', traversal_paths)
for p in _re_traversal_path_split(traversal_paths):
yield self._flatten(self._traverse(self, p, filter_fn=filter_fn))

def traverse_flat(
Expand Down Expand Up @@ -159,23 +235,31 @@ def _flatten(sequence) -> 'DocumentArray':
return DocumentArray(list(itertools.chain.from_iterable(sequence)))


def _parse_path_string(p: str) -> Tuple[str, slice, str]:
g = re.match(r'^([rcm])([-\d:]+)?([rcm].*)?$', p)
_this = g.group(1)
slice_str = g.group(2)
_next = g.group(3)
return _this, _parse_slice(slice_str or ':'), _next or ''
def _parse_path_string(p: str) -> Dict[str, str]:
g = TRAVERSAL_PATH_REGEX.match(p)
group_dict = g.groupdict()
group_dict['remainder'] = group_dict.get('remainder') or ''
group_dict['slice'] = _parse_slice(group_dict.get('slice') or ':')
if group_dict.get('attributes'):
group_dict['attributes'] = group_dict['attributes'].split(ATTRIBUTES_SEPARATOR)
elif group_dict.get('attribute'):
group_dict['attributes'] = [group_dict.get('attribute')]

return group_dict


def _parse_slice(value):
"""
Parses a `slice()` from string, like `start:stop:step`.
"""
if re.match(WRAPPED_SLICE_BASE, value):
value = value[1:-1]

if value:
parts = value.split(':')
if len(parts) == 1:
# slice(stop)
parts = [None, parts[0]]
parts = [parts[0], str(int(parts[0]) + 1)]
# else: slice(start, stop[, step])
else:
# slice()
Expand Down
4 changes: 3 additions & 1 deletion docarray/document/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
uri='',
mime_type='',
tags=dict,
_metadata=dict,
offset=0.0,
location=list,
modality='',
Expand Down Expand Up @@ -51,6 +52,7 @@ class DocumentData:
weight: Optional[float] = None
uri: Optional[str] = None
tags: Optional[Dict[str, 'StructValueType']] = None
_metadata: Optional[Dict[str, 'StructValueType']] = None
offset: Optional[float] = None
location: Optional[List[float]] = None
embedding: Optional['ArrayType'] = field(default=None, hash=False, compare=False)
Expand All @@ -65,7 +67,7 @@ def _non_empty_fields(self) -> Tuple[str]:
r = []
for f in fields(self):
f_name = f.name
if not f_name.startswith('_'):
if not f_name.startswith('_') or f_name == '_metadata':
v = getattr(self, f_name)
if v is not None:
if f_name not in default_values:
Expand Down
2 changes: 2 additions & 0 deletions docarray/document/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .featurehash import FeatureHashMixin
from .image import ImageDataMixin
from .mesh import MeshDataMixin
from .multimodal import MultiModalMixin
from .plot import PlotMixin
from .porting import PortingMixin
from .property import PropertyMixin
Expand Down Expand Up @@ -37,6 +38,7 @@ class AllMixins(
PortingMixin,
FeatureHashMixin,
GetAttributesMixin,
MultiModalMixin,
):
"""All plugins that can be used in :class:`Document`. """

Expand Down
9 changes: 9 additions & 0 deletions docarray/document/mixins/_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def tags(self) -> Optional[Dict[str, 'StructValueType']]:
def tags(self, value: Dict[str, 'StructValueType']):
self._data.tags = value

@property
def _metadata(self) -> Optional[Dict[str, 'StructValueType']]:
self._data._set_default_value_if_none('_metadata')
return self._data._metadata

@_metadata.setter
def _metadata(self, value: Dict[str, 'StructValueType']):
self._data._metadata = value

@property
def offset(self) -> Optional[float]:
self._data._set_default_value_if_none('offset')
Expand Down
140 changes: 140 additions & 0 deletions docarray/document/mixins/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import base64

import typing
from enum import Enum

from docarray.types.multimodal import Image, Text, Field, is_dataclass
from docarray.types.multimodal import TYPES_REGISTRY

if typing.TYPE_CHECKING:
from docarray import Document, DocumentArray


class AttributeType(str, Enum):
DOCUMENT = 'document'
PRIMITIVE = 'primitive'
ITERABLE_PRIMITIVE = 'iterable_primitive'
ITERABLE_DOCUMENT = 'iterable_document'
NESTED = 'nested'
ITERABLE_NESTED = 'iterable_nested'


class MultiModalMixin:
@classmethod
def from_dataclass(cls, obj):
if not is_dataclass(obj):
raise ValueError(f'Object {obj.__name__} is not a dataclass instance')

from docarray import Document

root = Document()
tags = {}
multi_modal_schema = {}
for key, field in obj.__dataclass_fields__.items():
attribute = getattr(obj, key)
if field.type in [str, int, float, bool] and not isinstance(field, Field):
tags[key] = attribute
multi_modal_schema[key] = {
'attribute_type': AttributeType.PRIMITIVE,
'type': field.type.__name__,
}

elif field.type == bytes and not isinstance(field, Field):
tags[key] = base64.b64encode(attribute).decode()
multi_modal_schema[key] = {
'attribute_type': AttributeType.PRIMITIVE,
'type': field.type.__name__,
}
elif isinstance(field.type, typing._GenericAlias):
if field.type._name in ['List', 'Iterable']:
sub_type = field.type.__args__[0]
if sub_type in [str, int, float, bool]:
tags[key] = attribute
multi_modal_schema[key] = {
'attribute_type': AttributeType.ITERABLE_PRIMITIVE,
'type': f'{field.type._name}[{sub_type.__name__}]',
}

else:
chunk = Document()
for element in attribute:
doc, attribute_type = cls._from_obj(
element, sub_type, field
)
if attribute_type == AttributeType.DOCUMENT:
attribute_type = AttributeType.ITERABLE_DOCUMENT
elif attribute_type == AttributeType.NESTED:
attribute_type = AttributeType.ITERABLE_NESTED
else:
raise ValueError(
f'Unsupported type annotation inside Iterable: {sub_type}'
)
chunk.chunks.append(doc)
multi_modal_schema[key] = {
'attribute_type': attribute_type,
'type': f'{field.type._name}[{sub_type.__name__}]',
'position': len(root.chunks),
}
root.chunks.append(chunk)
else:
raise ValueError(f'Unsupported type annotation {field.type._name}')
else:
doc, attribute_type = cls._from_obj(attribute, field.type, field)
multi_modal_schema[key] = {
'attribute_type': attribute_type,
'type': field.type.__name__,
'position': len(root.chunks),
}
root.chunks.append(doc)

# TODO: may have to modify this?
root.tags = tags
root._metadata['multi_modal_schema'] = multi_modal_schema

return root

def get_multi_modal_attribute(self, attribute: str) -> 'DocumentArray':
from docarray import DocumentArray

if 'multi_modal_schema' not in self._metadata:
raise ValueError(
'the Document does not correspond to a Multi Modal Document'
)

if attribute not in self._metadata['multi_modal_schema']:
raise ValueError(
f'the Document schema does not contain attribute {attribute}'
)

attribute_type = self._metadata['multi_modal_schema'][attribute][
'attribute_type'
]
position = self._metadata['multi_modal_schema'][attribute].get('position')

if attribute_type in [AttributeType.DOCUMENT, AttributeType.NESTED]:
return DocumentArray([self.chunks[position]])
elif attribute_type in [
AttributeType.ITERABLE_DOCUMENT,
AttributeType.ITERABLE_NESTED,
]:
return self.chunks[position].chunks
else:
raise ValueError(
f'Invalid attribute {attribute}: must a Document attribute or nested dataclass'
)

@classmethod
def _from_obj(cls, obj, obj_type, field) -> typing.Tuple['Document', AttributeType]:
from docarray import Document

attribute_type = AttributeType.DOCUMENT

if is_dataclass(obj_type):
doc = cls.from_dataclass(obj)
attribute_type = AttributeType.NESTED
elif isinstance(field, Field):
doc = Document()
field.serializer(obj, field.name, doc)
else:
raise ValueError(f'Unsupported type annotation')
return doc, attribute_type
1 change: 1 addition & 0 deletions docarray/document/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class PydanticDocument(BaseModel):
weight: Optional[float]
uri: Optional[str]
tags: Optional[Dict[str, '_StructValueType']]
_metadata: Optional[Dict[str, '_StructValueType']]
offset: Optional[float]
location: Optional[List[float]]
embedding: Optional[Any]
Expand Down
Loading

0 comments on commit 5a8544a

Please sign in to comment.