Skip to content

Commit

Permalink
Merge pull request #341 from golemfactory/az/handle-subscription-expi…
Browse files Browse the repository at this point in the history
…ration

Handle ApiException caused by subscription expiration
  • Loading branch information
azawlocki authored Apr 21, 2021
2 parents d76f040 + d33d676 commit 346f22c
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 60 deletions.
191 changes: 191 additions & 0 deletions tests/goth/test_resubscription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Test if subscription expiration is handled correctly by Executor"""
from datetime import timedelta
import logging
import os
from pathlib import Path
import time
from typing import Dict, Set, Type
from unittest.mock import Mock

import colors
import pytest

from goth.assertions import EventStream
from goth.assertions.monitor import EventMonitor
from goth.assertions.operators import eventually

from goth.configuration import load_yaml
from goth.runner import Runner
from goth.runner.log import configure_logging
from goth.runner.probe import RequestorProbe

from yapapi import Executor, Task
from yapapi.executor.events import (
Event,
ComputationStarted,
ComputationFinished,
SubscriptionCreated,
)
import yapapi.rest.market
from yapapi.log import enable_default_logger
from yapapi.package import vm

import ya_market.api.requestor_api
from ya_market import ApiException

logger = logging.getLogger("goth.test")

SUBSCRIPTION_EXPIRATION_TIME = 5
"""Number of seconds after which a subscription expires"""


class RequestorApi(ya_market.api.requestor_api.RequestorApi):
"""A replacement for market API that simulates early subscription expiration.
A call to `collect_offers(sub_id)` will raise `ApiException` indicating
subscription expiration when at least `SUBSCRIPTION_EXPIRATION_TIME`
elapsed after the given subscription has been created.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.subscriptions: Dict[str, float] = {}

def subscribe_demand(self, demand, **kwargs):
"""Override `RequestorApi.subscribe_demand()` to register subscription create time."""
id_coro = super().subscribe_demand(demand, **kwargs)

async def coro():
id = await id_coro
self.subscriptions[id] = time.time()
return id

return coro()

def collect_offers(self, subscription_id, **kwargs):
"""Override `RequestorApi.collect_offers()`.
Raise `ApiException(404)` if at least `SUBSCRIPTION_EXPIRATION_TIME` elapsed
since the subscription identified by `subscription_id` has been created.
"""
if time.time() > self.subscriptions[subscription_id] + SUBSCRIPTION_EXPIRATION_TIME:
logger.info("Subscription expired")

async def coro():
raise ApiException(
http_resp=Mock(
status=404,
reason="Not Found",
data=f"{{'message': 'Subscription [{subscription_id}] expired.'}}",
)
)

return coro()
else:
return super().collect_offers(subscription_id, **kwargs)


@pytest.fixture(autouse=True)
def patch_collect_offers(monkeypatch):
"""Install the patched `RequestorApi` class."""
monkeypatch.setattr(yapapi.rest.market, "RequestorApi", RequestorApi)


async def unsubscribe_demand(sub_id: str) -> None:
"""Auxiliary function that calls `unsubscribeDemand` operation for given `sub_id`."""
config = yapapi.rest.Configuration()
market_client = config.market()
requestor_api = yapapi.rest.market.RequestorApi(market_client)
await requestor_api.unsubscribe_demand(sub_id)


async def assert_demand_resubscribed(events: "EventStream[Event]"):
"""A temporal assertion that the requestor will have to satisfy."""

subscription_ids: Set[str] = set()

async def wait_for_event(event_type: Type[Event], timeout: float):
e = await eventually(events, lambda e: isinstance(e, event_type), timeout)
assert e, f"Timed out waiting for {event_type}"
logger.info(colors.cyan(str(e)))
return e

e = await wait_for_event(ComputationStarted, 10)

# Make sure new subscriptions are created at least three times
while len(subscription_ids) < 3:
e = await wait_for_event(SubscriptionCreated, SUBSCRIPTION_EXPIRATION_TIME + 10)
assert e.sub_id not in subscription_ids
subscription_ids.add(e.sub_id)

# Unsubscribe and make sure new subscription is created
await unsubscribe_demand(e.sub_id)
logger.info("Demand unsubscribed")
await wait_for_event(SubscriptionCreated, 5)

# Enough checking, wait until the computation finishes
await wait_for_event(ComputationFinished, 20)


@pytest.mark.asyncio
async def test_demand_resubscription(log_dir: Path, monkeypatch) -> None:
"""Test that checks that a demand is re-submitted after its previous submission expires."""

configure_logging(log_dir)

# Override the default test configuration to create only one provider node
nodes = [
{"name": "requestor", "type": "Requestor"},
{"name": "provider-1", "type": "VM-Wasm-Provider", "use-proxy": True},
]
goth_config = load_yaml(
Path(__file__).parent / "assets" / "goth-config.yml", [("nodes", nodes)]
)

vm_package = await vm.repo(
image_hash="9a3b5d67b0b27746283cb5f287c13eab1beaa12d92a9f536b747c7ae",
min_mem_gib=0.5,
min_storage_gib=2.0,
)

runner = Runner(base_log_dir=log_dir, compose_config=goth_config.compose_config)

async with runner(goth_config.containers):

requestor = runner.get_probes(probe_type=RequestorProbe)[0]
env = {**os.environ}
requestor.set_agent_env_vars(env)

# Setup the environment for the requestor
for key, val in env.items():
monkeypatch.setenv(key, val)

monitor = EventMonitor()
monitor.add_assertion(assert_demand_resubscribed)
monitor.start()

# The requestor

enable_default_logger()

async def worker(work_ctx, tasks):
async for task in tasks:
work_ctx.run("/bin/sleep", "5")
yield work_ctx.commit()
task.accept_result()

async with Executor(
budget=10.0,
package=vm_package,
max_workers=1,
timeout=timedelta(seconds=30),
event_consumer=monitor.add_event_sync,
) as executor:

task: Task # mypy needs this for some reason
async for task in executor.submit(worker, [Task(data=n) for n in range(20)]):
logger.info("Task %d computed", task.data)

await monitor.stop()
for a in monitor.failed:
raise a.result()
124 changes: 65 additions & 59 deletions yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
import asyncio
from asyncio import CancelledError
import contextlib
from datetime import datetime, timedelta, timezone
from decimal import Decimal
import logging
Expand Down Expand Up @@ -36,10 +35,10 @@
from .utils import AsyncWrapper
from ..package import Package
from ..props import Activity, com, NodeInfo, NodeInfoKeys
from ..props.base import InvalidPropertiesError
from ..props.builder import DemandBuilder
from .. import rest
from ..rest.activity import CommandExecutionError
from ..rest.market import Subscription
from ..storage import gftp
from ._smartq import SmartQueue, Handle
from .strategy import (
Expand Down Expand Up @@ -337,8 +336,8 @@ async def process_debit_notes() -> None:
note_id=debit_note.debit_note_id,
)
)
allocation = self._get_allocation(debit_note)
try:
allocation = self._get_allocation(debit_note)
await debit_note.accept(
amount=debit_note.total_amount_due, allocation=allocation
)
Expand Down Expand Up @@ -369,82 +368,89 @@ async def accept_payment_for_agreement(agreement_id: str, *, partial: bool = Fal
)
)

async def find_offers() -> None:
async def find_offers_for_subscription(subscription: Subscription) -> None:
"""Subscribe to offers and process them continuously."""

async def reject_proposal(proposal, reason):
await proposal.reject(reason=reason)
emit(events.ProposalRejected(prop_id=proposal.id, reason=reason))

nonlocal offers_collected, proposals_confirmed

emit(events.SubscriptionCreated(sub_id=subscription.id))
try:
subscription = await builder.subscribe(market_api)
proposals = subscription.events()
except Exception as ex:
emit(events.SubscriptionFailed(reason=str(ex)))
emit(events.CollectFailed(sub_id=subscription.id, reason=str(ex)))
raise

async with subscription:
async for proposal in proposals:

emit(events.SubscriptionCreated(sub_id=subscription.id))
emit(events.ProposalReceived(prop_id=proposal.id, provider_id=proposal.issuer))
offers_collected += 1
try:
proposals = subscription.events()
except Exception as ex:
emit(events.CollectFailed(sub_id=subscription.id, reason=str(ex)))
raise

async for proposal in proposals:

emit(events.ProposalReceived(prop_id=proposal.id, provider_id=proposal.issuer))
offers_collected += 1
try:
score = await strategy.score_offer(proposal, agreements_pool)
logger.debug(
"Scored offer %s, provider: %s, strategy: %s, score: %f",
proposal.id,
proposal.props.get("golem.node.id.name"),
type(strategy).__name__,
score,
)
score = await strategy.score_offer(proposal, agreements_pool)
logger.debug(
"Scored offer %s, provider: %s, strategy: %s, score: %f",
proposal.id,
proposal.props.get("golem.node.id.name"),
type(strategy).__name__,
score,
)

if score < SCORE_NEUTRAL:
await reject_proposal(proposal, "Score too low")
if score < SCORE_NEUTRAL:
await reject_proposal(proposal, "Score too low")

elif not proposal.is_draft:
common_platforms = self._get_common_payment_platforms(proposal)
if common_platforms:
builder.properties["golem.com.payment.chosen-platform"] = next(
iter(common_platforms)
elif not proposal.is_draft:
common_platforms = self._get_common_payment_platforms(proposal)
if common_platforms:
builder.properties["golem.com.payment.chosen-platform"] = next(
iter(common_platforms)
)
else:
# reject proposal if there are no common payment platforms
await reject_proposal(proposal, "No common payment platform")
continue
timeout = proposal.props.get(DEBIT_NOTE_ACCEPTANCE_TIMEOUT_PROP)
if timeout:
if timeout < DEBIT_NOTE_MIN_TIMEOUT:
await reject_proposal(
proposal, "Debit note acceptance timeout too short"
)
else:
# reject proposal if there are no common payment platforms
await reject_proposal(proposal, "No common payment platform")
continue
timeout = proposal.props.get(DEBIT_NOTE_ACCEPTANCE_TIMEOUT_PROP)
if timeout:
if timeout < DEBIT_NOTE_MIN_TIMEOUT:
await reject_proposal(
proposal, "Debit note acceptance timeout too short"
)
continue
else:
builder.properties[DEBIT_NOTE_ACCEPTANCE_TIMEOUT_PROP] = timeout
else:
builder.properties[DEBIT_NOTE_ACCEPTANCE_TIMEOUT_PROP] = timeout

await proposal.respond(builder.properties, builder.constraints)
emit(events.ProposalResponded(prop_id=proposal.id))
await proposal.respond(builder.properties, builder.constraints)
emit(events.ProposalResponded(prop_id=proposal.id))

else:
emit(events.ProposalConfirmed(prop_id=proposal.id))
await agreements_pool.add_proposal(score, proposal)
proposals_confirmed += 1
else:
emit(events.ProposalConfirmed(prop_id=proposal.id))
await agreements_pool.add_proposal(score, proposal)
proposals_confirmed += 1

except CancelledError:
raise
except Exception:
emit(
events.ProposalFailed(
prop_id=proposal.id, exc_info=sys.exc_info() # type: ignore
)
except CancelledError:
raise
except Exception:
emit(
events.ProposalFailed(
prop_id=proposal.id, exc_info=sys.exc_info() # type: ignore
)
)

async def find_offers() -> None:
"""Create demand subscription and process offers.
When the subscription expires, create a new one. And so on...
"""
while True:
try:
subscription = await builder.subscribe(market_api)
except Exception as ex:
emit(events.SubscriptionFailed(reason=str(ex)))
raise
async with subscription:
await find_offers_for_subscription(subscription)

# aio_session = await self._stack.enter_async_context(aiohttp.ClientSession())
# storage_manager = await DavStorageProvider.for_directory(
Expand Down Expand Up @@ -751,7 +757,7 @@ def _get_allocation(
and allocation.payment_platform == item.payment_platform
)
except:
raise RuntimeError(f"No allocation for {item.payment_platform} {item.payer_addr}.")
raise ValueError(f"No allocation for {item.payment_platform} {item.payer_addr}.")

async def __aenter__(self) -> "Executor":
stack = self._stack
Expand Down
Loading

0 comments on commit 346f22c

Please sign in to comment.