Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/update several existing tests #1273

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docker/conda/environments/cuda11.8_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ dependencies:
- pytorch=2.0.1
- rapidjson=1.1.0
- requests=2.31
- requests-cache=1.1
- scikit-build=0.17.1
- scikit-learn=1.2.2
- sphinx
Expand Down
4 changes: 4 additions & 0 deletions morpheus/cli/register_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from morpheus.config import PipelineModes
from morpheus.utils.type_utils import _DecoratorType
from morpheus.utils.type_utils import get_full_qualname
from morpheus.utils.type_utils import is_union_type


def class_name_to_command_name(class_name: str) -> str:
Expand Down Expand Up @@ -177,6 +178,9 @@ def set_options_param_type(options_kwargs: dict, annotation, doc_type: str):
if (annotation == inspect.Parameter.empty):
raise RuntimeError("All types must be specified to auto register stage.")

if (is_union_type(annotation)):
raise RuntimeError("Union types are not supported for auto registering stages.")

if (issubtype(annotation, typing.List)):
# For variable length array, use multiple=True
options_kwargs["multiple"] = True
Expand Down
5 changes: 4 additions & 1 deletion morpheus/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def parse_log_level(ctx, param, value):

def is_enum(enum_class: typing.Type):
"""Returns True if the given class is an enum."""
return issubclass(enum_class, Enum) or is_pybind_enum(enum_class)
try:
return issubclass(enum_class, Enum) or is_pybind_enum(enum_class)
except TypeError:
return False


def get_enum_members(enum_class: typing.Type):
Expand Down
35 changes: 20 additions & 15 deletions morpheus/controllers/rss_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
import logging
import os
import typing
import urllib
from urllib.parse import urlparse

import feedparser
import requests
import requests_cache

import cudf
Expand All @@ -33,7 +31,7 @@ class RSSController:

Parameters
----------
feed_input : str
feed_input : str or list[str]
The URL or file path of the RSS feed.
batch_size : int, optional, default = 128
Number of feed items to accumulate before creating a DataFrame.
Expand All @@ -50,14 +48,14 @@ def __init__(self, feed_input: str | list[str], batch_size: int = 128, run_indef

if (run_indefinitely is None):
# If feed_input is URL. Runs indefinitely
run_indefinitely = any([RSSController.is_url(f) for f in self._feed_input])
run_indefinitely = any(RSSController.is_url(f) for f in self._feed_input)

self._run_indefinitely = run_indefinitely

self._session = requests_cache.CachedSession(os.path.join("./.cache/http", "RSSController.sqlite"),
backend="sqlite")

self._blacklisted_feeds = [] # Feeds that have thrown an error and wont be retried
self._errored_feeds = [] # Feeds that have thrown an error and wont be retried

@property
def run_indefinitely(self):
Expand Down Expand Up @@ -86,23 +84,30 @@ def parse_feed(self) -> list[dict]:
raise RuntimeError(f"Invalid feed input: {self._feed_input}. No entries found.")

def _try_parse_feed(self, url: str):
is_url = RSSController.is_url(url)
if (is_url):
response = self._session.get(url)
cache_hit = response.from_cache

response = self._session.get(url)
feed_input = response.text
else:
cache_hit = False
feed_input = url

# Try to use requests to get the object
feed = feedparser.parse(response.text)
feed = feedparser.parse(feed_input)

cache_hit = response.from_cache
fallback = False

if (feed["bozo"]):
cache_hit = False
fallback = True

logger.info(f"Failed to parse feed: {url}. Trying to parse using feedparser directly.")
if (is_url):
fallback = True
logger.info(f"Failed to parse feed: {url}. Trying to parse using feedparser directly.")

# If that fails, use feedparser directly (cant cache this)
feed = feedparser.parse(url)
# If that fails, use feedparser directly (cant cache this)
feed = feedparser.parse(url)

if (feed["bozo"]):
raise RuntimeError(f"Invalid feed input: {url}. Error: {feed['bozo_exception']}")
Expand All @@ -127,7 +132,7 @@ def parse_feeds(self):
"""
for url in self._feed_input:
try:
if (url in self._blacklisted_feeds):
if (url in self._errored_feeds):
continue

feed = self._try_parse_feed(url)
Expand All @@ -138,9 +143,9 @@ def parse_feeds(self):
yield feed

except Exception as ex:
logger.warning(f"Failed to parse feed: {url}. The feed will be blacklisted and not retried.")
logger.warning("Failed to parse feed: %s: %s. The feed will be not be retried.", url, ex)

self._blacklisted_feeds.append(url)
self._errored_feeds.append(url)

def fetch_dataframes(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions morpheus/stages/input/rss_source_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class RSSSourceStage(PreallocatorMixin, SingleOutputSource):
----------
c : morpheus.config.Config
Pipeline configuration instance.
feed_input : str
feed_input : list[str]
The URL or file path of the RSS feed.
interval_secs : float, optional, default = 600
Interval in seconds between fetching new feed items.
Expand All @@ -49,7 +49,7 @@ class RSSSourceStage(PreallocatorMixin, SingleOutputSource):

def __init__(self,
c: Config,
feed_input: str | list[str],
feed_input: list[str],
interval_secs: float = 600,
stop_after: int = 0,
max_retries: int = 5,
Expand Down
4 changes: 3 additions & 1 deletion morpheus/stages/preprocess/deserialize_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
logger = logging.getLogger(__name__)


@register_stage("deserialize", modes=[PipelineModes.FIL, PipelineModes.NLP, PipelineModes.OTHER])
@register_stage("deserialize",
modes=[PipelineModes.FIL, PipelineModes.NLP, PipelineModes.OTHER],
ignore_args=["message_type", "task_type", "task_payload"])
class DeserializeStage(MultiMessageStage):
"""
Messages are logically partitioned based on the pipeline config's `pipeline_batch_size` parameter.
Expand Down
10 changes: 10 additions & 0 deletions morpheus/utils/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import types
import typing
from collections import defaultdict

Expand Down Expand Up @@ -45,6 +46,15 @@ def greatest_ancestor(*cls_list):
return None # or raise, if that's more appropriate


def is_union_type(type_: type) -> bool:
"""
Returns True if the type is a `typing.Union` or a `types.UnionType`.
"""
# Unions in the form of `(float | int)` are instances of `types.UnionType`.
# However, unions in the form of `typing.Union[float, int]` are instances of `typing._UnionGenericAlias`.
return isinstance(type_, (types.UnionType, typing._UnionGenericAlias))


@typing.overload
def unpack_union(cls_1: typing.Type[T]) -> typing.Union[typing.Type[T]]:
...
Expand Down
6 changes: 3 additions & 3 deletions tests/controllers/test_rss_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def test_run_indefinitely_false(feed_input):
@pytest.mark.parametrize("feed_input", test_urls)
def test_parse_feed_valid_url(feed_input):
controller = RSSController(feed_input=feed_input)
feed = controller.parse_feed()
feed = list(controller.parse_feeds())[0]
assert feed.entries


@pytest.mark.parametrize("feed_input", test_invalid_urls + test_invalid_file_paths)
def test_parse_feed_invalid_input(feed_input):
controller = RSSController(feed_input=feed_input)
with pytest.raises(RuntimeError):
controller.parse_feed()
list(controller.parse_feeds())
assert controller._errored_feeds == [feed_input]


@pytest.mark.parametrize("feed_input", test_urls + test_file_paths)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_deserialize_stage_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def test_fixing_non_unique_indexes(dataset: DatasetManager):
# When processing the dataframe, a warning should be generated when there are non-unique IDs
with pytest.warns(RuntimeWarning):

DeserializeStage.process_dataframe(meta, 5, ensure_sliceable_index=False)
DeserializeStage.process_dataframe_to_multi_message(meta, 5, ensure_sliceable_index=False)

assert not meta.has_sliceable_index()
assert "_index_" not in meta.df.columns

dataset.assert_df_equal(meta.df, df)

DeserializeStage.process_dataframe(meta, 5, ensure_sliceable_index=True)
DeserializeStage.process_dataframe_to_multi_message(meta, 5, ensure_sliceable_index=True)

assert meta.has_sliceable_index()
assert "_index_" in meta.df.columns
Expand Down
10 changes: 6 additions & 4 deletions tests/test_rss_source_stage_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_constructor_with_feed_url(config):

ctlr = rss_source_stage._controller

assert ctlr._feed_input == "https://realpython.com/atom.xml"
assert ctlr._feed_input == ["https://realpython.com/atom.xml"]
assert ctlr._run_indefinitely is True
assert ctlr._batch_size == config.pipeline_batch_size
assert rss_source_stage._interval_secs == 600
Expand All @@ -47,7 +47,7 @@ def test_constructor_with_feed_file(config):

ctlr = rss_source_stage._controller

assert ctlr._feed_input == file_feed_input
assert ctlr._feed_input == [file_feed_input]
assert ctlr._run_indefinitely is False
assert ctlr._batch_size == config.pipeline_batch_size
assert rss_source_stage._interval_secs == 5
Expand Down Expand Up @@ -94,5 +94,7 @@ def test_invalid_input_rss_source_stage_pipe(config) -> None:

pipe.add_edge(rss_source_stage, sink_stage)

with pytest.raises(RuntimeError):
pipe.run()
pipe.run()

assert len(sink_stage.get_messages()) == 0
assert rss_source_stage._controller._errored_feeds == [feed_input]
4 changes: 2 additions & 2 deletions tests/tests_data/service/milvus_simple_collection_conf.json
Git LFS file not shown
Loading