diff --git a/CHANGELOG.md b/CHANGELOG.md index c868573a3d..375b49a760 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ These are the section headers that we use: - Fixed error in the unification strategy for `RankingQuestion` ([#4295](https://github.com/argilla-io/argilla/pull/4295)) - Fixed `TextClassificationSettings.labels_schema` order was not being preserved. Closes [#3828](https://github.com/argilla-io/argilla/issues/3828) ([#4332](https://github.com/argilla-io/argilla/pull/4332)) - Fixed error when requesting non-existing API endpoints. Closes [#4073](https://github.com/argilla-io/argilla/issues/4073) ([#4325](https://github.com/argilla-io/argilla/pull/4325)) +- Fixed error when passing `draft` responses to create records endpoint. ([#4354](https://github.com/argilla-io/argilla/pull/4354)) ### Changed diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index c0cbedd5bb..625f211fc5 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -491,8 +491,14 @@ class UserDiscardedResponseCreate(BaseModel): status: Literal[ResponseStatus.discarded] +class UserDraftResponseCreate(BaseModel): + user_id: UUID + values: Dict[str, ResponseValueCreate] + status: Literal[ResponseStatus.draft] + + UserResponseCreate = Annotated[ - Union[UserSubmittedResponseCreate, UserDiscardedResponseCreate], + Union[UserSubmittedResponseCreate, UserDraftResponseCreate, UserDiscardedResponseCreate], PydanticField(discriminator="status"), ] diff --git a/tests/unit/server/api/v1/test_datasets.py b/tests/unit/server/api/v1/test_datasets.py index d26fc4644a..c3bf220040 100644 --- a/tests/unit/server/api/v1/test_datasets.py +++ b/tests/unit/server/api/v1/test_datasets.py @@ -3088,6 +3088,45 @@ async def test_create_dataset_records_with_discarded_response( await db.execute(select(func.count(Response.id)).filter(Response.status == ResponseStatus.discarded)) ).scalar() == 1 + async def test_create_dataset_records_with_draft_response( + self, + async_client: "AsyncClient", + db: "AsyncSession", + owner: User, + owner_auth_header: dict, + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + await TextFieldFactory.create(name="input", dataset=dataset) + await TextFieldFactory.create(name="output", dataset=dataset) + + await TextQuestionFactory.create(name="input_ok", dataset=dataset) + await TextQuestionFactory.create(name="output_ok", dataset=dataset) + + records_json = { + "items": [ + { + "fields": {"input": "Say Hello", "output": "Hello"}, + "responses": [ + { + "values": {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + "status": "draft", + "user_id": str(owner.id), + } + ], + }, + ] + } + + response = await async_client.post( + f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + ) + + assert response.status_code == 204 + assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 + assert ( + await db.execute(select(func.count(Response.id)).filter(Response.status == ResponseStatus.draft)) + ).scalar() == 1 + async def test_create_dataset_records_with_invalid_response_status( self, async_client: "AsyncClient",