Skip to content

Commit

Permalink
Merge pull request #1353 from syucream/fix/typed-join-attrs
Browse files Browse the repository at this point in the history
Make join attr types more type-safe
  • Loading branch information
hinashi authored Jan 22, 2025
2 parents 27fa5ff + 56c05d6 commit 8571155
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
17 changes: 16 additions & 1 deletion entry/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.db.models import Prefetch
from drf_spectacular.utils import extend_schema_field, extend_schema_serializer
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied, ValidationError
from typing_extensions import TypedDict
Expand Down Expand Up @@ -140,6 +140,21 @@ class EntryAttributeType(TypedDict):
schema: EntityAttributeType


class AdvancedSearchJoinAttrAttrInfo(BaseModel):
name: str
keyword: str | None = None
filter_key: FilterKey | None = None


class AdvancedSearchJoinAttrInfo(BaseModel):
name: str
offset: int = 0
attrinfo: list[AdvancedSearchJoinAttrAttrInfo] = []


AdvancedSearchJoinAttrInfoList = RootModel[list[AdvancedSearchJoinAttrInfo]]


class EntityAttributeTypeSerializer(serializers.Serializer):
id = serializers.IntegerField()
name = serializers.CharField()
Expand Down
41 changes: 20 additions & 21 deletions entry/api_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from entity.models import Entity, EntityAttr
from entry.api_v2.pagination import EntryReferralPagination
from entry.api_v2.serializers import (
AdvancedSearchJoinAttrInfo,
AdvancedSearchJoinAttrInfoList,
AdvancedSearchResultExportSerializer,
AdvancedSearchResultSerializer,
AdvancedSearchSerializer,
Expand Down Expand Up @@ -265,10 +267,12 @@ def post(self, request: Request) -> Response:
is_all_entities = serializer.validated_data["is_all_entities"]
entry_limit = serializer.validated_data["entry_limit"]
entry_offset = serializer.validated_data["entry_offset"]
join_attrs = serializer.validated_data.get("join_attrs", [])
join_attrs = AdvancedSearchJoinAttrInfoList.model_validate(
serializer.validated_data.get("join_attrs", [])
).root

def _get_joined_resp(
prev_results: list[AdvancedSearchResultRecord], join_attr: dict
prev_results: list[AdvancedSearchResultRecord], join_attr: AdvancedSearchJoinAttrInfo
) -> tuple[bool, dict]:
"""
This is a helper method for join_attrs that will get specified attr values
Expand All @@ -285,7 +289,7 @@ def _get_joined_resp(
Prefetch(
"attrs",
queryset=EntityAttr.objects.filter(
name=join_attr["name"], is_active=True
name=join_attr.name, is_active=True
).prefetch_related(
Prefetch(
"referral", queryset=Entity.objects.filter(is_active=True).only("id")
Expand All @@ -300,7 +304,7 @@ def _get_joined_resp(
if entity is None:
continue

attr = next((a for a in entity.attrs.all() if a.name == join_attr["name"]), None)
attr = next((a for a in entity.attrs.all() if a.name == join_attr.name), None)
if attr is None:
continue

Expand All @@ -309,7 +313,7 @@ def _get_joined_resp(
hint_entity_ids.extend([x.id for x in attr.referral.all()])

# set Item name
attrinfo = result.attrs[join_attr["name"]]
attrinfo = result.attrs[join_attr.name]

if attr.type == AttrType.OBJECT and attrinfo["value"]["name"] not in item_names:
item_names.append(attrinfo["value"]["name"])
Expand All @@ -332,24 +336,19 @@ def _get_joined_resp(

# set parameters to filter joining search results
hint_attrs: list[AttrHint] = []
for info in join_attr.get("attrinfo", []):
for info in join_attr.attrinfo:
hint_attrs.append(
AttrHint(
name=info["name"],
keyword=info.get("keyword"),
filter_key=info.get("filter_key"),
name=info.name,
keyword=info.keyword,
filter_key=info.filter_key,
)
)

# search Items from elasticsearch to join
return (
# This represents whether user want to narrow down results by keyword of joined attr
any(
[
x.get("keyword") or x.get("filter_key", 0) > 0
for x in join_attr.get("attrinfo", [])
]
),
any([x.keyword or (x.filter_key or 0) > 0 for x in join_attr.attrinfo]),
AdvancedSearchService.search_entries(
request.user,
hint_entity_ids=list(set(hint_entity_ids)), # this removes depulicated IDs
Expand All @@ -359,7 +358,7 @@ def _get_joined_resp(
hint_referral=None,
is_output_all=is_output_all,
hint_referral_entity_id=None,
offset=join_attr.get("offset", 0),
offset=join_attr.offset,
).dict(),
)

Expand Down Expand Up @@ -447,20 +446,20 @@ def _get_ref_id_from_es_result(attrinfo):
(will_filter_by_joined_attr, joined_resp) = _get_joined_resp(resp.ret_values, join_attr)
# This is needed to set result as blank value
blank_joining_info = {
"%s.%s" % (join_attr["name"], k["name"]): {
"%s.%s" % (join_attr.name, k.name): {
"is_readable": True,
"type": AttrType.STRING,
"value": "",
}
for k in join_attr["attrinfo"]
for k in join_attr.attrinfo
}

# convert search result to dict to be able to handle it without loop
joined_resp_info = {
x["entry"]["id"]: {
"%s.%s" % (join_attr["name"], k): v
"%s.%s" % (join_attr.name, k): v
for k, v in x["attrs"].items()
if any(_x["name"] == k for _x in join_attr["attrinfo"])
if any(_x.name == k for _x in join_attr.attrinfo)
}
for x in joined_resp["ret_values"]
}
Expand All @@ -470,7 +469,7 @@ def _get_ref_id_from_es_result(attrinfo):
joined_ret_values = []
for resp_result in resp.ret_values:
# joining search result to original one
ref_info = resp_result.attrs.get(join_attr["name"])
ref_info = resp_result.attrs.get(join_attr.name)

# This get referral Item-ID from joined search result
ref_list = _get_ref_id_from_es_result(ref_info)
Expand Down

0 comments on commit 8571155

Please sign in to comment.