Skip to content

Commit

Permalink
Update Elastic bulk index and schemas to pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
prrao87 committed Jul 17, 2023
1 parent 4211ad8 commit 4dd10f4
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 58 deletions.
31 changes: 13 additions & 18 deletions dbs/elasticsearch/schemas/retriever.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class FullTextSearch(BaseModel):
id: int
country: str
title: str
description: str | None
points: int
price: float | str | None
variety: str | None
winery: str | None

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra={
"example": {
"id": 3845,
"country": "Italy",
Expand All @@ -24,6 +15,16 @@ class Config:
"winery": "Castellinuzza e Piuca",
}
}
)

id: int
country: str
title: str
description: str | None
points: int
price: float | str | None
variety: str | None
winery: str | None


class TopWinesByCountry(BaseModel):
Expand All @@ -36,9 +37,6 @@ class TopWinesByCountry(BaseModel):
variety: str | None
winery: str | None

class Config:
validate_assignment = True


class TopWinesByProvince(BaseModel):
id: int
Expand All @@ -51,9 +49,6 @@ class TopWinesByProvince(BaseModel):
variety: str | None
winery: str | None

class Config:
validate_assignment = True


class MostWinesByVariety(BaseModel):
country: str
Expand Down
91 changes: 55 additions & 36 deletions dbs/elasticsearch/schemas/wine.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator


class Wine(BaseModel):
id: int
points: int
title: str
description: str | None
price: float | None
variety: str | None
winery: str | None
vineyard: str | None
country: str | None
province: str | None
region_1: str | None
region_2: str | None
taster_name: str | None
taster_twitter_handle: str | None

class Config:
allow_population_by_field_name = True
validate_assignment = True
schema_extra = {
model_config = ConfigDict(
populate_by_name=True,
validate_assignment=True,
extra="allow",
str_strip_whitespace=True,
json_schema_extra={
"example": {
"id": 45100,
"points": 85,
Expand All @@ -37,26 +24,58 @@ class Config:
"taster_name": "Michael Schachner",
"taster_twitter_handle": "@wineschach",
}
}
},
)

@root_validator
def _create_id_field(cls, values):
"Elastic needs an _id field to create unique documents, so we just use the existing id field"
values["_id"] = values["id"]
return values

@root_validator(pre=True)
def _get_vineyard(cls, values):
"Rename designation to vineyard"
vineyard = values.pop("designation", None)
if vineyard:
values["vineyard"] = vineyard.strip()
return values
id: int
points: int
title: str
description: str | None
price: float | None
variety: str | None
winery: str | None
vineyard: str | None = Field(..., alias="designation")
country: str | None
province: str | None
region_1: str | None
region_2: str | None
taster_name: str | None
taster_twitter_handle: str | None

@root_validator
@model_validator(mode="before")
def _fill_country_unknowns(cls, values):
"Fill in missing country values with 'Unknown', as we always want this field to be queryable"
country = values.get("country")
if not country:
if country is None or country == "null":
values["country"] = "Unknown"
return values

@model_validator(mode="before")
def _create_id(cls, values):
"Create an _id field because Elastic needs this to store as primary key"
values["_id"] = values["id"]
return values



if __name__ == "__main__":
data = {
"id": 45100,
"points": 85,
"title": "Balduzzi 2012 Reserva Merlot (Maule Valley)",
"description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.",
"price": 10, # Test if field is cast to float
"variety": "Merlot",
"winery": "Balduzzi",
"designation": "Reserva", # Test if field is renamed
"country": "null", # Test unknown country
"province": " Maule Valley ", # Test if field is stripped
"region_1": "null",
"region_2": "null",
"taster_name": "Michael Schachner",
"taster_twitter_handle": "@wineschach",
}
from pprint import pprint

wine = Wine(**data)
pprint(wine.model_dump(), sort_dicts=False)
6 changes: 2 additions & 4 deletions dbs/elasticsearch/scripts/bulk_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import srsly
from dotenv import load_dotenv
from elasticsearch import AsyncElasticsearch, helpers
from pydantic.main import ModelMetaclass

sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1]))
from api.config import Settings
Expand Down Expand Up @@ -59,15 +58,14 @@ def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]:

def validate(
data: tuple[JsonBlob],
model: ModelMetaclass,
exclude_none: bool = False,
) -> list[JsonBlob]:
validated_data = [model(**item).dict(exclude_none=exclude_none) for item in data]
validated_data = [Wine(**item).dict(exclude_none=exclude_none) for item in data]
return validated_data


def process_chunks(data: list[JsonBlob]) -> tuple[list[JsonBlob], str]:
validated_data = validate(data, Wine, exclude_none=True)
validated_data = validate(data, exclude_none=True)
return validated_data


Expand Down

0 comments on commit 4dd10f4

Please sign in to comment.