diff --git a/wren-ai-service/src/pipelines/generation/semantics_description.py b/wren-ai-service/src/pipelines/generation/semantics_description.py index a7b7fb12d..b0e16ef4b 100644 --- a/wren-ai-service/src/pipelines/generation/semantics_description.py +++ b/wren-ai-service/src/pipelines/generation/semantics_description.py @@ -20,15 +20,29 @@ ## Start of Pipeline @observe(capture_input=False) def picked_models(mdl: dict, selected_models: list[str]) -> list[dict]: - def remove_relation_columns(columns: list[dict]) -> list[dict]: - # remove columns that have a relationship property - return [column for column in columns if "relationship" not in column] + def relation_filter(column: dict) -> bool: + return "relationship" not in column + + def column_formatter(columns: list[dict]) -> list[dict]: + return [ + { + "name": column["name"], + "type": column["type"], + "properties": { + "description": column["properties"].get("description", ""), + }, + } + for column in columns + if relation_filter(column) + ] def extract(model: dict) -> dict: return { "name": model["name"], - "columns": remove_relation_columns(model["columns"]), - "properties": model["properties"], + "columns": column_formatter(model["columns"]), + "properties": { + "description": model["properties"].get("description", ""), + }, } return [ @@ -117,11 +131,11 @@ class SemanticResult(BaseModel): ``` [ {'name': 'model', 'columns': [ - {'name': 'column_1', 'type': 'type', 'notNull': True, 'properties': {} + {'name': 'column_1', 'type': 'type', 'properties': {} }, - {'name': 'column_2', 'type': 'type', 'notNull': True, 'properties': {} + {'name': 'column_2', 'type': 'type', 'properties': {} }, - {'name': 'column_3', 'type': 'type', 'notNull': False, 'properties': {} + {'name': 'column_3', 'type': 'type', 'properties': {} } ], 'properties': {} } diff --git a/wren-ai-service/src/web/v1/routers/relationship_recommendation.py b/wren-ai-service/src/web/v1/routers/relationship_recommendation.py index 2b0e2b2ab..b3d17654a 100644 --- a/wren-ai-service/src/web/v1/routers/relationship_recommendation.py +++ b/wren-ai-service/src/web/v1/routers/relationship_recommendation.py @@ -25,7 +25,8 @@ - Generates a new relationship recommendation - Request body: PostRequest { - "mdl": "{ ... }" # JSON string of the MDL (Model Definition Language) + "mdl": "{ ... }", # JSON string of the MDL (Model Definition Language) + "project_id": "project-id" # Optional project ID } - Response: PostResponse { @@ -62,6 +63,7 @@ class PostRequest(BaseModel): mdl: str + project_id: Optional[str] = None class PostResponse(BaseModel): diff --git a/wren-ai-service/src/web/v1/routers/semantics_description.py b/wren-ai-service/src/web/v1/routers/semantics_description.py index a66ac8806..fcaa5caae 100644 --- a/wren-ai-service/src/web/v1/routers/semantics_description.py +++ b/wren-ai-service/src/web/v1/routers/semantics_description.py @@ -29,6 +29,7 @@ "selected_models": ["model1", "model2"], # List of model names to describe "user_prompt": "Describe these models", # User's instruction for description "mdl": "{ ... }", # JSON string of the MDL (Model Definition Language) + "project_id": "project-id", # Optional project ID "configuration": { # Optional configuration settings "language": "English" # Optional language, defaults to "English" } @@ -89,6 +90,7 @@ class PostRequest(BaseModel): selected_models: list[str] user_prompt: str mdl: str + project_id: Optional[str] = None configuration: Optional[Configuration] = Configuration() diff --git a/wren-ai-service/src/web/v1/services/semantics_description.py b/wren-ai-service/src/web/v1/services/semantics_description.py index 7ea655f20..d4e97879e 100644 --- a/wren-ai-service/src/web/v1/services/semantics_description.py +++ b/wren-ai-service/src/web/v1/services/semantics_description.py @@ -57,27 +57,45 @@ def _handle_exception( logger.error(error_message) def _chunking( - self, mdl_dict: dict, request: Input, chunk_size: int = 1 + self, mdl_dict: dict, request: Input, chunk_size: int = 50 ) -> list[dict]: template = { "user_prompt": request.user_prompt, - "mdl": mdl_dict, "language": request.configuration.language, } + + chunks = [ + { + **model, + "columns": model["columns"][i : i + chunk_size], + } + for model in mdl_dict["models"] + if model["name"] in request.selected_models + for i in range(0, len(model["columns"]), chunk_size) + ] + return [ { **template, - "selected_models": request.selected_models[i : i + chunk_size], + "mdl": {"models": [chunk]}, + "selected_models": [chunk["name"]], } - for i in range(0, len(request.selected_models), chunk_size) + for chunk in chunks ] async def _generate_task(self, request_id: str, chunk: dict): resp = await self._pipelines["semantics_description"].run(**chunk) + normalize = resp.get("normalize") current = self[request_id] current.response = current.response or {} - current.response.update(resp.get("normalize")) + + for key in normalize.keys(): + if key not in current.response: + current.response[key] = normalize[key] + continue + + current.response[key]["columns"].extend(normalize[key]["columns"]) @observe(name="Generate Semantics Description") @trace_metadata diff --git a/wren-ai-service/tests/pytest/services/test_semantics_description.py b/wren-ai-service/tests/pytest/services/test_semantics_description.py index d8dfd694c..dc48e7339 100644 --- a/wren-ai-service/tests/pytest/services/test_semantics_description.py +++ b/wren-ai-service/tests/pytest/services/test_semantics_description.py @@ -32,7 +32,7 @@ async def test_generate_semantics_description( id="test_id", user_prompt="Describe the model", selected_models=["model1"], - mdl='{"models": [{"name": "model1", "columns": []}]}', + mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', ) await service.generate(request) @@ -80,7 +80,7 @@ async def test_generate_semantics_description_with_exception( id="test_id", user_prompt="Describe the model", selected_models=["model1"], - mdl='{"models": [{"name": "model1", "columns": []}]}', + mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', ) service._pipelines["semantics_description"].run.side_effect = Exception( @@ -136,7 +136,7 @@ async def test_batch_processing_with_multiple_models( id="test_id", user_prompt="Describe the models", selected_models=["model1", "model2", "model3"], - mdl='{"models": [{"name": "model1"}, {"name": "model2"}, {"name": "model3"}]}', + mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model3", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', ) # Mock pipeline responses for each chunk @@ -172,16 +172,18 @@ def test_batch_processing_with_custom_chunk_size( id="test_id", user_prompt="Describe the models", selected_models=["model1", "model2", "model3", "model4"], - mdl='{"models": [{"name": "model1"}, {"name": "model2"}, {"name": "model3"}, {"name": "model4"}]}', + mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model3", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model4", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', ) # Test chunking with custom chunk size chunks = service._chunking(orjson.loads(request.mdl), request, chunk_size=2) - assert len(chunks) == 2 # Should create 2 chunks with size 2 - assert [len(chunk["selected_models"]) for chunk in chunks] == [2, 2] - assert chunks[0]["selected_models"] == ["model1", "model2"] - assert chunks[1]["selected_models"] == ["model3", "model4"] + assert len(chunks) == 4 + assert [len(chunk["selected_models"]) for chunk in chunks] == [1, 1, 1, 1] + assert chunks[0]["selected_models"] == ["model1"] + assert chunks[1]["selected_models"] == ["model2"] + assert chunks[2]["selected_models"] == ["model3"] + assert chunks[3]["selected_models"] == ["model4"] @pytest.mark.asyncio @@ -193,7 +195,7 @@ async def test_batch_processing_partial_failure( id="test_id", user_prompt="Describe the models", selected_models=["model1", "model2"], - mdl='{"models": [{"name": "model1"}, {"name": "model2"}]}', + mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', ) # Mock first chunk succeeds, second chunk fails @@ -222,7 +224,7 @@ async def test_concurrent_updates_no_race_condition( id=test_id, user_prompt="Test concurrent updates", selected_models=["model1", "model2", "model3", "model4", "model5"], - mdl='{"models": [{"name": "model1"}, {"name": "model2"}, {"name": "model3"}, {"name": "model4"}, {"name": "model5"}]}', + mdl='{"models": [{"name": "model1", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model2", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model3", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model4", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}, {"name": "model5", "columns": [{"name": "column1", "type": "varchar", "notNull": false}]}]}', ) # Mock pipeline responses with delays to simulate concurrent execution