diff --git a/inference/__init__.py b/inference/__init__.py index 670d7ec941..8dc9d99aed 100644 --- a/inference/__init__.py +++ b/inference/__init__.py @@ -1,3 +1,3 @@ -from inference.core.interfaces.stream.stream import Stream +from inference.core.interfaces.stream.stream import Stream # isort:skip from inference.core.interfaces.stream.inference_pipeline import InferencePipeline from inference.models.utils import get_roboflow_model diff --git a/inference/core/interfaces/stream/inference_pipeline.py b/inference/core/interfaces/stream/inference_pipeline.py index f6b1156f95..58b3e69569 100644 --- a/inference/core/interfaces/stream/inference_pipeline.py +++ b/inference/core/interfaces/stream/inference_pipeline.py @@ -127,7 +127,9 @@ def init( active_learning_enabled (Optional[bool]): Flag to enable / disable Active Learning middleware (setting it true does not guarantee any data to be collected, as data collection is controlled by Roboflow backend - it just enables middleware intercepting predictions). If not given, env variable - `ACTIVE_LEARNING_ENABLED` will be used. + `ACTIVE_LEARNING_ENABLED` will be used. Please point out that Active Learning will be forcefully + disabled in a scenario when Roboflow API key is not given, as Roboflow account is required + for this feature to be operational. Other ENV variables involved in low-level configuration: * INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching @@ -170,6 +172,11 @@ def init( f"with value: {ACTIVE_LEARNING_ENABLED}" ) active_learning_enabled = ACTIVE_LEARNING_ENABLED + if api_key is None: + logger.info( + f"Roboflow API key not given - Active Learning is forced to be disabled." + ) + active_learning_enabled = False if active_learning_enabled is True: active_learning_middleware = ThreadingActiveLearningMiddleware.init( api_key=api_key, diff --git a/inference/core/managers/active_learning.py b/inference/core/managers/active_learning.py index 2d004d0c99..4dc497f080 100644 --- a/inference/core/managers/active_learning.py +++ b/inference/core/managers/active_learning.py @@ -31,10 +31,10 @@ async def infer_from_request( self, model_id: str, request: InferenceRequest, **kwargs ) -> InferenceResponse: prediction = await super().infer_from_request( - model_id=model_id, request=request + model_id=model_id, request=request, **kwargs ) active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) - if not active_learning_eligible: + if not active_learning_eligible or request.api_key is None: return prediction self.register(prediction=prediction, model_id=model_id, request=request) return prediction @@ -108,11 +108,12 @@ class BackgroundTaskActiveLearningManager(ActiveLearningManager): async def infer_from_request( self, model_id: str, request: InferenceRequest, **kwargs ) -> InferenceResponse: + active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) + kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes prediction = await super().infer_from_request( - model_id=model_id, request=request + model_id=model_id, request=request, **kwargs ) - active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) - if not active_learning_eligible: + if not active_learning_eligible or request.api_key is None: return prediction if BACKGROUND_TASKS_PARAM not in kwargs: logger.warning( diff --git a/inference/enterprise/stream_management/api/entities.py b/inference/enterprise/stream_management/api/entities.py index 4c93d69b54..16648a6e87 100644 --- a/inference/enterprise/stream_management/api/entities.py +++ b/inference/enterprise/stream_management/api/entities.py @@ -51,7 +51,7 @@ class PipelineInitialisationRequest(BaseModel): sink_configuration: UDPSinkConfiguration = Field( description="Configuration of the sink." ) - api_key: str = Field(description="Roboflow API key") + api_key: Optional[str] = Field(description="Roboflow API key", default=None) max_fps: Optional[Union[float, int]] = Field( description="Limit of FPS in video processing.", default=None ) diff --git a/inference/enterprise/stream_management/manager/inference_pipeline_manager.py b/inference/enterprise/stream_management/manager/inference_pipeline_manager.py index 0eef0d8039..36c0ba7edb 100644 --- a/inference/enterprise/stream_management/manager/inference_pipeline_manager.py +++ b/inference/enterprise/stream_management/manager/inference_pipeline_manager.py @@ -115,7 +115,7 @@ def _initialise_pipeline(self, request_id: str, payload: dict) -> None: model_id=payload["model_id"], video_reference=payload["video_reference"], on_prediction=sink, - api_key=payload["api_key"], + api_key=payload.get("api_key"), max_fps=payload.get("max_fps"), watchdog=watchdog, source_buffer_filling_strategy=source_buffer_filling_strategy, diff --git a/inference_sdk/http/client.py b/inference_sdk/http/client.py index dc91a5c541..f04ea3de34 100644 --- a/inference_sdk/http/client.py +++ b/inference_sdk/http/client.py @@ -478,7 +478,7 @@ def clip_compare( ) payload = self.__initialise_payload() payload["subject_type"] = subject_type - payload["prompt_type"] = subject_type + payload["prompt_type"] = prompt_type if subject_type == "image": encoded_image = load_static_inference_input( inference_input=subject, diff --git a/tests/inference/unit_tests/enterprise/stream_management/api/test_app.py b/tests/inference/unit_tests/enterprise/stream_management/api/test_app.py index ce53d62a7d..e8e3e3bde4 100644 --- a/tests/inference/unit_tests/enterprise/stream_management/api/test_app.py +++ b/tests/inference/unit_tests/enterprise/stream_management/api/test_app.py @@ -177,6 +177,44 @@ def test_initialise_pipeline_when_valid_payload_given( }, "CommandResponse must be serialised directly to JSON response" +@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) +def test_initialise_pipeline_when_valid_payload_given_without_api_key( + stream_manager_client: AsyncMock, +) -> None: + # given + client = TestClient(app.app) + stream_manager_client.initialise_pipeline.return_value = CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"), + ) + + # when + response = client.post( + "/initialise", + json={ + "model_id": "some/1", + "video_reference": "rtsp://some:543", + "sink_configuration": { + "type": "udp_sink", + "host": "127.0.0.1", + "port": 9090, + }, + "model_configuration": {"type": "object-detection"}, + "active_learning_enabled": True, + }, + ) + + # then + assert response.status_code == 200, "Status code for success must be 200" + assert response.json() == { + "status": "success", + "context": { + "request_id": "my_request", + "pipeline_id": "my_pipeline", + }, + }, "CommandResponse must be serialised directly to JSON response" + + @mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock) def test_pause_pipeline_when_successful_response_expected( stream_manager_client: AsyncMock, diff --git a/tests/inference/unit_tests/enterprise/stream_management/api/test_stream_manager_client.py b/tests/inference/unit_tests/enterprise/stream_management/api/test_stream_manager_client.py index 13db0670f2..e7f2f24fe0 100644 --- a/tests/inference/unit_tests/enterprise/stream_management/api/test_stream_manager_client.py +++ b/tests/inference/unit_tests/enterprise/stream_management/api/test_stream_manager_client.py @@ -506,6 +506,52 @@ async def test_stream_manager_client_can_successfully_initialise_pipeline( ) +@pytest.mark.asyncio +@mock.patch.object(stream_manager_client, "establish_socket_connection") +async def test_stream_manager_client_can_successfully_initialise_pipeline_without_api_key( + establish_socket_connection_mock: AsyncMock, +) -> None: + # given + reader = assembly_socket_reader( + message={ + "request_id": "my_request", + "pipeline_id": "new_pipeline", + "response": {"status": "success"}, + }, + header_size=4, + ) + writer = DummyStreamWriter() + establish_socket_connection_mock.return_value = (reader, writer) + initialisation_request = PipelineInitialisationRequest( + model_id="some/1", + video_reference="rtsp://some:543", + sink_configuration=UDPSinkConfiguration( + type="udp_sink", + host="127.0.0.1", + port=9090, + ), + model_configuration=ObjectDetectionModelConfiguration(type="object_detection"), + ) + client = StreamManagerClient.init( + host="127.0.0.1", + port=7070, + operations_timeout=1.0, + header_size=4, + buffer_size=16438, + ) + + # when + result = await client.initialise_pipeline( + initialisation_request=initialisation_request + ) + + # then + assert result == CommandResponse( + status="success", + context=CommandContext(request_id="my_request", pipeline_id="new_pipeline"), + ) + + @pytest.mark.asyncio @mock.patch.object(stream_manager_client, "establish_socket_connection") async def test_stream_manager_client_can_successfully_terminate_pipeline( diff --git a/tests/inference/unit_tests/enterprise/stream_management/manager/test_inference_pipeline_manager.py b/tests/inference/unit_tests/enterprise/stream_management/manager/test_inference_pipeline_manager.py index a89a486c6f..2d1beece4c 100644 --- a/tests/inference/unit_tests/enterprise/stream_management/manager/test_inference_pipeline_manager.py +++ b/tests/inference/unit_tests/enterprise/stream_management/manager/test_inference_pipeline_manager.py @@ -60,7 +60,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested( @pytest.mark.timeout(30) @mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") -def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_invalid_payload_sent( +def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_without_api_key( pipeline_init_mock: MagicMock, ) -> None: # given @@ -70,7 +70,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu command_queue=command_queue, responses_queue=responses_queue ) init_payload = assembly_valid_init_payload() - del init_payload["model_configuration"] + del init_payload["api_key"] # when command_queue.put(("1", init_payload)) @@ -82,24 +82,18 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu status_2 = responses_queue.get() # then - assert ( - status_1[0] == "1" - ), "First request should be reported in responses_queue at first" - assert ( - status_1[1]["status"] == OperationStatus.FAILURE - ), "Init operation should fail" - assert ( - status_1[1]["error_type"] == ErrorType.INVALID_PAYLOAD - ), "Invalid Payload error is expected" + assert status_1 == ( + "1", + {"status": OperationStatus.SUCCESS}, + ), "Initialisation operation must succeed" assert status_2 == ( "2", {"status": OperationStatus.SUCCESS}, - ), "Termination of pipeline must happen" - + ), "Termination operation must succeed" @pytest.mark.timeout(30) @mock.patch.object(inference_pipeline_manager.InferencePipeline, "init") -def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_api_key_not_given( +def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_invalid_payload_sent( pipeline_init_mock: MagicMock, ) -> None: # given @@ -109,7 +103,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu command_queue=command_queue, responses_queue=responses_queue ) init_payload = assembly_valid_init_payload() - del init_payload["api_key"] + del init_payload["model_configuration"] # when command_queue.put(("1", init_payload)) diff --git a/tests/inference_sdk/unit_tests/http/test_client.py b/tests/inference_sdk/unit_tests/http/test_client.py index 1211455e13..6b7c8a2eef 100644 --- a/tests/inference_sdk/unit_tests/http/test_client.py +++ b/tests/inference_sdk/unit_tests/http/test_client.py @@ -1561,6 +1561,47 @@ def test_clip_compare_when_both_prompt_and_subject_are_texts( }, "Request must contain API key, subject and prompt types as text, exact values of subject and list of prompt values" +@mock.patch.object(client, "load_static_inference_input") +def test_clip_compare_when_mixed_input_is_given( + load_static_inference_input_mock: MagicMock, + requests_mock: Mocker, +) -> None: + # given + api_url = "http://some.com" + http_client = InferenceHTTPClient(api_key="my-api-key", api_url=api_url) + load_static_inference_input_mock.side_effect = [ + [("base64_image_1", 0.5)] + ] + requests_mock.post( + f"{api_url}/clip/compare", + json={ + "frame_id": None, + "time": 0.1435863340011565, + "similarity": [0.8963012099266052, 0.8830886483192444], + }, + ) + + # when + result = http_client.clip_compare( + subject="/some/image.jpg", + prompt=["dog", "house"], + ) + + # then + assert result == { + "frame_id": None, + "time": 0.1435863340011565, + "similarity": [0.8963012099266052, 0.8830886483192444], + }, "Result must match the value returned by HTTP endpoint" + assert requests_mock.request_history[0].json() == { + "api_key": "my-api-key", + "subject": {"type": "base64", "value": "base64_image_1"}, + "prompt": ["dog", "house"], + "prompt_type": "text", + "subject_type": "image", + }, "Request must contain API key, subject and prompt types as text, exact values of subject and list of prompt values" + + @mock.patch.object(client, "load_static_inference_input") def test_clip_compare_when_both_prompt_and_subject_are_images( load_static_inference_input_mock: MagicMock,