Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix problem with keyless access and Active Learning #214

Merged
merged 2 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from inference.core.interfaces.stream.stream import Stream
from inference.core.interfaces.stream.stream import Stream # isort:skip
from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
from inference.models.utils import get_roboflow_model
9 changes: 8 additions & 1 deletion inference/core/interfaces/stream/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def init(
active_learning_enabled (Optional[bool]): Flag to enable / disable Active Learning middleware (setting it
true does not guarantee any data to be collected, as data collection is controlled by Roboflow backend -
it just enables middleware intercepting predictions). If not given, env variable
`ACTIVE_LEARNING_ENABLED` will be used.
`ACTIVE_LEARNING_ENABLED` will be used. Please point out that Active Learning will be forcefully
disabled in a scenario when Roboflow API key is not given, as Roboflow account is required
for this feature to be operational.

Other ENV variables involved in low-level configuration:
* INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching
Expand Down Expand Up @@ -170,6 +172,11 @@ def init(
f"with value: {ACTIVE_LEARNING_ENABLED}"
)
active_learning_enabled = ACTIVE_LEARNING_ENABLED
if api_key is None:
logger.info(
f"Roboflow API key not given - Active Learning is forced to be disabled."
)
active_learning_enabled = False
if active_learning_enabled is True:
active_learning_middleware = ThreadingActiveLearningMiddleware.init(
api_key=api_key,
Expand Down
11 changes: 6 additions & 5 deletions inference/core/managers/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
prediction = await super().infer_from_request(
model_id=model_id, request=request
model_id=model_id, request=request, **kwargs
)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
if not active_learning_eligible:
if not active_learning_eligible or request.api_key is None:
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
return prediction
Expand Down Expand Up @@ -108,11 +108,12 @@ class BackgroundTaskActiveLearningManager(ActiveLearningManager):
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes
prediction = await super().infer_from_request(
model_id=model_id, request=request
model_id=model_id, request=request, **kwargs
)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
if not active_learning_eligible:
if not active_learning_eligible or request.api_key is None:
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion inference/enterprise/stream_management/api/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PipelineInitialisationRequest(BaseModel):
sink_configuration: UDPSinkConfiguration = Field(
description="Configuration of the sink."
)
api_key: str = Field(description="Roboflow API key")
api_key: Optional[str] = Field(description="Roboflow API key", default=None)
max_fps: Optional[Union[float, int]] = Field(
description="Limit of FPS in video processing.", default=None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _initialise_pipeline(self, request_id: str, payload: dict) -> None:
model_id=payload["model_id"],
video_reference=payload["video_reference"],
on_prediction=sink,
api_key=payload["api_key"],
api_key=payload.get("api_key"),
max_fps=payload.get("max_fps"),
watchdog=watchdog,
source_buffer_filling_strategy=source_buffer_filling_strategy,
Expand Down
2 changes: 1 addition & 1 deletion inference_sdk/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def clip_compare(
)
payload = self.__initialise_payload()
payload["subject_type"] = subject_type
payload["prompt_type"] = subject_type
payload["prompt_type"] = prompt_type
if subject_type == "image":
encoded_image = load_static_inference_input(
inference_input=subject,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,44 @@ def test_initialise_pipeline_when_valid_payload_given(
}, "CommandResponse must be serialised directly to JSON response"


@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock)
def test_initialise_pipeline_when_valid_payload_given_without_api_key(
stream_manager_client: AsyncMock,
) -> None:
# given
client = TestClient(app.app)
stream_manager_client.initialise_pipeline.return_value = CommandResponse(
status="success",
context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"),
)

# when
response = client.post(
"/initialise",
json={
"model_id": "some/1",
"video_reference": "rtsp://some:543",
"sink_configuration": {
"type": "udp_sink",
"host": "127.0.0.1",
"port": 9090,
},
"model_configuration": {"type": "object-detection"},
"active_learning_enabled": True,
},
)

# then
assert response.status_code == 200, "Status code for success must be 200"
assert response.json() == {
"status": "success",
"context": {
"request_id": "my_request",
"pipeline_id": "my_pipeline",
},
}, "CommandResponse must be serialised directly to JSON response"


@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock)
def test_pause_pipeline_when_successful_response_expected(
stream_manager_client: AsyncMock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,52 @@ async def test_stream_manager_client_can_successfully_initialise_pipeline(
)


@pytest.mark.asyncio
@mock.patch.object(stream_manager_client, "establish_socket_connection")
async def test_stream_manager_client_can_successfully_initialise_pipeline_without_api_key(
establish_socket_connection_mock: AsyncMock,
) -> None:
# given
reader = assembly_socket_reader(
message={
"request_id": "my_request",
"pipeline_id": "new_pipeline",
"response": {"status": "success"},
},
header_size=4,
)
writer = DummyStreamWriter()
establish_socket_connection_mock.return_value = (reader, writer)
initialisation_request = PipelineInitialisationRequest(
model_id="some/1",
video_reference="rtsp://some:543",
sink_configuration=UDPSinkConfiguration(
type="udp_sink",
host="127.0.0.1",
port=9090,
),
model_configuration=ObjectDetectionModelConfiguration(type="object_detection"),
)
client = StreamManagerClient.init(
host="127.0.0.1",
port=7070,
operations_timeout=1.0,
header_size=4,
buffer_size=16438,
)

# when
result = await client.initialise_pipeline(
initialisation_request=initialisation_request
)

# then
assert result == CommandResponse(
status="success",
context=CommandContext(request_id="my_request", pipeline_id="new_pipeline"),
)


@pytest.mark.asyncio
@mock.patch.object(stream_manager_client, "establish_socket_connection")
async def test_stream_manager_client_can_successfully_terminate_pipeline(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested(

@pytest.mark.timeout(30)
@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init")
def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_invalid_payload_sent(
def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_without_api_key(
pipeline_init_mock: MagicMock,
) -> None:
# given
Expand All @@ -70,7 +70,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu
command_queue=command_queue, responses_queue=responses_queue
)
init_payload = assembly_valid_init_payload()
del init_payload["model_configuration"]
del init_payload["api_key"]

# when
command_queue.put(("1", init_payload))
Expand All @@ -82,24 +82,18 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu
status_2 = responses_queue.get()

# then
assert (
status_1[0] == "1"
), "First request should be reported in responses_queue at first"
assert (
status_1[1]["status"] == OperationStatus.FAILURE
), "Init operation should fail"
assert (
status_1[1]["error_type"] == ErrorType.INVALID_PAYLOAD
), "Invalid Payload error is expected"
assert status_1 == (
"1",
{"status": OperationStatus.SUCCESS},
), "Initialisation operation must succeed"
assert status_2 == (
"2",
{"status": OperationStatus.SUCCESS},
), "Termination of pipeline must happen"

), "Termination operation must succeed"

@pytest.mark.timeout(30)
@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init")
def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_api_key_not_given(
def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_invalid_payload_sent(
pipeline_init_mock: MagicMock,
) -> None:
# given
Expand All @@ -109,7 +103,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu
command_queue=command_queue, responses_queue=responses_queue
)
init_payload = assembly_valid_init_payload()
del init_payload["api_key"]
del init_payload["model_configuration"]

# when
command_queue.put(("1", init_payload))
Expand Down
41 changes: 41 additions & 0 deletions tests/inference_sdk/unit_tests/http/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,47 @@ def test_clip_compare_when_both_prompt_and_subject_are_texts(
}, "Request must contain API key, subject and prompt types as text, exact values of subject and list of prompt values"


@mock.patch.object(client, "load_static_inference_input")
def test_clip_compare_when_mixed_input_is_given(
load_static_inference_input_mock: MagicMock,
requests_mock: Mocker,
) -> None:
# given
api_url = "http://some.com"
http_client = InferenceHTTPClient(api_key="my-api-key", api_url=api_url)
load_static_inference_input_mock.side_effect = [
[("base64_image_1", 0.5)]
]
requests_mock.post(
f"{api_url}/clip/compare",
json={
"frame_id": None,
"time": 0.1435863340011565,
"similarity": [0.8963012099266052, 0.8830886483192444],
},
)

# when
result = http_client.clip_compare(
subject="/some/image.jpg",
prompt=["dog", "house"],
)

# then
assert result == {
"frame_id": None,
"time": 0.1435863340011565,
"similarity": [0.8963012099266052, 0.8830886483192444],
}, "Result must match the value returned by HTTP endpoint"
assert requests_mock.request_history[0].json() == {
"api_key": "my-api-key",
"subject": {"type": "base64", "value": "base64_image_1"},
"prompt": ["dog", "house"],
"prompt_type": "text",
"subject_type": "image",
}, "Request must contain API key, subject and prompt types as text, exact values of subject and list of prompt values"


@mock.patch.object(client, "load_static_inference_input")
def test_clip_compare_when_both_prompt_and_subject_are_images(
load_static_inference_input_mock: MagicMock,
Expand Down
Loading