Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fb/DH-3408 this commit adds a new endpoint to run an SQL query agains… #48

Merged
merged 2 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, List

from dataherald.api.types import Query
from dataherald.config import Component
from dataherald.eval import Evaluation
from dataherald.sql_database.models.types import SSHSettings
Expand Down Expand Up @@ -34,3 +35,7 @@ def connect_database(
@abstractmethod
def add_golden_records(self, golden_records: List) -> bool:
pass

@abstractmethod
def execute_query(self, query: Query) -> tuple[str, dict]:
pass
15 changes: 15 additions & 0 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from overrides import override

from dataherald.api import API
from dataherald.api.types import Query
from dataherald.config import DBConnectionConfigSettings, System
from dataherald.context_store import ContextStore
from dataherald.db import DB
from dataherald.eval import Evaluation, Evaluator
from dataherald.smart_cache import SmartCache
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.sql_generator import SQLGenerator
from dataherald.types import DataDefinitionType, NLQuery, NLQueryResponse
Expand Down Expand Up @@ -118,3 +120,16 @@ def add_golden_records(self, golden_records: List) -> bool:
"""Takes in a list of NL <> SQL pairs and stores them to be used in prompts to the LLM"""
context_store = self.system.instance(ContextStore)
return context_store.add_golden_records(golden_records)

@override
def execute_query(self, query: Query) -> tuple[str, dict]:
"""Executes a SQL query against the database and returns the results"""
db_connection = self.storage.find_one(
"database_connection", {"alias": query.db_alias}
)
if not db_connection:
raise HTTPException(status_code=404, detail="Database connection not found")
database_connection = DatabaseConnection(**db_connection)
database = SQLDatabase.get_sql_engine(database_connection)
print(type(database.run_sql(query.sql_statement)))
return database.run_sql(query.sql_statement)
6 changes: 6 additions & 0 deletions dataherald/api/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


class Query(BaseModel):
sql_statement: str
db_alias: str
7 changes: 7 additions & 0 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.routing import APIRoute

import dataherald
from dataherald.api.types import Query
from dataherald.config import Settings
from dataherald.eval import Evaluation
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(self, settings: Settings):
"/api/v1/data-definition", self.add_data_definition, methods=["POST"]
)

self.router.add_api_route("/api/v1/query", self.execute_query, methods=["POST"])

self._app.include_router(self.router)
use_route_names_as_operation_ids(self._app)

Expand Down Expand Up @@ -94,3 +97,7 @@ def add_golden_records(self, golden_records: List) -> bool:
def add_data_definition(self, uri: str, type: DataDefinitionType) -> bool:
"""Takes in an English question and answers it based on content from the registered databases"""
return self._api.add_data_definition(type, uri)

def execute_query(self, query: Query) -> tuple[str, dict]:
"""Executes a query on the given db_alias"""
return self._api.execute_query(query)