Skip to content

Commit

Permalink
DH-4357 Creates nl_question repository
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Aug 3, 2023
1 parent e1dfdb3 commit d90e49a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 5 deletions.
8 changes: 5 additions & 3 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dataherald.db_scanner.repository.base import DBScannerRepository
from dataherald.eval import Evaluation, Evaluator
from dataherald.repositories.base import NLQueryResponseRepository
from dataherald.repositories.nl_question import NLQuestionRepository
from dataherald.smart_cache import SmartCache
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
Expand Down Expand Up @@ -79,9 +80,10 @@ def answer_question(self, question: str, db_alias: str) -> NLQueryResponse:
context_store = self.system.instance(ContextStore)

user_question = NLQuery(question=question, db_alias=db_alias)
user_question.id = self.storage.insert_one(
"nl_question", user_question.dict(exclude={"id"})
)

nl_question_repository = NLQuestionRepository(self.storage)
user_question = nl_question_repository.insert(user_question)

db_connection = self.storage.find_one(
"database_connection", {"alias": db_alias}
)
Expand Down
7 changes: 5 additions & 2 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataherald.config import System
from dataherald.context_store import ContextStore
from dataherald.repositories.base import NLQueryResponseRepository
from dataherald.repositories.nl_question import NLQuestionRepository
from dataherald.types import NLQuery, NLQueryResponse

logger = logging.getLogger(__name__)
Expand All @@ -30,12 +31,13 @@ def retrieve_context_for_question(
)

samples = []
nl_question_repository = NLQuestionRepository(self.db)
nl_query_response_repository = NLQueryResponseRepository(self.db)
for question in closest_questions:
golden_query = nl_query_response_repository.find_one(
{"nl_question_id": ObjectId(question["id"])}
)
associated_nl_question = self.db.find_by_id("nl_question", question["id"])
associated_nl_question = nl_question_repository.find_by_id(question["id"])
if golden_query is not None and associated_nl_question is not None:
samples.append(
{
Expand All @@ -57,7 +59,8 @@ def add_golden_records(self, golden_records: List) -> bool:
tables = Parser(record["sql"]).tables
question = record["nl_question"]
user_question = NLQuery(question=question, db_alias=record["db"])
user_question.id = self.db.insert_one("nl_question", user_question.dict())
nl_query_repository = NLQuestionRepository(self.db)
user_question = nl_query_repository.insert(user_question)
self.vector_store.add_record(
documents=question,
collection=self.golden_record_collection,
Expand Down
8 changes: 8 additions & 0 deletions dataherald/repositories/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from bson.objectid import ObjectId

from dataherald.types import NLQueryResponse

DB_COLLECTION = "nl_query_response"
Expand All @@ -18,3 +20,9 @@ def find_one(self, query: dict) -> NLQueryResponse | None:
if not row:
return None
return NLQueryResponse(**row)

def find_by_id(self, id: str) -> NLQueryResponse | None:
row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)})
if not row:
return None
return NLQueryResponse(**row)
28 changes: 28 additions & 0 deletions dataherald/repositories/nl_question.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from bson.objectid import ObjectId

from dataherald.types import NLQuery

DB_COLLECTION = "nl_question"


class NLQuestionRepository:
def __init__(self, storage):
self.storage = storage

def insert(self, nl_query: NLQuery) -> NLQuery:
nl_query.id = self.storage.insert_one(
DB_COLLECTION, nl_query.dict(exclude={"id"})
)
return nl_query

def find_one(self, query: dict) -> NLQuery | None:
row = self.storage.find_one(DB_COLLECTION, query)
if not row:
return None
return NLQuery(**row)

def find_by_id(self, id: str) -> NLQuery | None:
row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)})
if not row:
return None
return NLQuery(**row)

0 comments on commit d90e49a

Please sign in to comment.