Skip to content

Commit

Permalink
Set dataset_id in params from op doc (#3934)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixi-wang authored and benjaminpkane committed Dec 20, 2023
1 parent 7679100 commit a36d746
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 23 deletions.
36 changes: 25 additions & 11 deletions fiftyone/factory/repos/delegated_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
import logging
from datetime import datetime
from typing import Any, List

Expand All @@ -13,7 +14,6 @@
from pymongo import IndexModel
from pymongo.collection import Collection

import fiftyone.core.dataset as fod
from fiftyone.factory import DelegatedOperationPagingParams
from fiftyone.factory.repos import DelegatedOperationDocument
from fiftyone.operators.executor import (
Expand All @@ -23,6 +23,8 @@
ExecutionRunState,
)

logger = logging.getLogger(__name__)


class DelegatedOperationRepo(object):
"""Base Class for a delegated operation repository."""
Expand Down Expand Up @@ -118,7 +120,6 @@ def __init__(self, collection: Collection = None):
self._create_indexes()

def _get_collection(self) -> Collection:
import fiftyone as fo
import fiftyone.core.odm as foo

database: pymongo.database.Database = foo.get_db_conn()
Expand Down Expand Up @@ -162,19 +163,32 @@ def queue_operation(self, **kwargs: Any) -> DelegatedOperationDocument:
if delegation_target:
setattr(op, "delegation_target", delegation_target)

context = None
if isinstance(op.context, dict):
context = ExecutionContext(
request_params=op.context.get("request_params", {})
)
elif isinstance(op.context, ExecutionContext):
context = op.context
if not op.dataset_id:
# For consistency, set the dataset_id using the ExecutionContext.dataset
# For consistency, set the dataset_id using the
# ExecutionContext.dataset
# rather than calling load_dataset() on a potentially stale
# dataset_name in the request_params
context = None
if isinstance(op.context, dict):
context = ExecutionContext(
request_params=op.context.get("request_params", {})
)
elif isinstance(op.context, ExecutionContext):
context = op.context
if context and context.dataset:
try:
op.dataset_id = context.dataset._doc.id
except:
# If we can't resolve the dataset_id, it is possible the
# dataset doesn't exist (deleted/being created). However,
# it's also possible that future operators can run
# dataset-less, so don't raise an error here and just log it
# in case we need to debug later.
logger.debug("Could not resolve dataset_id for operation. ")
elif op.dataset_id:
# If the dataset_id is provided, we set it in the request_params
# to ensure that the operation is executed on the correct dataset
context.request_params["dataset_id"] = str(op.dataset_id)
context.request_params["dataset_name"] = context.dataset.name

doc = self._collection.insert_one(op.to_pymongo())
op.id = doc.inserted_id
Expand Down
75 changes: 63 additions & 12 deletions tests/unittests/delegated_operators_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@
from fiftyone.operators.operator import Operator, OperatorConfig


class MockDataset:
def __init__(self, **kwargs):
self.name = kwargs.get("name", "test_dataset")
self._doc = mock.MagicMock()
self._doc.id = kwargs.get("id", ObjectId())

def save(self):
pass

def delete(self):
pass


class MockOperator(Operator):
def __init__(self, success=True, sets_progress=False, **kwargs):
self.success = success
Expand Down Expand Up @@ -257,12 +270,20 @@ def test_list_queued_operations(
dataset.delete()
dataset2.delete()

def test_set_run_states(self, mock_get_operator, mock_operator_exists):
@patch(
"fiftyone.core.odm.utils.load_dataset",
)
def test_set_run_states(
self, mock_load_dataset, mock_get_operator, mock_operator_exists
):
mock_load_dataset.return_value = MockDataset()
doc = self.svc.queue_operation(
operator="@voxelfiftyone/operator/foo",
label=mock_get_operator.return_value.name,
delegation_target=f"test_target",
context=ExecutionContext(request_params={"foo": "bar"}),
context=ExecutionContext(
request_params={"foo": "bar", "dataset_id": str(ObjectId())}
),
)

original_updated_at = doc.updated_at
Expand Down Expand Up @@ -291,13 +312,21 @@ def test_set_run_states(self, mock_get_operator, mock_operator_exists):
self.assertIsNotNone(doc.result.error)
self.assertNotEqual(doc.updated_at, original_updated_at)

def test_sets_progress(self, mock_get_operator, mock_operator_exists):
@patch(
"fiftyone.core.odm.utils.load_dataset",
)
def test_sets_progress(
self, mock_load_dataset, mock_get_operator, mock_operator_exists
):
mock_load_dataset.return_value = MockDataset()
mock_get_operator.return_value = MockOperator(sets_progress=True)

doc = self.svc.queue_operation(
operator="@voxelfiftyone/operator/foo",
delegation_target=f"test_target",
context=ExecutionContext(request_params={"foo": "bar"}),
context=ExecutionContext(
request_params={"foo": "bar", "dataset_id": str(ObjectId())}
),
)

self.docs_to_delete.append(doc)
Expand All @@ -312,12 +341,20 @@ def test_sets_progress(self, mock_get_operator, mock_operator_exists):
self.assertEqual(doc.status.label, "halfway there")
self.assertIsNotNone(doc.status.updated_at)

def test_full_run_success(self, mock_get_operator, mock_operator_exists):
@patch(
"fiftyone.core.odm.utils.load_dataset",
)
def test_full_run_success(
self, mock_load_dataset, mock_get_operator, mock_operator_exists
):
mock_load_dataset.return_value = MockDataset()
doc = self.svc.queue_operation(
operator="@voxelfiftyone/operator/foo",
label=mock_get_operator.return_value.name,
delegation_target=f"test_target",
context=ExecutionContext(request_params={"foo": "bar"}),
context=ExecutionContext(
request_params={"foo": "bar", "dataset_id": str(ObjectId())}
),
)

self.docs_to_delete.append(doc)
Expand All @@ -336,17 +373,22 @@ def test_full_run_success(self, mock_get_operator, mock_operator_exists):

self.assertEqual(doc.result.result, {"executed": True})

@patch(
"fiftyone.core.odm.utils.load_dataset",
)
def test_generator_run_success(
self, mock_get_operator, mock_operator_exists
self, mock_load_dataset, mock_get_operator, mock_operator_exists
):

mock_load_dataset.return_value = MockDataset()
mock_get_operator.return_value = MockGeneratorOperator()

doc = self.svc.queue_operation(
operator="@voxelfiftyone/operator/generator_op",
label=mock_get_operator.return_value.name,
delegation_target=f"test_target_generator",
context=ExecutionContext(request_params={"foo": "bar"}),
context=ExecutionContext(
request_params={"foo": "bar", "dataset_id": str(ObjectId())}
),
)

self.docs_to_delete.append(doc)
Expand All @@ -364,9 +406,13 @@ def test_generator_run_success(
self.assertIsNone(doc.result)
self.assertIsNone(doc.failed_at)

@patch(
"fiftyone.core.odm.utils.load_dataset",
)
def test_generator_sets_progress(
self, mock_get_operator, mock_operator_exists
self, mock_load_dataset, mock_get_operator, mock_operator_exists
):
mock_load_dataset.return_value = MockDataset()
mock_get_operator.return_value = MockGeneratorOperator(
sets_progress=True
)
Expand All @@ -389,9 +435,14 @@ def test_generator_sets_progress(
self.assertEqual(doc.status.label, "halfway there")
self.assertIsNotNone(doc.status.updated_at)

def test_updates_progress(self, mock_get_operator, mock_operator_exists):
@patch(
"fiftyone.core.odm.utils.load_dataset",
)
def test_updates_progress(
self, mock_load_dataset, mock_get_operator, mock_operator_exists
):
mock_get_operator.return_value = MockProgressiveOperator()

mock_load_dataset.return_value = MockDataset()
doc = self.svc.queue_operation(
operator="@voxelfiftyone/operator/foo",
delegation_target=f"test_target",
Expand Down

0 comments on commit a36d746

Please sign in to comment.