Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): refactor #936

Merged
merged 6 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,948 changes: 974 additions & 974 deletions wren-ai-service/demo/sample_dataset/ecommerce_duckdb_mdl.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion wren-ai-service/src/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import Any, Dict

from hamilton.experimental.h_async import AsyncDriver
from hamilton.async_driver import AsyncDriver
from haystack import Pipeline

from src.core.engine import Engine
Expand Down
31 changes: 25 additions & 6 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
add_quotes,
clean_generation_result,
)
from src.web.v1.services.ask import AskConfigurations
from src.core.pipeline import BasicPipeline
from src.web.v1.services import Configuration

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -444,16 +445,16 @@ async def _task(result: Dict[str, str]):
"""


def construct_instructions(configurations: AskConfigurations | None):
def construct_instructions(configuration: Configuration | None):
instructions = ""
if configurations:
if configurations.fiscal_year:
instructions += f"- For calendar year related computation, it should be started from {configurations.fiscal_year.start} to {configurations.fiscal_year.end}"
if configuration:
if configuration.fiscal_year:
instructions += f"- For calendar year related computation, it should be started from {configuration.fiscal_year.start} to {configuration.fiscal_year.end}"

return instructions


def show_current_time(timezone: AskConfigurations.Timezone):
def show_current_time(timezone: Configuration.Timezone):
# Get the current time in the specified timezone
tz = pytz.timezone(
timezone.name
Expand Down Expand Up @@ -485,3 +486,21 @@ def build_table_ddl(
+ ",\n ".join(columns_ddl)
+ "\n);"
)


def dry_run_pipeline(pipeline_cls: BasicPipeline, pipeline_name: str, **kwargs):
from langfuse.decorators import langfuse_context

from src.config import settings
from src.core.pipeline import async_validate
from src.providers import generate_components
from src.utils import init_langfuse

pipe_components = generate_components(settings.components)
pipeline = pipeline_cls(**pipe_components[pipeline_name])
init_langfuse()

pipeline.visualize(**kwargs)
async_validate(lambda: pipeline.run(**kwargs))

langfuse_context.flush()
28 changes: 9 additions & 19 deletions wren-ai-service/src/pipelines/generation/data_assistance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Optional

from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from hamilton.async_driver import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe
from pydantic import BaseModel
Expand Down Expand Up @@ -170,22 +170,12 @@ async def run(


if __name__ == "__main__":
from langfuse.decorators import langfuse_context

from src.core.engine import EngineConfig
from src.core.pipeline import async_validate
from src.providers import init_providers
from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, _ = init_providers(engine_config=EngineConfig())
pipeline = DataAssistance(
llm_provider=llm_provider,
from src.pipelines.common import dry_run_pipeline

dry_run_pipeline(
DataAssistance,
"data_assistance",
query="show me the dataset",
db_schemas=[],
language="English",
)

pipeline.visualize("show me the dataset", [], "English")
async_validate(lambda: pipeline.run("show me the dataset", [], "English"))

langfuse_context.flush()
53 changes: 18 additions & 35 deletions wren-ai-service/src/pipelines/generation/followup_sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, List

from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from hamilton.async_driver import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe
from pydantic import BaseModel
Expand All @@ -20,7 +20,8 @@
sql_generation_system_prompt,
)
from src.utils import async_timer, timer
from src.web.v1.services.ask import AskConfigurations, AskHistory
from src.web.v1.services import Configuration
from src.web.v1.services.ask import AskHistory
Comment on lines +23 to +24
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: we could consider moving them into pipeline layer in the future. i think it can reduce the bi-dependency between pipeline and service layer


logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -115,16 +116,16 @@ def prompt(
documents: List[str],
history: AskHistory,
alert: str,
configurations: AskConfigurations,
configuration: Configuration,
prompt_builder: PromptBuilder,
) -> dict:
return prompt_builder.run(
query=query,
documents=documents,
history=history,
alert=alert,
instructions=construct_instructions(configurations),
current_time=show_current_time(configurations.timezone),
instructions=construct_instructions(configuration),
current_time=show_current_time(configuration.timezone),
)


Expand Down Expand Up @@ -199,7 +200,7 @@ def visualize(
query: str,
contexts: List[str],
history: AskHistory,
configurations: AskConfigurations = AskConfigurations(),
configuration: Configuration = Configuration(),
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/generation"
Expand All @@ -214,7 +215,7 @@ def visualize(
"documents": contexts,
"history": history,
"project_id": project_id,
"configurations": configurations,
"configuration": configuration,
**self._components,
**self._configs,
},
Expand All @@ -229,7 +230,7 @@ async def run(
query: str,
contexts: List[str],
history: AskHistory,
configurations: AskConfigurations = AskConfigurations(),
configuration: Configuration = Configuration(),
project_id: str | None = None,
):
logger.info("Follow-Up SQL Generation pipeline is running...")
Expand All @@ -240,38 +241,20 @@ async def run(
"documents": contexts,
"history": history,
"project_id": project_id,
"configurations": configurations,
"configuration": configuration,
**self._components,
**self._configs,
},
)


if __name__ == "__main__":
from langfuse.decorators import langfuse_context

from src.core.engine import EngineConfig
from src.core.pipeline import async_validate
from src.providers import init_providers
from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, engine = init_providers(engine_config=EngineConfig())
pipeline = FollowUpSQLGeneration(llm_provider=llm_provider, engine=engine)

pipeline.visualize(
"this is a test query",
[],
AskHistory(sql="SELECT * FROM table", summary="Summary", steps=[]),
from src.pipelines.common import dry_run_pipeline

dry_run_pipeline(
FollowUpSQLGeneration,
"followup_sql_generation",
query="show me the dataset",
contexts=[],
history=AskHistory(sql="SELECT * FROM table", summary="Summary", steps=[]),
)
async_validate(
lambda: pipeline.run(
"this is a test query",
[],
AskHistory(sql="SELECT * FROM table", summary="Summary", steps=[]),
)
)

langfuse_context.flush()
27 changes: 6 additions & 21 deletions wren-ai-service/src/pipelines/generation/intent_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import orjson
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from hamilton.async_driver import AsyncDriver
from haystack import Document
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe
Expand Down Expand Up @@ -273,25 +273,10 @@ async def run(self, query: str, id: Optional[str] = None):


if __name__ == "__main__":
from langfuse.decorators import langfuse_context
from src.pipelines.common import dry_run_pipeline

from src.core.engine import EngineConfig
from src.core.pipeline import async_validate
from src.providers import init_providers
from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, document_store_provider, _ = init_providers(
engine_config=EngineConfig()
)
pipeline = IntentClassification(
document_store_provider=document_store_provider,
llm_provider=llm_provider,
dry_run_pipeline(
IntentClassification,
"intent_classification",
query="show me the dataset",
)

pipeline.visualize("this is a query")
async_validate(lambda: pipeline.run("this is a query"))

langfuse_context.flush()
52 changes: 18 additions & 34 deletions wren-ai-service/src/pipelines/generation/question_recommendation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import sys
from datetime import datetime
Expand All @@ -7,12 +6,12 @@

import orjson
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from hamilton.async_driver import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline, async_validate
from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider

logger = logging.getLogger("wren-ai-service")
Expand All @@ -30,7 +29,7 @@ def prompt(
prompt_builder: PromptBuilder,
) -> dict:
return prompt_builder.run(
models=mdl["models"],
models=mdl.get("models", []),
previous_questions=previous_questions,
language=language,
current_date=current_date,
Expand Down Expand Up @@ -203,6 +202,7 @@ def visualize(
"current_date": current_date,
"max_questions": max_questions,
"max_categories": max_categories,
**self._components,
},
show_legend=True,
orient="LR",
Expand All @@ -215,7 +215,7 @@ async def run(
previous_questions: list[str] = [],
categories: list[str] = [],
language: str = "English",
current_date: str = datetime.now(),
current_date: str = datetime.now().strftime("%Y-%m-%d %A %H:%M:%S"),
max_questions: int = 5,
max_categories: int = 3,
**_,
Expand All @@ -237,32 +237,16 @@ async def run(


if __name__ == "__main__":
from langfuse.decorators import langfuse_context

from src.core.engine import EngineConfig
from src.core.pipeline import async_validate
from src.providers import init_providers
from src.utils import init_langfuse, load_env_vars

load_env_vars()
init_langfuse()

llm_provider, _, _, _ = init_providers(EngineConfig())
pipeline = QuestionRecommendation(llm_provider=llm_provider)

with open("sample/ecommerce_duckdb_mdl.json", "r") as file:
mdl = json.load(file)

input = {
"mdl": mdl,
"previous_questions": [],
"categories": ["Customer Insights", "Product Performance"],
"language": "English",
"max_questions": 5,
"max_categories": 2,
}

# pipeline.visualize(**input)
async_validate(lambda: pipeline.run(**input))

langfuse_context.flush()
from src.pipelines.common import dry_run_pipeline

dry_run_pipeline(
QuestionRecommendation,
"question_recommendation",
mdl={},
previous_questions=[],
categories=[],
language="English",
current_date=datetime.now().strftime("%Y-%m-%d %A %H:%M:%S"),
max_questions=5,
max_categories=3,
)
Loading
Loading