Skip to content

Commit

Permalink
fix the agreements pool tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeofblue committed Jun 20, 2023
1 parent 9366cfc commit 81383db
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions tests/test_agreements_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ async def create_agreement():
return create_agreement


def get_agreements_pool() -> agreements_pool.AgreementsPool:
return agreements_pool.AgreementsPool(
lambda _event, **kwargs: None, lambda _offer: None, mock.Mock()
)


@pytest.mark.asyncio
async def test_use_agreement_chooses_max_score():
"""Test that a proposal with the largest score is chosen in AgreementsPool.use_agreement()."""
Expand All @@ -40,7 +46,7 @@ async def test_use_agreement_chooses_max_score():
mock_score = random.random()
proposals[n] = (mock_score, mock_proposal)

pool = agreements_pool.AgreementsPool(lambda _event, **kwargs: None, lambda _offer: None)
pool = get_agreements_pool()

for score, proposal in proposals.values():
await pool.add_proposal(score, proposal)
Expand Down Expand Up @@ -76,7 +82,7 @@ async def test_use_agreement_shuffles_proposals():
mock_score = 42.0 if n != 0 else 41.0
proposals.append((mock_score, mock_proposal))

pool = agreements_pool.AgreementsPool(lambda _event, **kwargs: None, lambda _offer: None)
pool = get_agreements_pool()

for score, proposal in proposals:
await pool.add_proposal(score, proposal)
Expand All @@ -95,7 +101,7 @@ def use_agreement_cb(agreement):
async def test_use_agreement_no_proposals():
"""Test that `AgreementPool.use_agreement()` returns `None` when there are no proposals."""

pool = agreements_pool.AgreementsPool(lambda _event, **kwargs: None, lambda _offer: None)
pool = get_agreements_pool()

def use_agreement_cb(_agreement):
assert False, "use_agreement callback called"
Expand All @@ -120,7 +126,7 @@ async def test_terminate_agreement(multi_activity, simulate_race, event_emitted)
events = []

pool = agreements_pool.AgreementsPool(
lambda event, **kwargs: events.append(event), lambda _offer: None # noqa
lambda event, **kwargs: events.append(event), lambda _offer: None, mock.Mock() # noqa
)
agreement: BufferedAgreement = BufferedAgreementFactory(has_multi_activity=multi_activity)
pool._agreements[agreement.agreement.id] = agreement
Expand Down

0 comments on commit 81383db

Please sign in to comment.