Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wren-ai-service): column-based batch to generate semantics #923

Merged
merged 6 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions wren-ai-service/src/pipelines/generation/semantics_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,29 @@
## Start of Pipeline
@observe(capture_input=False)
def picked_models(mdl: dict, selected_models: list[str]) -> list[dict]:
def remove_relation_columns(columns: list[dict]) -> list[dict]:
# remove columns that have a relationship property
return [column for column in columns if "relationship" not in column]
def relation_filter(column: dict) -> bool:
return "relationship" not in column

def column_formatter(columns: list[dict]) -> list[dict]:
return [
{
"name": column["name"],
"type": column["type"],
"properties": {
"description": column["properties"].get("description", ""),
},
}
for column in columns
if relation_filter(column)
]

def extract(model: dict) -> dict:
return {
"name": model["name"],
"columns": remove_relation_columns(model["columns"]),
"properties": model["properties"],
"columns": column_formatter(model["columns"]),
"properties": {
"description": model["properties"].get("description", ""),
},
}

return [
Expand Down Expand Up @@ -117,11 +131,11 @@ class SemanticResult(BaseModel):
```
[
{'name': 'model', 'columns': [
{'name': 'column_1', 'type': 'type', 'notNull': True, 'properties': {}
{'name': 'column_1', 'type': 'type', 'properties': {}
},
{'name': 'column_2', 'type': 'type', 'notNull': True, 'properties': {}
{'name': 'column_2', 'type': 'type', 'properties': {}
},
{'name': 'column_3', 'type': 'type', 'notNull': False, 'properties': {}
{'name': 'column_3', 'type': 'type', 'properties': {}
}
], 'properties': {}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
- Generates a new relationship recommendation
- Request body: PostRequest
{
"mdl": "{ ... }" # JSON string of the MDL (Model Definition Language)
"mdl": "{ ... }", # JSON string of the MDL (Model Definition Language)
"project_id": "project-id" # Optional project ID
}
- Response: PostResponse
{
Expand Down Expand Up @@ -62,6 +63,7 @@

class PostRequest(BaseModel):
mdl: str
project_id: Optional[str] = None


class PostResponse(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/src/web/v1/routers/semantics_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"selected_models": ["model1", "model2"], # List of model names to describe
"user_prompt": "Describe these models", # User's instruction for description
"mdl": "{ ... }", # JSON string of the MDL (Model Definition Language)
"project_id": "project-id", # Optional project ID
"configuration": { # Optional configuration settings
"language": "English" # Optional language, defaults to "English"
}
Expand Down Expand Up @@ -89,6 +90,7 @@ class PostRequest(BaseModel):
selected_models: list[str]
user_prompt: str
mdl: str
project_id: Optional[str] = None
configuration: Optional[Configuration] = Configuration()


Expand Down
28 changes: 23 additions & 5 deletions wren-ai-service/src/web/v1/services/semantics_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,45 @@ def _handle_exception(
logger.error(error_message)

def _chunking(
self, mdl_dict: dict, request: Input, chunk_size: int = 1
self, mdl_dict: dict, request: Input, chunk_size: int = 50
) -> list[dict]:
template = {
"user_prompt": request.user_prompt,
"mdl": mdl_dict,
"language": request.configuration.language,
}

chunks = [
{
**model,
"columns": model["columns"][i : i + chunk_size],
}
for model in mdl_dict["models"]
if model["name"] in request.selected_models
for i in range(0, len(model["columns"]), chunk_size)
]

return [
{
**template,
"selected_models": request.selected_models[i : i + chunk_size],
"mdl": {"models": [chunk]},
"selected_models": [chunk["name"]],
}
for i in range(0, len(request.selected_models), chunk_size)
for chunk in chunks
]

async def _generate_task(self, request_id: str, chunk: dict):
resp = await self._pipelines["semantics_description"].run(**chunk)
normalize = resp.get("normalize")

current = self[request_id]
current.response = current.response or {}
current.response.update(resp.get("normalize"))

for key in normalize.keys():
if key not in current.response:
current.response[key] = normalize[key]
continue

current.response[key]["columns"].extend(normalize[key]["columns"])

@observe(name="Generate Semantics Description")
@trace_metadata
Expand Down
22 changes: 12 additions & 10 deletions wren-ai-service/tests/pytest/services/test_semantics_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def test_generate_semantics_description(
id="test_id",
user_prompt="Describe the model",
selected_models=["model1"],
mdl='{"models": [{"name": "model1", "columns": []}]}',
mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}',
)

await service.generate(request)
Expand Down Expand Up @@ -80,7 +80,7 @@ async def test_generate_semantics_description_with_exception(
id="test_id",
user_prompt="Describe the model",
selected_models=["model1"],
mdl='{"models": [{"name": "model1", "columns": []}]}',
mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}',
)

service._pipelines["semantics_description"].run.side_effect = Exception(
Expand Down Expand Up @@ -136,7 +136,7 @@ async def test_batch_processing_with_multiple_models(
id="test_id",
user_prompt="Describe the models",
selected_models=["model1", "model2", "model3"],
mdl='{"models": [{"name": "model1"}, {"name": "model2"}, {"name": "model3"}]}',
mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model3", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}',
)

# Mock pipeline responses for each chunk
Expand Down Expand Up @@ -172,16 +172,18 @@ def test_batch_processing_with_custom_chunk_size(
id="test_id",
user_prompt="Describe the models",
selected_models=["model1", "model2", "model3", "model4"],
mdl='{"models": [{"name": "model1"}, {"name": "model2"}, {"name": "model3"}, {"name": "model4"}]}',
mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model3", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model4", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}',
)

# Test chunking with custom chunk size
chunks = service._chunking(orjson.loads(request.mdl), request, chunk_size=2)

assert len(chunks) == 2 # Should create 2 chunks with size 2
assert [len(chunk["selected_models"]) for chunk in chunks] == [2, 2]
assert chunks[0]["selected_models"] == ["model1", "model2"]
assert chunks[1]["selected_models"] == ["model3", "model4"]
assert len(chunks) == 4
assert [len(chunk["selected_models"]) for chunk in chunks] == [1, 1, 1, 1]
assert chunks[0]["selected_models"] == ["model1"]
assert chunks[1]["selected_models"] == ["model2"]
assert chunks[2]["selected_models"] == ["model3"]
assert chunks[3]["selected_models"] == ["model4"]


@pytest.mark.asyncio
Expand All @@ -193,7 +195,7 @@ async def test_batch_processing_partial_failure(
id="test_id",
user_prompt="Describe the models",
selected_models=["model1", "model2"],
mdl='{"models": [{"name": "model1"}, {"name": "model2"}]}',
mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}',
)

# Mock first chunk succeeds, second chunk fails
Expand Down Expand Up @@ -222,7 +224,7 @@ async def test_concurrent_updates_no_race_condition(
id=test_id,
user_prompt="Test concurrent updates",
selected_models=["model1", "model2", "model3", "model4", "model5"],
mdl='{"models": [{"name": "model1"}, {"name": "model2"}, {"name": "model3"}, {"name": "model4"}, {"name": "model5"}]}',
mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model3", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model4", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model5", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}',
)

# Mock pipeline responses with delays to simulate concurrent execution
Expand Down
Loading