diff --git a/flowcept/commons/daos/document_db_dao.py b/flowcept/commons/daos/document_db_dao.py index e25f5e14..8c529ee0 100644 --- a/flowcept/commons/daos/document_db_dao.py +++ b/flowcept/commons/daos/document_db_dao.py @@ -17,6 +17,7 @@ MONGO_TASK_COLLECTION, MONGO_WORKFLOWS_COLLECTION, PERF_LOG, + MONGO_URI, ) from flowcept.flowceptor.consumers.consumer_utils import ( curate_dict_task_messages, @@ -27,7 +28,11 @@ class DocumentDBDao(object): def __init__(self): self.logger = FlowceptLogger().get_logger() - client = MongoClient(MONGO_HOST, MONGO_PORT) + + if MONGO_URI is not None: + client = MongoClient(MONGO_URI) + else: + client = MongoClient(MONGO_HOST, MONGO_PORT) self._db = client[MONGO_DB] self._tasks_collection = self._db[MONGO_TASK_COLLECTION] diff --git a/flowcept/commons/daos/keyvalue_dao.py b/flowcept/commons/daos/keyvalue_dao.py new file mode 100644 index 00000000..c3785c03 --- /dev/null +++ b/flowcept/commons/daos/keyvalue_dao.py @@ -0,0 +1,45 @@ +from redis import Redis + +from flowcept.commons.flowcept_logger import FlowceptLogger +from flowcept.configs import ( + REDIS_HOST, + REDIS_PORT, + REDIS_PASSWORD, +) + + +class KeyValueDAO: + def __init__(self, connection=None): + self.logger = FlowceptLogger().get_logger() + if connection is None: + self._redis = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=0, + password=REDIS_PASSWORD, + ) + else: + self._redis = connection + + def delete_set(self, set_name: str): + self._redis.delete(set_name) + + def add_key_into_set(self, set_name: str, key): + self._redis.sadd(set_name, key) + + def remove_key_from_set(self, set_name: str, key): + self._redis.srem(set_name, key) + + def set_has_key(self, set_name: str, key) -> bool: + return self._redis.sismember(set_name, key) + + def set_count(self, set_name: str): + return self._redis.scard(set_name) + + def set_is_empty(self, set_name: str) -> bool: + return self.set_count(set_name) == 0 + + def delete_all_matching_sets(self, key_pattern): + matching_sets = self._redis.keys(key_pattern) + for set_name in matching_sets: + self.delete_set(set_name) diff --git a/flowcept/commons/daos/mq_dao.py b/flowcept/commons/daos/mq_dao.py index 2a933d5f..74104774 100644 --- a/flowcept/commons/daos/mq_dao.py +++ b/flowcept/commons/daos/mq_dao.py @@ -4,6 +4,7 @@ from threading import Thread, Lock from time import time, sleep +from flowcept.commons.daos.keyvalue_dao import KeyValueDAO from flowcept.commons.utils import perf_log from flowcept.commons.flowcept_logger import FlowceptLogger from flowcept.configs import ( @@ -16,6 +17,7 @@ REDIS_BUFFER_SIZE, REDIS_INSERTION_BUFFER_TIME, PERF_LOG, + REDIS_URI, ) from flowcept.commons.utils import GenericJSONEncoder @@ -26,11 +28,33 @@ class MQDao: ENCODER = GenericJSONEncoder if JSON_SERIALIZER == "complex" else None # TODO we don't have a unit test to cover complex dict! + @staticmethod + def _get_set_name(exec_bundle_id=None): + """ + :param exec_bundle_id: A way to group one or many interceptors, and treat each group as a bundle to control when their time_based threads started and ended. + :return: + """ + set_id = f"started_mq_thread_execution" + if exec_bundle_id is not None: + set_id += "_" + str(exec_bundle_id) + return set_id + def __init__(self): self.logger = FlowceptLogger().get_logger() - self._redis = Redis( - host=REDIS_HOST, port=REDIS_PORT, db=0, password=REDIS_PASSWORD - ) + + if REDIS_URI is not None: + # If a URI is provided, use it for connection + self._redis = Redis.from_url(REDIS_URI) + else: + # Otherwise, use the host, port, and password settings + self._redis = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=0, + password=REDIS_PASSWORD if REDIS_PASSWORD else None, + ) + + self._keyvalue_dao = KeyValueDAO(connection=self._redis) self._buffer = None self._time_thread: Thread = None self._previous_time = -1 @@ -38,36 +62,71 @@ def __init__(self): self._time_based_flushing_started = False self._lock = None - def start_time_based_flushing(self): + def register_time_based_thread_init( + self, interceptor_instance_id: int, exec_bundle_id=None + ): + set_name = MQDao._get_set_name(exec_bundle_id) + self.logger.debug( + f"Registering the beginning of the time_based MQ flush thread {set_name}.{interceptor_instance_id}" + ) + self._keyvalue_dao.add_key_into_set(set_name, interceptor_instance_id) + + def register_time_based_thread_end( + self, interceptor_instance_id: int, exec_bundle_id=None + ): + set_name = MQDao._get_set_name(exec_bundle_id) + self.logger.debug( + f"Registering the end of the time_based MQ flush thread {set_name}.{interceptor_instance_id}" + ) + self._keyvalue_dao.remove_key_from_set( + set_name, interceptor_instance_id + ) + + def all_time_based_threads_ended(self, exec_bundle_id=None): + set_name = MQDao._get_set_name(exec_bundle_id) + return self._keyvalue_dao.set_is_empty(set_name) + + def delete_all_time_based_threads_sets(self): + return self._keyvalue_dao.delete_all_matching_sets( + MQDao._get_set_name() + "*" + ) + + def start_time_based_flushing( + self, interceptor_instance_id: int, exec_bundle_id=None + ): self._buffer = list() self._time_thread: Thread = None self._previous_time = time() self._stop_flag = False self._time_based_flushing_started = False self._lock = Lock() - self._time_thread = Thread(target=self.time_based_flushing) - self._redis.incr(REDIS_STARTED_MQ_THREADS_KEY) - self.logger.debug( - f"Incrementing REDIS_STARTED_MQ_THREADS_KEY. Now: {self.get_started_mq_threads()}" + self.register_time_based_thread_init( + interceptor_instance_id, exec_bundle_id ) + # self._redis.incr(REDIS_STARTED_MQ_THREADS_KEY) + # self.logger.debug( + # f"Incrementing REDIS_STARTED_MQ_THREADS_KEY. Now: {self.get_started_mq_threads()}" + # ) self._time_based_flushing_started = True self._time_thread.start() - def get_started_mq_threads(self): - return int(self._redis.get(REDIS_STARTED_MQ_THREADS_KEY)) - - def reset_started_mq_threads(self): - self.logger.debug("RESETTING REDIS_STARTED_MQ_THREADS_KEY TO 0") - self._redis.set(REDIS_STARTED_MQ_THREADS_KEY, 0) + # def get_started_mq_threads(self): + # return int(self._redis.get(REDIS_STARTED_MQ_THREADS_KEY)) + # + # def reset_started_mq_threads(self): + # self.logger.debug("RESETTING REDIS_STARTED_MQ_THREADS_KEY TO 0") + # self._redis.set(REDIS_STARTED_MQ_THREADS_KEY, 0) - def stop(self): + def stop_time_based_flushing( + self, interceptor_instance_id: int, exec_bundle_id: int = None + ): self.logger.info("MQ time-based received stop signal!") if self._time_based_flushing_started: self._stop_flag = True self._time_thread.join() self._flush() - self._send_stop_message() + self._send_stop_message(interceptor_instance_id, exec_bundle_id) self._time_based_flushing_started = False self.logger.info("MQ time-based flushing stopped.") else: @@ -123,9 +182,16 @@ def time_based_flushing(self): ) sleep(REDIS_INSERTION_BUFFER_TIME) - def _send_stop_message(self): + def _send_stop_message( + self, interceptor_instance_id, exec_bundle_id=None + ): # TODO: these should be constants - msg = {"type": "flowcept_control", "info": "mq_dao_thread_stopped"} + msg = { + "type": "flowcept_control", + "info": "mq_dao_thread_stopped", + "interceptor_instance_id": interceptor_instance_id, + "exec_bundle_id": exec_bundle_id, + } self._redis.publish(REDIS_CHANNEL, json.dumps(msg)) def stop_document_inserter(self): diff --git a/flowcept/configs.py b/flowcept/configs.py index 66ed7c12..cda7b3a4 100644 --- a/flowcept/configs.py +++ b/flowcept/configs.py @@ -50,11 +50,15 @@ ###################### # Redis Settings # ###################### +REDIS_URI = settings["main_redis"].get("uri", None) REDIS_HOST = settings["main_redis"].get("host", "localhost") REDIS_PORT = int(settings["main_redis"].get("port", "6379")) REDIS_CHANNEL = settings["main_redis"].get("channel", "interception") REDIS_PASSWORD = settings["main_redis"].get("password", None) REDIS_STARTED_MQ_THREADS_KEY = "started_mq_threads" +REDIS_RESET_DB_AT_START = settings["main_redis"].get( + "reset_db_at_start", True +) REDIS_BUFFER_SIZE = int(settings["main_redis"].get("buffer_size", 50)) REDIS_INSERTION_BUFFER_TIME = int( settings["main_redis"].get("insertion_buffer_time_secs", 5) @@ -67,9 +71,10 @@ ###################### # MongoDB Settings # ###################### +MONGO_URI = settings["mongodb"].get("uri", None) MONGO_HOST = settings["mongodb"].get("host", "localhost") MONGO_PORT = int(settings["mongodb"].get("port", "27017")) -MONGO_DB = settings["mongodb"].get("db", "flowcept") +MONGO_DB = settings["mongodb"].get("db", PROJECT_NAME) MONGO_TASK_COLLECTION = "tasks" MONGO_WORKFLOWS_COLLECTION = "workflows" diff --git a/flowcept/flowcept_api/consumer_api.py b/flowcept/flowcept_api/consumer_api.py index f1d2869e..90514e5a 100644 --- a/flowcept/flowcept_api/consumer_api.py +++ b/flowcept/flowcept_api/consumer_api.py @@ -26,13 +26,12 @@ def start(self): self.logger.warning("Consumer is already started!") return self - self._mq_dao.reset_started_mq_threads() if self._interceptors and len(self._interceptors): for interceptor in self._interceptors: self.logger.debug( f"Flowceptor {interceptor.settings.key} starting..." ) - interceptor.start() + interceptor.start(bundle_exec_id=id(self)) self.logger.debug("... ok!") self.logger.debug("Flowcept Consumer starting...") @@ -66,3 +65,13 @@ def stop(self): self._document_inserter.stop() self.is_started = False self.logger.debug("All stopped!") + + def reset_time_based_threads_tracker(self): + self._mq_dao.delete_all_time_based_threads_sets() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() diff --git a/flowcept/flowceptor/adapters/base_interceptor.py b/flowcept/flowceptor/adapters/base_interceptor.py index 30b099e8..0e4573e1 100644 --- a/flowcept/flowceptor/adapters/base_interceptor.py +++ b/flowcept/flowceptor/adapters/base_interceptor.py @@ -65,17 +65,22 @@ def __init__(self, plugin_key): self.logger = FlowceptLogger().get_logger() self.settings = get_settings(plugin_key) self._mq_dao = MQDao() + self._bundle_exec_id = None + self._interceptor_instance_id = id(self) self.telemetry_capture = TelemetryCapture() def prepare_task_msg(self, *args, **kwargs) -> TaskMessage: raise NotImplementedError() - def start(self) -> "BaseInterceptor": + def start(self, bundle_exec_id) -> "BaseInterceptor": """ Starts an interceptor :return: """ - self._mq_dao.start_time_based_flushing() + self._bundle_exec_id = bundle_exec_id + self._mq_dao.start_time_based_flushing( + self._interceptor_instance_id, bundle_exec_id + ) self.telemetry_capture.init_gpu_telemetry() return self @@ -84,7 +89,9 @@ def stop(self) -> bool: Gracefully stops an interceptor :return: """ - self._mq_dao.stop() + self._mq_dao.stop_time_based_flushing( + self._interceptor_instance_id, self._bundle_exec_id + ) self.telemetry_capture.shutdown_gpu_telemetry() def observe(self, *args, **kwargs): diff --git a/flowcept/flowceptor/adapters/dask/dask_interceptor.py b/flowcept/flowceptor/adapters/dask/dask_interceptor.py index 0fed8670..007f6442 100644 --- a/flowcept/flowceptor/adapters/dask/dask_interceptor.py +++ b/flowcept/flowceptor/adapters/dask/dask_interceptor.py @@ -82,7 +82,7 @@ class DaskSchedulerInterceptor(BaseInterceptor): def __init__(self, scheduler, plugin_key="dask"): self._scheduler = scheduler super().__init__(plugin_key) - super().start() + super().start(bundle_exec_id=self._scheduler.address) def callback(self, task_id, start, finish, *args, **kwargs): try: @@ -129,7 +129,7 @@ def setup_worker(self, worker): """ self._worker = worker super().__init__(self._plugin_key) - super().start() + super().start(bundle_exec_id=self._worker.scheduler.address) # Note that both scheduler and worker get the exact same input. # Worker does not resolve intermediate inputs, just like the scheduler. # But careful: we are only able to capture inputs in client.map on @@ -197,6 +197,3 @@ def callback(self, task_id, start, finish, *args, **kwargs): f"Error with dask worker: {self._worker.worker_address}" ) self.logger.exception(e) - - def stop(self) -> bool: - super().stop() diff --git a/flowcept/flowceptor/adapters/mlflow/mlflow_interceptor.py b/flowcept/flowceptor/adapters/mlflow/mlflow_interceptor.py index dd4de62d..c95fe273 100644 --- a/flowcept/flowceptor/adapters/mlflow/mlflow_interceptor.py +++ b/flowcept/flowceptor/adapters/mlflow/mlflow_interceptor.py @@ -2,6 +2,7 @@ import time from watchdog.observers import Observer +from watchdog.observers.polling import PollingObserver from flowcept.commons.flowcept_dataclasses.task_message import TaskMessage from flowcept.commons.utils import get_utc_now, get_status_from_str @@ -22,7 +23,7 @@ class MLFlowInterceptor(BaseInterceptor): def __init__(self, plugin_key="mlflow"): super().__init__(plugin_key) - self._observer = None + self._observer: PollingObserver = None self.state_manager = InterceptorStateManager(self.settings) self.dao = MLFlowDAO(self.settings) @@ -58,8 +59,8 @@ def callback(self): task_msg = self.prepare_task_msg(run_data) self.intercept(task_msg) - def start(self) -> "MLFlowInterceptor": - super().start() + def start(self, bundle_exec_id) -> "MLFlowInterceptor": + super().start(bundle_exec_id) self.observe() return self diff --git a/flowcept/flowceptor/adapters/tensorboard/tensorboard_interceptor.py b/flowcept/flowceptor/adapters/tensorboard/tensorboard_interceptor.py index 53d59f04..ae31c9f3 100644 --- a/flowcept/flowceptor/adapters/tensorboard/tensorboard_interceptor.py +++ b/flowcept/flowceptor/adapters/tensorboard/tensorboard_interceptor.py @@ -3,6 +3,7 @@ from watchdog.observers import Observer from tbparse import SummaryReader +from watchdog.observers.polling import PollingObserver from flowcept.commons.flowcept_dataclasses.task_message import ( TaskMessage, @@ -23,7 +24,7 @@ class TensorboardInterceptor(BaseInterceptor): def __init__(self, plugin_key="tensorboard"): super().__init__(plugin_key) - self._observer = None + self._observer: PollingObserver = None self.state_manager = InterceptorStateManager(self.settings) self.state_manager.reset() self.log_metrics = set(self.settings.log_metrics) @@ -87,8 +88,8 @@ def callback(self): self.intercept(task_msg) self.state_manager.add_element_id(child_event.log_path) - def start(self) -> "TensorboardInterceptor": - super().start() + def start(self, bundle_exec_id) -> "TensorboardInterceptor": + super().start(bundle_exec_id) self.observe() return self diff --git a/flowcept/flowceptor/adapters/zambeze/zambeze_dataclasses.py b/flowcept/flowceptor/adapters/zambeze/zambeze_dataclasses.py index 7947ab37..827a4ed2 100644 --- a/flowcept/flowceptor/adapters/zambeze/zambeze_dataclasses.py +++ b/flowcept/flowceptor/adapters/zambeze/zambeze_dataclasses.py @@ -25,7 +25,7 @@ class ZambezeMessage: class ZambezeSettings(BaseSettings): host: str port: int - queue_name: str + queue_names: List[str] key_values_to_filter: List[KeyValue] = None kind = "zambeze" diff --git a/flowcept/flowceptor/adapters/zambeze/zambeze_interceptor.py b/flowcept/flowceptor/adapters/zambeze/zambeze_interceptor.py index b600331c..e9577e07 100644 --- a/flowcept/flowceptor/adapters/zambeze/zambeze_interceptor.py +++ b/flowcept/flowceptor/adapters/zambeze/zambeze_interceptor.py @@ -38,8 +38,8 @@ def prepare_task_msg(self, zambeze_msg: Dict) -> TaskMessage: } return task_msg - def start(self) -> "ZambezeInterceptor": - super().start() + def start(self, bundle_exec_id) -> "ZambezeInterceptor": + super().start(bundle_exec_id) self._observer_thread = Thread(target=self.observe) self._observer_thread.start() return self @@ -48,7 +48,7 @@ def stop(self) -> bool: self.logger.debug("Interceptor stopping...") super().stop() try: - self._channel.basic_cancel(self._consumer_tag) + self._channel.stop_consuming() except Exception as e: self.logger.warning( f"This exception is expected to occur after " @@ -66,13 +66,20 @@ def observe(self): ) ) self._channel = connection.channel() - self._channel.queue_declare(queue=self.settings.queue_name) - self._consumer_tag = self._channel.basic_consume( - queue=self.settings.queue_name, - on_message_callback=self.callback, - auto_ack=True, - ) - self.logger.debug("Waiting for Zambeze messages.") + for queue in self.settings.queue_names: + self._channel.queue_declare(queue=queue) + + # self._consumer_tag =\ + for queue in self.settings.queue_names: + self._channel.basic_consume( + queue=queue, + on_message_callback=self.callback, + auto_ack=True, + ) + self.logger.debug( + f"Waiting for Zambeze messages on queue {queue}" + ) + try: self._channel.start_consuming() except Exception as e: diff --git a/flowcept/flowceptor/consumers/document_inserter.py b/flowcept/flowceptor/consumers/document_inserter.py index 3e36457e..13823a05 100644 --- a/flowcept/flowceptor/consumers/document_inserter.py +++ b/flowcept/flowceptor/consumers/document_inserter.py @@ -141,7 +141,6 @@ def _start(self): ) time_thread.start() pubsub = self._mq_dao.subscribe() - stoped_mq_threads = 0 should_continue = True while should_continue: try: @@ -160,15 +159,18 @@ def _start(self): "Received mq_dao_thread_stopped message " "in DocInserter!" ) - stoped_mq_threads += 1 - started_mq_threads = ( - self._mq_dao.get_started_mq_threads() + exec_bundle_id = _dict_obj.get( + "exec_bundle_id", None ) - self.logger.debug( - f"stoped_mq_threads={stoped_mq_threads}; " - f"REDIS_STARTED_MQ_THREADS_KEY={started_mq_threads}" + interceptor_instance_id = _dict_obj.get( + "interceptor_instance_id" + ) + self._mq_dao.register_time_based_thread_end( + interceptor_instance_id, exec_bundle_id ) - if stoped_mq_threads == started_mq_threads: + if self._mq_dao.all_time_based_threads_ended( + exec_bundle_id + ): self._safe_to_stop = True self.logger.debug("It is safe to stop.") diff --git a/resources/sample_settings.yaml b/resources/sample_settings.yaml index 2ba75e7f..0f799695 100644 --- a/resources/sample_settings.yaml +++ b/resources/sample_settings.yaml @@ -24,6 +24,7 @@ experiment: main_redis: host: localhost port: 6379 + reset_db_at_start: true channel: interception buffer_size: 50 insertion_buffer_time_secs: 5 @@ -68,7 +69,9 @@ adapters: kind: zambeze host: localhost port: 5672 - queue_name: hello + queue_names: + - hello + - hello2 # key_values_to_filter: # - key: activity_status # value: CREATED diff --git a/tests/adapters/test_dask_with_context_mgmt.py b/tests/adapters/test_dask_with_context_mgmt.py new file mode 100644 index 00000000..4aff3412 --- /dev/null +++ b/tests/adapters/test_dask_with_context_mgmt.py @@ -0,0 +1,65 @@ +import unittest +from time import sleep +from uuid import uuid4 +import numpy as np + +from dask.distributed import Client + +from flowcept import FlowceptConsumerAPI, TaskQueryAPI +from flowcept.commons.daos.document_db_dao import DocumentDBDao +from flowcept.commons.flowcept_logger import FlowceptLogger + + +def dummy_func1(x, workflow_id=None): + cool_var = "cool value" # test if we can intercept this var + print(cool_var) + y = cool_var + return x * 2 + + +class TestDaskContextMgmt(unittest.TestCase): + client: Client = None + + def __init__(self, *args, **kwargs): + super(TestDaskContextMgmt, self).__init__(*args, **kwargs) + self.logger = FlowceptLogger().get_logger() + + @classmethod + def setUpClass(cls): + TestDaskContextMgmt.client = ( + TestDaskContextMgmt._setup_local_dask_cluster() + ) + + @staticmethod + def _setup_local_dask_cluster(n_workers=2): + from dask.distributed import Client, LocalCluster + from flowcept import ( + FlowceptDaskSchedulerAdapter, + FlowceptDaskWorkerAdapter, + ) + + cluster = LocalCluster(n_workers=n_workers) + scheduler = cluster.scheduler + client = Client(scheduler.address) + + # Instantiate and Register FlowceptPlugins, which are the ONLY + # additional steps users would need to do in their code: + scheduler.add_plugin(FlowceptDaskSchedulerAdapter(scheduler)) + + client.register_worker_plugin(FlowceptDaskWorkerAdapter()) + + return client + + def test_workflow(self): + i1 = np.random.random() + wf_id = f"wf_{uuid4()}" + with FlowceptConsumerAPI(): + o1 = self.client.submit(dummy_func1, i1, workflow_id=wf_id) + self.logger.debug(o1.result()) + self.logger.debug(o1.key) + sleep(5) + TestDaskContextMgmt.client.shutdown() + + query_api = TaskQueryAPI() + docs = query_api.query({"workflow_id": wf_id}) + assert len(docs) diff --git a/tests/adapters/test_zambeze.py b/tests/adapters/test_zambeze.py index dc16e1fa..d2496933 100644 --- a/tests/adapters/test_zambeze.py +++ b/tests/adapters/test_zambeze.py @@ -18,7 +18,7 @@ def __init__(self, *args, **kwargs): self.logger = FlowceptLogger().get_logger() interceptor = ZambezeInterceptor() self.consumer = FlowceptConsumerAPI(interceptor) - + # self.consumer.reset_time_based_threads_tracker() self._connection = pika.BlockingConnection( pika.ConnectionParameters( interceptor.settings.host, @@ -26,9 +26,9 @@ def __init__(self, *args, **kwargs): ) ) self._channel = self._connection.channel() - self._queue_name = interceptor.settings.queue_name - self._channel.queue_declare(queue=self._queue_name) + self._queue_names = interceptor.settings.queue_names + self._channel.queue_declare(queue=self._queue_names[0]) self.consumer.start() def test_send_message(self): @@ -58,7 +58,7 @@ def test_send_message(self): self._channel.basic_publish( exchange="", - routing_key=self._queue_name, + routing_key=self._queue_names[0], body=json.dumps(msg.__dict__), ) print("Zambeze Activity_id", act_id)