Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shutdown on Ctrl+C for Python source stages #1839

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
4c34d39
Implement stop method, and use a non-blocking call to queue.get, this…
dagardner-nv Aug 8, 2024
1955523
Move checking of stop_requested attribute to occur at the same time a…
dagardner-nv Aug 8, 2024
561b16b
Stop the http server on shutdown
dagardner-nv Aug 8, 2024
9417048
Move the stop method implemented in the kafka source stage to the bas…
dagardner-nv Aug 9, 2024
d4e7201
Remove, setting _stop_requested as this is now in the parent class
dagardner-nv Aug 9, 2024
ac86f05
wip
dagardner-nv Aug 9, 2024
bbb920d
Use non-blocking call to fetch from queue
dagardner-nv Aug 9, 2024
f27e8b0
Fix
dagardner-nv Aug 9, 2024
38fc546
Add a should_stop_fn callback to Watcher class
dagardner-nv Aug 9, 2024
7a12d43
Since we cannot pass a callback function to a module, as module confi…
dagardner-nv Aug 9, 2024
0ca515a
Since the interval time is often high (default is 10 minutes), rather…
dagardner-nv Aug 9, 2024
8fa5602
Remove unused import
dagardner-nv Aug 9, 2024
096aefd
Remove stop impl, this has been moved to the base class
dagardner-nv Aug 9, 2024
9e10924
Check is_stop_requested
dagardner-nv Aug 9, 2024
5412076
Use self.is_stop_requested() method, set type hint
dagardner-nv Aug 9, 2024
0a3fb0a
Update developer docs to reflect code changes
dagardner-nv Aug 9, 2024
8586673
Merge branch 'branch-24.10' into david-polling-source-stage-ctrl-c-18337
dagardner-nv Aug 13, 2024
9805dc8
Formatting fix
dagardner-nv Aug 14, 2024
e8be2a7
Merge branch 'branch-24.10' into david-polling-source-stage-ctrl-c-18337
dagardner-nv Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions docs/source/developer_guide/guides/2_real_world_phishing.md
Original file line number Diff line number Diff line change
Expand Up @@ -761,20 +761,20 @@ def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject:
return builder.make_source(self.unique_name, self.source_generator)
```

The `source_generator` method is where most of the RabbitMQ-specific code exists. When we have a message that we wish to emit into the pipeline, we simply `yield` it.
The `source_generator` method is where most of the RabbitMQ-specific code exists. When we have a message that we wish to emit into the pipeline, we simply `yield` it. We continue this process until the `is_stop_requested()` method returns `True`.

```python
def source_generator(self) -> collections.abc.Iterator[MessageMeta]:
try:
while not self._stop_requested:
(method_frame, header_frame, body) = self._channel.basic_get(self._queue_name)
while not self.is_stop_requested():
(method_frame, _, body) = self._channel.basic_get(self._queue_name)
if method_frame is not None:
try:
buffer = StringIO(body.decode("utf-8"))
df = cudf.io.read_json(buffer, orient='records', lines=True)
yield MessageMeta(df=df)
except Exception as ex:
logger.exception("Error occurred converting RabbitMQ message to Dataframe: {}".format(ex))
logger.exception("Error occurred converting RabbitMQ message to Dataframe: %s", ex)
finally:
self._channel.basic_ack(method_frame.delivery_tag)
else:
Expand Down Expand Up @@ -824,11 +824,11 @@ class RabbitMQSourceStage(PreallocatorMixin, SingleOutputSource):
Hostname or IP of the RabbitMQ server.
exchange : str
Name of the RabbitMQ exchange to connect to.
exchange_type : str
exchange_type : str, optional
RabbitMQ exchange type; defaults to `fanout`.
queue_name : str
queue_name : str, optional
Name of the queue to listen to. If left blank, RabbitMQ will generate a random queue name bound to the exchange.
poll_interval : str
poll_interval : str, optional
Amount of time between polling RabbitMQ for new messages
"""

Expand All @@ -854,9 +854,6 @@ class RabbitMQSourceStage(PreallocatorMixin, SingleOutputSource):

self._poll_interval = pd.Timedelta(poll_interval)

# Flag to indicate whether or not we should stop
self._stop_requested = False

@property
def name(self) -> str:
return "from-rabbitmq"
Expand All @@ -867,18 +864,12 @@ class RabbitMQSourceStage(PreallocatorMixin, SingleOutputSource):
def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(MessageMeta)

def stop(self):
# Indicate we need to stop
self._stop_requested = True

return super().stop()

def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject:
return builder.make_source(self.unique_name, self.source_generator)

def source_generator(self) -> collections.abc.Iterator[MessageMeta]:
try:
while not self._stop_requested:
while not self.is_stop_requested():
(method_frame, _, body) = self._channel.basic_get(self._queue_name)
if method_frame is not None:
try:
Expand Down
5 changes: 1 addition & 4 deletions docs/source/developer_guide/guides/4_source_cpp_stage.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,10 @@ def __init__(self,
self._exchange_type = exchange_type
self._queue_name = queue_name

self._connection = None
self._connection: pika.BlockingConnection = None
self._channel = None

self._poll_interval = pd.Timedelta(poll_interval)

# Flag to indicate whether or not we should stop
self._stop_requested = False
```
```python
def connect(self):
Expand Down
11 changes: 1 addition & 10 deletions examples/developer_guide/2_2_rabbitmq/rabbitmq_source_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def __init__(self,

self._poll_interval = pd.Timedelta(poll_interval)

# Flag to indicate whether or not we should stop
self._stop_requested = False

@property
def name(self) -> str:
return "from-rabbitmq"
Expand All @@ -90,18 +87,12 @@ def supports_cpp_node(self) -> bool:
def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(MessageMeta)

def stop(self):
# Indicate we need to stop
self._stop_requested = True

return super().stop()

def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject:
return builder.make_source(self.unique_name, self.source_generator)

def source_generator(self) -> collections.abc.Iterator[MessageMeta]:
try:
while not self._stop_requested:
while not self.is_stop_requested():
(method_frame, _, body) = self._channel.basic_get(self._queue_name)
if method_frame is not None:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,11 @@ def __init__(self,
self._exchange_type = exchange_type
self._queue_name = queue_name

self._connection = None
self._connection: pika.BlockingConnection = None
self._channel = None

self._poll_interval = pd.Timedelta(poll_interval)

# Flag to indicate whether or not we should stop
self._stop_requested = False

@property
def name(self) -> str:
return "from-rabbitmq"
Expand Down Expand Up @@ -117,7 +114,7 @@ def connect(self):

def source_generator(self):
try:
while not self._stop_requested:
while not self.is_stop_requested():
(method_frame, _, body) = self._channel.basic_get(self._queue_name)
if method_frame is not None:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _polling_generate_frames_fsspec(self) -> typing.Iterable[fsspec.core.OpenFil
curr_time = time.monotonic()
next_update_epoch = curr_time

while (True):
while (not self.is_stop_requested()):
# Before doing any work, find the next update epoch after the current time
while (next_update_epoch <= curr_time):
# Only ever add `self._watch_interval` to next_update_epoch so all updates are at repeating intervals
Expand Down
71 changes: 70 additions & 1 deletion python/morpheus/morpheus/controllers/rss_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@
import logging
import os
import time
from collections.abc import Callable
from collections.abc import Iterable
from dataclasses import asdict
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
from urllib.parse import urlparse

import requests
import requests_cache

import cudf

from morpheus.messages import MessageMeta

logger = logging.getLogger(__name__)

IMPORT_EXCEPTION = None
Expand Down Expand Up @@ -72,6 +78,12 @@ class RSSController:
Request timeout in secs to fetch the feed.
strip_markup : bool, optional, default = False
When true, strip HTML & XML markup from the from the content, summary and title fields.
stop_after: int, default = 0
Stops ingesting after emitting `stop_after` records (rows in the dataframe). Useful for testing. Disabled if `0`
interval_secs : float, optional, default = 600
Interval in seconds between fetching new feed items.
should_stop_fn: Callable[[], bool]
Function that returns a boolean indicating if the watcher should stop processing files.
"""

# Fields which may contain HTML or XML content
Expand All @@ -89,7 +101,10 @@ def __init__(self,
cache_dir: str = "./.cache/http",
cooldown_interval: int = 600,
request_timeout: float = 2.0,
strip_markup: bool = False):
strip_markup: bool = False,
stop_after: int = 0,
interval_secs: float = 600,
should_stop_fn: Callable[[], bool] = None):
if IMPORT_EXCEPTION is not None:
raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION

Expand All @@ -104,6 +119,11 @@ def __init__(self,
self._request_timeout = request_timeout
self._strip_markup = strip_markup

if should_stop_fn is None:
self._should_stop_fn = lambda: False
else:
self._should_stop_fn = should_stop_fn

# Validate feed_input
for f in self._feed_input:
if not RSSController.is_url(f) and not os.path.exists(f):
Expand All @@ -113,7 +133,14 @@ def __init__(self,
# If feed_input is URL. Runs indefinitely
run_indefinitely = any(RSSController.is_url(f) for f in self._feed_input)

if (stop_after > 0 and run_indefinitely):
raise ValueError("Cannot set both `stop_after` and `run_indefinitely` to True.")

self._stop_after = stop_after
self._run_indefinitely = run_indefinitely
self._interval_secs = interval_secs
self._interval_td = timedelta(seconds=self._interval_secs)

self._enable_cache = enable_cache

if enable_cache:
Expand Down Expand Up @@ -381,3 +408,45 @@ def is_url(cls, feed_input: str) -> bool:
return parsed_url.scheme != '' and parsed_url.netloc != ''
except Exception:
return False

def feed_generator(self) -> Iterable[MessageMeta]:
"""
Fetch RSS feed entries and yield as MessageMeta object.
"""
stop_requested = False
records_emitted = 0

while (not stop_requested and not self._should_stop_fn()):
try:
for df in self.fetch_dataframes():
df_size = len(df)

if logger.isEnabledFor(logging.DEBUG):
logger.info("Received %d new entries...", df_size)
logger.info("Emitted %d records so far.", records_emitted)

yield MessageMeta(df=df)

records_emitted += df_size

if (0 < self._stop_after <= records_emitted):
stop_requested = True
logger.info("Stop limit reached... preparing to halt the source.")
break

except Exception as exc:
if not self.run_indefinitely:
logger.error("Failed either in the process of fetching or processing entries: %s.", exc)
raise
logger.error("Failed either in the process of fetching or processing entries: %s.", exc)

if not self.run_indefinitely:
stop_requested = True
continue

logger.info("Waiting for %d seconds before fetching again...", self._interval_secs)
sleep_until = datetime.now() + self._interval_td
while (datetime.now() < sleep_until and not self._should_stop_fn()):
time.sleep(1)

logger.info("RSS source exhausted, stopping.")
51 changes: 5 additions & 46 deletions python/morpheus/morpheus/modules/input/rss_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.

import logging
import time

import mrc
from pydantic import ValidationError

from morpheus.controllers.rss_controller import RSSController
from morpheus.messages import MessageMeta
from morpheus.modules.schemas.rss_source_schema import RSSSourceSchema
from morpheus.utils.module_utils import ModuleLoaderFactory
from morpheus.utils.module_utils import register_module
Expand Down Expand Up @@ -57,6 +55,7 @@ def _rss_source(builder: mrc.Builder):

module_config = builder.get_current_module_config()
rss_config = module_config.get("rss_source", {})

try:
validated_config = RSSSourceSchema(**rss_config)
except ValidationError as e:
Expand All @@ -74,50 +73,10 @@ def _rss_source(builder: mrc.Builder):
cache_dir=validated_config.cache_dir,
cooldown_interval=validated_config.cooldown_interval_sec,
request_timeout=validated_config.request_timeout_sec,
strip_markup=validated_config.strip_markup)

stop_requested = False

def fetch_feeds() -> MessageMeta:
"""
Fetch RSS feed entries and yield as MessageMeta object.
"""
nonlocal stop_requested
records_emitted = 0

while (not stop_requested):
try:
for df in controller.fetch_dataframes():
df_size = len(df)

if logger.isEnabledFor(logging.DEBUG):
logger.info("Received %d new entries...", df_size)
logger.info("Emitted %d records so far.", records_emitted)

yield MessageMeta(df=df)

records_emitted += df_size

if (0 < validated_config.stop_after_rec <= records_emitted):
stop_requested = True
logger.info("Stop limit reached... preparing to halt the source.")
break

except Exception as exc:
if not controller.run_indefinitely:
logger.error("Failed either in the process of fetching or processing entries: %s.", exc)
raise
logger.error("Failed either in the process of fetching or processing entries: %s.", exc)

if not controller.run_indefinitely:
stop_requested = True
continue

logger.info("Waiting for %d seconds before fetching again...", validated_config.interval_sec)
time.sleep(validated_config.interval_sec)

logger.info("RSS source exhausted, stopping.")
strip_markup=validated_config.strip_markup,
stop_after=validated_config.stop_after_rec,
interval_secs=validated_config.interval_sec)

node = builder.make_source("fetch_feeds", fetch_feeds)
node = builder.make_source("fetch_feeds", controller.feed_generator)

builder.register_module_output("output", node)
32 changes: 32 additions & 0 deletions python/morpheus/morpheus/pipeline/single_output_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def __init__(self, c: Config):

self._create_ports(0, 1)

# Flag to indicate if we need to stop, subclasses should check this value periodically, typically at the start
# of a polling loop
self._stop_requested = False

# pylint: disable=unused-argument
def _post_build_single(self, builder: mrc.Builder, out_node: mrc.SegmentObject) -> mrc.SegmentObject:
return out_node
Expand Down Expand Up @@ -74,3 +78,31 @@ def _post_build(self, builder: mrc.Builder, out_ports_nodes: list[mrc.SegmentObj
logger.info("Added source: %s\n └─> %s", self, pretty_print_type_name(self.output_ports[0].output_type))

return [ret_val]

def stop(self):
"""
This method is invoked by the pipeline whenever there is an unexpected shutdown.
Subclasses should override this method to perform any necessary cleanup operations.
"""

# Indicate we need to stop
self.request_stop()

return super().stop()

def request_stop(self):
"""
Request the source to stop processing data.
"""
self._stop_requested = True

def is_stop_requested(self) -> bool:
"""
Returns `True` if a stop has been requested.

Returns
-------
bool:
True if a stop has been requested, False otherwise.
"""
return self._stop_requested
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def __init__(self,
sort_glob=sort_glob,
recursive=recursive,
queue_max_size=queue_max_size,
batch_timeout=batch_timeout)
batch_timeout=batch_timeout,
should_stop_fn=self.is_stop_requested)

@property
def name(self) -> str:
Expand Down
Loading
Loading