Skip to content

Commit

Permalink
Merge pull request #102 from ORNL/dev
Browse files Browse the repository at this point in the history
Main < Dev: Context management
  • Loading branch information
renan-souza authored Feb 15, 2024
2 parents 8f17431 + 5ce31d4 commit b4e5d0d
Show file tree
Hide file tree
Showing 15 changed files with 273 additions and 60 deletions.
7 changes: 6 additions & 1 deletion flowcept/commons/daos/document_db_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MONGO_TASK_COLLECTION,
MONGO_WORKFLOWS_COLLECTION,
PERF_LOG,
MONGO_URI,
)
from flowcept.flowceptor.consumers.consumer_utils import (
curate_dict_task_messages,
Expand All @@ -27,7 +28,11 @@
class DocumentDBDao(object):
def __init__(self):
self.logger = FlowceptLogger().get_logger()
client = MongoClient(MONGO_HOST, MONGO_PORT)

if MONGO_URI is not None:
client = MongoClient(MONGO_URI)
else:
client = MongoClient(MONGO_HOST, MONGO_PORT)
self._db = client[MONGO_DB]

self._tasks_collection = self._db[MONGO_TASK_COLLECTION]
Expand Down
45 changes: 45 additions & 0 deletions flowcept/commons/daos/keyvalue_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from redis import Redis

from flowcept.commons.flowcept_logger import FlowceptLogger
from flowcept.configs import (
REDIS_HOST,
REDIS_PORT,
REDIS_PASSWORD,
)


class KeyValueDAO:
def __init__(self, connection=None):
self.logger = FlowceptLogger().get_logger()
if connection is None:
self._redis = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=0,
password=REDIS_PASSWORD,
)
else:
self._redis = connection

def delete_set(self, set_name: str):
self._redis.delete(set_name)

def add_key_into_set(self, set_name: str, key):
self._redis.sadd(set_name, key)

def remove_key_from_set(self, set_name: str, key):
self._redis.srem(set_name, key)

def set_has_key(self, set_name: str, key) -> bool:
return self._redis.sismember(set_name, key)

def set_count(self, set_name: str):
return self._redis.scard(set_name)

def set_is_empty(self, set_name: str) -> bool:
return self.set_count(set_name) == 0

def delete_all_matching_sets(self, key_pattern):
matching_sets = self._redis.keys(key_pattern)
for set_name in matching_sets:
self.delete_set(set_name)
102 changes: 84 additions & 18 deletions flowcept/commons/daos/mq_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from threading import Thread, Lock
from time import time, sleep

from flowcept.commons.daos.keyvalue_dao import KeyValueDAO
from flowcept.commons.utils import perf_log
from flowcept.commons.flowcept_logger import FlowceptLogger
from flowcept.configs import (
Expand All @@ -16,6 +17,7 @@
REDIS_BUFFER_SIZE,
REDIS_INSERTION_BUFFER_TIME,
PERF_LOG,
REDIS_URI,
)

from flowcept.commons.utils import GenericJSONEncoder
Expand All @@ -26,48 +28,105 @@ class MQDao:
ENCODER = GenericJSONEncoder if JSON_SERIALIZER == "complex" else None
# TODO we don't have a unit test to cover complex dict!

@staticmethod
def _get_set_name(exec_bundle_id=None):
"""
:param exec_bundle_id: A way to group one or many interceptors, and treat each group as a bundle to control when their time_based threads started and ended.
:return:
"""
set_id = f"started_mq_thread_execution"
if exec_bundle_id is not None:
set_id += "_" + str(exec_bundle_id)
return set_id

def __init__(self):
self.logger = FlowceptLogger().get_logger()
self._redis = Redis(
host=REDIS_HOST, port=REDIS_PORT, db=0, password=REDIS_PASSWORD
)

if REDIS_URI is not None:
# If a URI is provided, use it for connection
self._redis = Redis.from_url(REDIS_URI)
else:
# Otherwise, use the host, port, and password settings
self._redis = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=0,
password=REDIS_PASSWORD if REDIS_PASSWORD else None,
)

self._keyvalue_dao = KeyValueDAO(connection=self._redis)
self._buffer = None
self._time_thread: Thread = None
self._previous_time = -1
self._stop_flag = False
self._time_based_flushing_started = False
self._lock = None

def start_time_based_flushing(self):
def register_time_based_thread_init(
self, interceptor_instance_id: int, exec_bundle_id=None
):
set_name = MQDao._get_set_name(exec_bundle_id)
self.logger.debug(
f"Registering the beginning of the time_based MQ flush thread {set_name}.{interceptor_instance_id}"
)
self._keyvalue_dao.add_key_into_set(set_name, interceptor_instance_id)

def register_time_based_thread_end(
self, interceptor_instance_id: int, exec_bundle_id=None
):
set_name = MQDao._get_set_name(exec_bundle_id)
self.logger.debug(
f"Registering the end of the time_based MQ flush thread {set_name}.{interceptor_instance_id}"
)
self._keyvalue_dao.remove_key_from_set(
set_name, interceptor_instance_id
)

def all_time_based_threads_ended(self, exec_bundle_id=None):
set_name = MQDao._get_set_name(exec_bundle_id)
return self._keyvalue_dao.set_is_empty(set_name)

def delete_all_time_based_threads_sets(self):
return self._keyvalue_dao.delete_all_matching_sets(
MQDao._get_set_name() + "*"
)

def start_time_based_flushing(
self, interceptor_instance_id: int, exec_bundle_id=None
):
self._buffer = list()
self._time_thread: Thread = None
self._previous_time = time()
self._stop_flag = False
self._time_based_flushing_started = False
self._lock = Lock()

self._time_thread = Thread(target=self.time_based_flushing)
self._redis.incr(REDIS_STARTED_MQ_THREADS_KEY)
self.logger.debug(
f"Incrementing REDIS_STARTED_MQ_THREADS_KEY. Now: {self.get_started_mq_threads()}"
self.register_time_based_thread_init(
interceptor_instance_id, exec_bundle_id
)
# self._redis.incr(REDIS_STARTED_MQ_THREADS_KEY)
# self.logger.debug(
# f"Incrementing REDIS_STARTED_MQ_THREADS_KEY. Now: {self.get_started_mq_threads()}"
# )
self._time_based_flushing_started = True
self._time_thread.start()

def get_started_mq_threads(self):
return int(self._redis.get(REDIS_STARTED_MQ_THREADS_KEY))

def reset_started_mq_threads(self):
self.logger.debug("RESETTING REDIS_STARTED_MQ_THREADS_KEY TO 0")
self._redis.set(REDIS_STARTED_MQ_THREADS_KEY, 0)
# def get_started_mq_threads(self):
# return int(self._redis.get(REDIS_STARTED_MQ_THREADS_KEY))
#
# def reset_started_mq_threads(self):
# self.logger.debug("RESETTING REDIS_STARTED_MQ_THREADS_KEY TO 0")
# self._redis.set(REDIS_STARTED_MQ_THREADS_KEY, 0)

def stop(self):
def stop_time_based_flushing(
self, interceptor_instance_id: int, exec_bundle_id: int = None
):
self.logger.info("MQ time-based received stop signal!")
if self._time_based_flushing_started:
self._stop_flag = True
self._time_thread.join()
self._flush()
self._send_stop_message()
self._send_stop_message(interceptor_instance_id, exec_bundle_id)
self._time_based_flushing_started = False
self.logger.info("MQ time-based flushing stopped.")
else:
Expand Down Expand Up @@ -123,9 +182,16 @@ def time_based_flushing(self):
)
sleep(REDIS_INSERTION_BUFFER_TIME)

def _send_stop_message(self):
def _send_stop_message(
self, interceptor_instance_id, exec_bundle_id=None
):
# TODO: these should be constants
msg = {"type": "flowcept_control", "info": "mq_dao_thread_stopped"}
msg = {
"type": "flowcept_control",
"info": "mq_dao_thread_stopped",
"interceptor_instance_id": interceptor_instance_id,
"exec_bundle_id": exec_bundle_id,
}
self._redis.publish(REDIS_CHANNEL, json.dumps(msg))

def stop_document_inserter(self):
Expand Down
7 changes: 6 additions & 1 deletion flowcept/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@
######################
# Redis Settings #
######################
REDIS_URI = settings["main_redis"].get("uri", None)
REDIS_HOST = settings["main_redis"].get("host", "localhost")
REDIS_PORT = int(settings["main_redis"].get("port", "6379"))
REDIS_CHANNEL = settings["main_redis"].get("channel", "interception")
REDIS_PASSWORD = settings["main_redis"].get("password", None)
REDIS_STARTED_MQ_THREADS_KEY = "started_mq_threads"
REDIS_RESET_DB_AT_START = settings["main_redis"].get(
"reset_db_at_start", True
)
REDIS_BUFFER_SIZE = int(settings["main_redis"].get("buffer_size", 50))
REDIS_INSERTION_BUFFER_TIME = int(
settings["main_redis"].get("insertion_buffer_time_secs", 5)
Expand All @@ -67,9 +71,10 @@
######################
# MongoDB Settings #
######################
MONGO_URI = settings["mongodb"].get("uri", None)
MONGO_HOST = settings["mongodb"].get("host", "localhost")
MONGO_PORT = int(settings["mongodb"].get("port", "27017"))
MONGO_DB = settings["mongodb"].get("db", "flowcept")
MONGO_DB = settings["mongodb"].get("db", PROJECT_NAME)

MONGO_TASK_COLLECTION = "tasks"
MONGO_WORKFLOWS_COLLECTION = "workflows"
Expand Down
13 changes: 11 additions & 2 deletions flowcept/flowcept_api/consumer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ def start(self):
self.logger.warning("Consumer is already started!")
return self

self._mq_dao.reset_started_mq_threads()
if self._interceptors and len(self._interceptors):
for interceptor in self._interceptors:
self.logger.debug(
f"Flowceptor {interceptor.settings.key} starting..."
)
interceptor.start()
interceptor.start(bundle_exec_id=id(self))
self.logger.debug("... ok!")

self.logger.debug("Flowcept Consumer starting...")
Expand Down Expand Up @@ -66,3 +65,13 @@ def stop(self):
self._document_inserter.stop()
self.is_started = False
self.logger.debug("All stopped!")

def reset_time_based_threads_tracker(self):
self._mq_dao.delete_all_time_based_threads_sets()

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
13 changes: 10 additions & 3 deletions flowcept/flowceptor/adapters/base_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,22 @@ def __init__(self, plugin_key):
self.logger = FlowceptLogger().get_logger()
self.settings = get_settings(plugin_key)
self._mq_dao = MQDao()
self._bundle_exec_id = None
self._interceptor_instance_id = id(self)
self.telemetry_capture = TelemetryCapture()

def prepare_task_msg(self, *args, **kwargs) -> TaskMessage:
raise NotImplementedError()

def start(self) -> "BaseInterceptor":
def start(self, bundle_exec_id) -> "BaseInterceptor":
"""
Starts an interceptor
:return:
"""
self._mq_dao.start_time_based_flushing()
self._bundle_exec_id = bundle_exec_id
self._mq_dao.start_time_based_flushing(
self._interceptor_instance_id, bundle_exec_id
)
self.telemetry_capture.init_gpu_telemetry()
return self

Expand All @@ -84,7 +89,9 @@ def stop(self) -> bool:
Gracefully stops an interceptor
:return:
"""
self._mq_dao.stop()
self._mq_dao.stop_time_based_flushing(
self._interceptor_instance_id, self._bundle_exec_id
)
self.telemetry_capture.shutdown_gpu_telemetry()

def observe(self, *args, **kwargs):
Expand Down
7 changes: 2 additions & 5 deletions flowcept/flowceptor/adapters/dask/dask_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DaskSchedulerInterceptor(BaseInterceptor):
def __init__(self, scheduler, plugin_key="dask"):
self._scheduler = scheduler
super().__init__(plugin_key)
super().start()
super().start(bundle_exec_id=self._scheduler.address)

def callback(self, task_id, start, finish, *args, **kwargs):
try:
Expand Down Expand Up @@ -129,7 +129,7 @@ def setup_worker(self, worker):
"""
self._worker = worker
super().__init__(self._plugin_key)
super().start()
super().start(bundle_exec_id=self._worker.scheduler.address)
# Note that both scheduler and worker get the exact same input.
# Worker does not resolve intermediate inputs, just like the scheduler.
# But careful: we are only able to capture inputs in client.map on
Expand Down Expand Up @@ -197,6 +197,3 @@ def callback(self, task_id, start, finish, *args, **kwargs):
f"Error with dask worker: {self._worker.worker_address}"
)
self.logger.exception(e)

def stop(self) -> bool:
super().stop()
7 changes: 4 additions & 3 deletions flowcept/flowceptor/adapters/mlflow/mlflow_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time

from watchdog.observers import Observer
from watchdog.observers.polling import PollingObserver

from flowcept.commons.flowcept_dataclasses.task_message import TaskMessage
from flowcept.commons.utils import get_utc_now, get_status_from_str
Expand All @@ -22,7 +23,7 @@
class MLFlowInterceptor(BaseInterceptor):
def __init__(self, plugin_key="mlflow"):
super().__init__(plugin_key)
self._observer = None
self._observer: PollingObserver = None
self.state_manager = InterceptorStateManager(self.settings)
self.dao = MLFlowDAO(self.settings)

Expand Down Expand Up @@ -58,8 +59,8 @@ def callback(self):
task_msg = self.prepare_task_msg(run_data)
self.intercept(task_msg)

def start(self) -> "MLFlowInterceptor":
super().start()
def start(self, bundle_exec_id) -> "MLFlowInterceptor":
super().start(bundle_exec_id)
self.observe()
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from watchdog.observers import Observer
from tbparse import SummaryReader
from watchdog.observers.polling import PollingObserver

from flowcept.commons.flowcept_dataclasses.task_message import (
TaskMessage,
Expand All @@ -23,7 +24,7 @@
class TensorboardInterceptor(BaseInterceptor):
def __init__(self, plugin_key="tensorboard"):
super().__init__(plugin_key)
self._observer = None
self._observer: PollingObserver = None
self.state_manager = InterceptorStateManager(self.settings)
self.state_manager.reset()
self.log_metrics = set(self.settings.log_metrics)
Expand Down Expand Up @@ -87,8 +88,8 @@ def callback(self):
self.intercept(task_msg)
self.state_manager.add_element_id(child_event.log_path)

def start(self) -> "TensorboardInterceptor":
super().start()
def start(self, bundle_exec_id) -> "TensorboardInterceptor":
super().start(bundle_exec_id)
self.observe()
return self

Expand Down
Loading

0 comments on commit b4e5d0d

Please sign in to comment.