Skip to content

Commit

Permalink
Fix telemetry_test and pubsub_test
Browse files Browse the repository at this point in the history
  • Loading branch information
linkous8 committed May 29, 2024
1 parent a675067 commit c3b2a27
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
2 changes: 1 addition & 1 deletion servo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def event_str(self, event: Events | str) -> str:
class Status(pydantic.BaseModel):
status: Statuses
message: Optional[str] = None
other_messages: Optional[list[str]] = (
additional_messages: Optional[list[str]] = (
None # other lower priority error in exception group
)
reason: Optional[str] = None
Expand Down
12 changes: 8 additions & 4 deletions servo/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ def __init__(
content = text.encode()
elif json is not None:
content = (
json.json()
if (hasattr(json, "json") and callable(json.json))
json.model_dump_json()
if (
hasattr(json, "model_dump_json")
and callable(json.model_dump_json)
)
else json_.dumps(json)
)
elif yaml is not None:
Expand Down Expand Up @@ -204,7 +207,8 @@ class _ExchangeChildModel(pydantic.BaseModel):

def __init__(self, *args, exchange: Exchange, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._exchange = weakref.ref(exchange)
if exchange is not None:
self._exchange = weakref.ref(exchange)

@property
def exchange(self) -> Exchange:
Expand Down Expand Up @@ -1370,7 +1374,7 @@ def _random_unique_channel_name(self) -> str:
while True:
name = _random_string()
if self.pubsub_exchange.get_channel(name) is None and re.match(
ChannelName.regex, name
ChannelName.__metadata__[0].pattern, name
):
return name

Expand Down
5 changes: 3 additions & 2 deletions servo/servo.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ async def _post_event(
try:
try:
response = await self._api_client.post(
"servo", data=event_request.model_dump_json()
"servo", data=event_request.model_dump_json(exclude_none=True)
)
except RuntimeError as e:
if "the handler is closed" in str(e):
Expand All @@ -629,7 +629,8 @@ async def _post_event(
self.config.optimizer, self.config.settings
)
response = await self._api_client.post(
"servo", data=event_request.model_dump_json()
"servo",
data=event_request.model_dump_json(exclude_none=True),
)
else:
raise
Expand Down
44 changes: 25 additions & 19 deletions tests/pubsub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,18 @@ def test_json_message(self) -> None:
@freezegun.freeze_time("2021-01-01 12:00:01")
def test_json_message_via_protocol(self) -> None:
# NOTE: Use Pydantic's json() method support
channel = servo.pubsub.Channel.construct(
name="whatever", created_at=datetime.datetime.now()
channel = servo.pubsub.Channel(
name="whatever", created_at=datetime.datetime.now(), exchange=None
)
message = servo.pubsub.Message(json=channel)
assert (
message.text
== '{"name": "whatever", "description": null, "created_at": "2021-01-01T12:00:01"}'
== '{"name":"whatever","description":null,"created_at":"2021-01-01T12:00:01"}'
)
assert message.content_type == "application/json"
assert (
message.content
== b'{"name": "whatever", "description": null, "created_at": "2021-01-01T12:00:01"}'
== b'{"name":"whatever","description":null,"created_at":"2021-01-01T12:00:01"}'
)

def test_yaml_message(self) -> None:
Expand Down Expand Up @@ -87,19 +87,23 @@ def test_content_required(self) -> None:
servo.pubsub.Message()

assert {
"input": None,
"loc": ("content",),
"msg": "none is not an allowed value",
"type": "type_error.none.not_allowed",
"msg": "Input should be a valid bytes",
"type": "bytes_type",
"url": "https://errors.pydantic.dev/2.7/v/bytes_type",
} in excinfo.value.errors()

def test_content_type_required(self) -> None:
with pytest.raises(pydantic.ValidationError) as excinfo:
servo.pubsub.Message(content=b"foo")

assert {
"input": None,
"loc": ("content_type",),
"msg": "none is not an allowed value",
"type": "type_error.none.not_allowed",
"msg": "Input should be a valid string",
"type": "string_type",
"url": "https://errors.pydantic.dev/2.7/v/string_type",
} in excinfo.value.errors()


Expand Down Expand Up @@ -127,22 +131,24 @@ def test_name_required(self, exchange: servo.pubsub.Exchange) -> None:
servo.pubsub.Channel(exchange=exchange)

assert {
"input": {},
"loc": ("name",),
"msg": "field required",
"type": "value_error.missing",
"msg": "Field required",
"type": "missing",
"url": "https://errors.pydantic.dev/2.7/v/missing",
} in excinfo.value.errors()

def test_name_constraints(self, exchange: servo.pubsub.Exchange) -> None:
with pytest.raises(pydantic.ValidationError) as excinfo:
servo.pubsub.Channel(exchange=exchange, name="THIS_IS_INVALID***")

assert {
"ctx": {"pattern": "^[0-9a-zA-Z]([0-9a-zA-Z\\.\\-_])*[0-9A-Za-z]$"},
"input": "THIS_IS_INVALID***",
"loc": ("name",),
"msg": 'string does not match regex "^[0-9a-zA-Z]([0-9a-zA-Z\\.\\-_])*[0-9A-Za-z]$"',
"type": "value_error.str.regex",
"ctx": {
"pattern": "^[0-9a-zA-Z]([0-9a-zA-Z\\.\\-_])*[0-9A-Za-z]$",
},
"msg": "String should match pattern '^[0-9a-zA-Z]([0-9a-zA-Z\\.\\-_])*[0-9A-Za-z]$'",
"type": "string_pattern_mismatch",
"url": "https://errors.pydantic.dev/2.7/v/string_pattern_mismatch",
} in excinfo.value.errors()

def test_exchange_required(self) -> None:
Expand All @@ -158,7 +164,7 @@ def test_hashing(self, channel: servo.pubsub.Channel) -> None:
channels = {
channel,
}
copy_of_channel = channel.copy()
copy_of_channel = channel.model_copy()
assert copy_of_channel in channels
copy_of_channel.name = "another_name"
assert copy_of_channel not in channels
Expand Down Expand Up @@ -916,7 +922,7 @@ async def _aggregate_metrics(
channel: servo.pubsub.Channel,
) -> None:
if aggregator.message is None:
aggregator.message = message.copy()
aggregator.message = message.model_copy()
else:
text = "\n".join([aggregator.message.text, message.text])
aggregator.message = servo.pubsub.Message(text=text)
Expand Down Expand Up @@ -944,7 +950,7 @@ async def _aggregate_metrics(
channel: servo.pubsub.Channel,
) -> None:
if aggregator.message is None:
aggregator.message = message.copy()
aggregator.message = message.model_copy()
else:
text = "\n".join([aggregator.message.text, message.text])
aggregator.message = servo.pubsub.Message(text=text)
Expand Down Expand Up @@ -997,7 +1003,7 @@ async def _aggregate_metrics(
channel: servo.pubsub.Channel,
) -> None:
if aggregator.message is None:
aggregator.message = message.copy()
aggregator.message = message.model_copy()
else:
text = "\n".join([aggregator.message.text, message.text])
aggregator.message = servo.pubsub.Message(text=text)
Expand Down
4 changes: 2 additions & 2 deletions tests/telemetry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
async def test_telemetry_hello(
monkeypatch, optimizer: servo.configuration.OpsaniOptimizer
) -> None:
expected = f'"telemetry": {{"servox.version": "{servo.__version__}", "servox.platform": "{platform.platform()}", "servox.namespace": "test-namespace"}}'
expected = f'"telemetry":{{"servox.version":"{servo.__version__}","servox.platform":"{platform.platform()}","servox.namespace":"test-namespace"}}'

# Simulate running as a k8s pod
monkeypatch.setenv("POD_NAMESPACE", "test-namespace")
Expand Down Expand Up @@ -142,7 +142,7 @@ async def test_diagnostics_put(
method="PUT",
endpoint=servo.telemetry.DIAGNOSTICS_OUTPUT_ENDPOINT,
output_model=servo.api.Status,
json=diagnostic_data.dict(),
json=diagnostic_data.model_dump(),
)

assert put.called
Expand Down

0 comments on commit c3b2a27

Please sign in to comment.