-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #341 from golemfactory/az/handle-subscription-expi…
…ration Handle ApiException caused by subscription expiration
- Loading branch information
Showing
3 changed files
with
274 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
"""Test if subscription expiration is handled correctly by Executor""" | ||
from datetime import timedelta | ||
import logging | ||
import os | ||
from pathlib import Path | ||
import time | ||
from typing import Dict, Set, Type | ||
from unittest.mock import Mock | ||
|
||
import colors | ||
import pytest | ||
|
||
from goth.assertions import EventStream | ||
from goth.assertions.monitor import EventMonitor | ||
from goth.assertions.operators import eventually | ||
|
||
from goth.configuration import load_yaml | ||
from goth.runner import Runner | ||
from goth.runner.log import configure_logging | ||
from goth.runner.probe import RequestorProbe | ||
|
||
from yapapi import Executor, Task | ||
from yapapi.executor.events import ( | ||
Event, | ||
ComputationStarted, | ||
ComputationFinished, | ||
SubscriptionCreated, | ||
) | ||
import yapapi.rest.market | ||
from yapapi.log import enable_default_logger | ||
from yapapi.package import vm | ||
|
||
import ya_market.api.requestor_api | ||
from ya_market import ApiException | ||
|
||
logger = logging.getLogger("goth.test") | ||
|
||
SUBSCRIPTION_EXPIRATION_TIME = 5 | ||
"""Number of seconds after which a subscription expires""" | ||
|
||
|
||
class RequestorApi(ya_market.api.requestor_api.RequestorApi): | ||
"""A replacement for market API that simulates early subscription expiration. | ||
A call to `collect_offers(sub_id)` will raise `ApiException` indicating | ||
subscription expiration when at least `SUBSCRIPTION_EXPIRATION_TIME` | ||
elapsed after the given subscription has been created. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.subscriptions: Dict[str, float] = {} | ||
|
||
def subscribe_demand(self, demand, **kwargs): | ||
"""Override `RequestorApi.subscribe_demand()` to register subscription create time.""" | ||
id_coro = super().subscribe_demand(demand, **kwargs) | ||
|
||
async def coro(): | ||
id = await id_coro | ||
self.subscriptions[id] = time.time() | ||
return id | ||
|
||
return coro() | ||
|
||
def collect_offers(self, subscription_id, **kwargs): | ||
"""Override `RequestorApi.collect_offers()`. | ||
Raise `ApiException(404)` if at least `SUBSCRIPTION_EXPIRATION_TIME` elapsed | ||
since the subscription identified by `subscription_id` has been created. | ||
""" | ||
if time.time() > self.subscriptions[subscription_id] + SUBSCRIPTION_EXPIRATION_TIME: | ||
logger.info("Subscription expired") | ||
|
||
async def coro(): | ||
raise ApiException( | ||
http_resp=Mock( | ||
status=404, | ||
reason="Not Found", | ||
data=f"{{'message': 'Subscription [{subscription_id}] expired.'}}", | ||
) | ||
) | ||
|
||
return coro() | ||
else: | ||
return super().collect_offers(subscription_id, **kwargs) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def patch_collect_offers(monkeypatch): | ||
"""Install the patched `RequestorApi` class.""" | ||
monkeypatch.setattr(yapapi.rest.market, "RequestorApi", RequestorApi) | ||
|
||
|
||
async def unsubscribe_demand(sub_id: str) -> None: | ||
"""Auxiliary function that calls `unsubscribeDemand` operation for given `sub_id`.""" | ||
config = yapapi.rest.Configuration() | ||
market_client = config.market() | ||
requestor_api = yapapi.rest.market.RequestorApi(market_client) | ||
await requestor_api.unsubscribe_demand(sub_id) | ||
|
||
|
||
async def assert_demand_resubscribed(events: "EventStream[Event]"): | ||
"""A temporal assertion that the requestor will have to satisfy.""" | ||
|
||
subscription_ids: Set[str] = set() | ||
|
||
async def wait_for_event(event_type: Type[Event], timeout: float): | ||
e = await eventually(events, lambda e: isinstance(e, event_type), timeout) | ||
assert e, f"Timed out waiting for {event_type}" | ||
logger.info(colors.cyan(str(e))) | ||
return e | ||
|
||
e = await wait_for_event(ComputationStarted, 10) | ||
|
||
# Make sure new subscriptions are created at least three times | ||
while len(subscription_ids) < 3: | ||
e = await wait_for_event(SubscriptionCreated, SUBSCRIPTION_EXPIRATION_TIME + 10) | ||
assert e.sub_id not in subscription_ids | ||
subscription_ids.add(e.sub_id) | ||
|
||
# Unsubscribe and make sure new subscription is created | ||
await unsubscribe_demand(e.sub_id) | ||
logger.info("Demand unsubscribed") | ||
await wait_for_event(SubscriptionCreated, 5) | ||
|
||
# Enough checking, wait until the computation finishes | ||
await wait_for_event(ComputationFinished, 20) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_demand_resubscription(log_dir: Path, monkeypatch) -> None: | ||
"""Test that checks that a demand is re-submitted after its previous submission expires.""" | ||
|
||
configure_logging(log_dir) | ||
|
||
# Override the default test configuration to create only one provider node | ||
nodes = [ | ||
{"name": "requestor", "type": "Requestor"}, | ||
{"name": "provider-1", "type": "VM-Wasm-Provider", "use-proxy": True}, | ||
] | ||
goth_config = load_yaml( | ||
Path(__file__).parent / "assets" / "goth-config.yml", [("nodes", nodes)] | ||
) | ||
|
||
vm_package = await vm.repo( | ||
image_hash="9a3b5d67b0b27746283cb5f287c13eab1beaa12d92a9f536b747c7ae", | ||
min_mem_gib=0.5, | ||
min_storage_gib=2.0, | ||
) | ||
|
||
runner = Runner(base_log_dir=log_dir, compose_config=goth_config.compose_config) | ||
|
||
async with runner(goth_config.containers): | ||
|
||
requestor = runner.get_probes(probe_type=RequestorProbe)[0] | ||
env = {**os.environ} | ||
requestor.set_agent_env_vars(env) | ||
|
||
# Setup the environment for the requestor | ||
for key, val in env.items(): | ||
monkeypatch.setenv(key, val) | ||
|
||
monitor = EventMonitor() | ||
monitor.add_assertion(assert_demand_resubscribed) | ||
monitor.start() | ||
|
||
# The requestor | ||
|
||
enable_default_logger() | ||
|
||
async def worker(work_ctx, tasks): | ||
async for task in tasks: | ||
work_ctx.run("/bin/sleep", "5") | ||
yield work_ctx.commit() | ||
task.accept_result() | ||
|
||
async with Executor( | ||
budget=10.0, | ||
package=vm_package, | ||
max_workers=1, | ||
timeout=timedelta(seconds=30), | ||
event_consumer=monitor.add_event_sync, | ||
) as executor: | ||
|
||
task: Task # mypy needs this for some reason | ||
async for task in executor.submit(worker, [Task(data=n) for n in range(20)]): | ||
logger.info("Task %d computed", task.data) | ||
|
||
await monitor.stop() | ||
for a in monitor.failed: | ||
raise a.result() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.