Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: multiuse invites with did peer 4 #3112

Merged
merged 6 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ repos:
# Run the formatter
- id: ruff-format
stages: [commit]
args: [--fix, --exit-non-zero-on-fix, --formatter]
40 changes: 18 additions & 22 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,6 @@ async def find_did_for_key(self, key: str) -> str:
storage: BaseStorage = session.inject(BaseStorage)
record = await storage.find_record(self.RECORD_TYPE_DID_KEY, {"key": key})
ret_did = record.tags["did"]
if ret_did.startswith("did:peer:4"):
ret_did = self.long_did_peer_to_short(ret_did)
return ret_did

async def remove_keys_for_did(self, did: str):
Expand All @@ -452,9 +450,7 @@ async def resolve_didcomm_services(
doc_dict: dict = await resolver.resolve(self._profile, did, service_accept)
doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True)
except ResolverError as error:
raise BaseConnectionManagerError(
"Failed to resolve DID services"
) from error
raise BaseConnectionManagerError("Failed to resolve DID services") from error

if not doc.service:
raise BaseConnectionManagerError(
Expand Down Expand Up @@ -523,10 +519,7 @@ async def resolve_invitation(

return (
endpoint,
[
self._extract_key_material_in_base58_format(key)
for key in recipient_keys
],
[self._extract_key_material_in_base58_format(key) for key in recipient_keys],
[self._extract_key_material_in_base58_format(key) for key in routing_keys],
)

Expand Down Expand Up @@ -800,9 +793,7 @@ async def get_connection_targets(
async with cache.acquire(cache_key) as entry:
if entry.result:
self._logger.debug("Connection targets retrieved from cache")
targets = [
ConnectionTarget.deserialize(row) for row in entry.result
]
targets = [ConnectionTarget.deserialize(row) for row in entry.result]
else:
if not connection:
async with self._profile.session() as session:
Expand All @@ -817,9 +808,7 @@ async def get_connection_targets(
# Otherwise, a replica that participated early in exchange
# may have bad data set in cache.
self._logger.debug("Caching connection targets")
await entry.set_result(
[row.serialize() for row in targets], 3600
)
await entry.set_result([row.serialize() for row in targets], 3600)
else:
self._logger.debug(
"Not caching connection targets for connection in "
Expand Down Expand Up @@ -878,12 +867,8 @@ def diddoc_connection_targets(
did=doc.did,
endpoint=service.endpoint,
label=their_label,
recipient_keys=[
key.value for key in (service.recip_keys or ())
],
routing_keys=[
key.value for key in (service.routing_keys or ())
],
recipient_keys=[key.value for key in (service.recip_keys or ())],
routing_keys=[key.value for key in (service.routing_keys or ())],
sender_key=sender_verkey,
)
)
Expand Down Expand Up @@ -920,7 +905,18 @@ async def find_connection(

"""
connection = None
if their_did:
if their_did and their_did.startswith("did:peer:4"):
# did:peer:4 always recorded as long
long = their_did
short = self.long_did_peer_to_short(their_did)
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did_peer_4(
session, long, short, my_did
)
except StorageNotFoundError:
pass
elif their_did:
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did(
Expand Down
46 changes: 42 additions & 4 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def __init__(
self.their_role = (
ConnRecord.Role.get(their_role).rfc160
if isinstance(their_role, str)
else None if their_role is None else their_role.rfc160
else None
if their_role is None
else their_role.rfc160
)
self.invitation_key = invitation_key
self.invitation_msg_id = invitation_msg_id
Expand Down Expand Up @@ -293,6 +295,44 @@ async def retrieve_by_did(

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_did_peer_4(
cls,
session: ProfileSession,
their_did_long: Optional[str] = None,
their_did_short: Optional[str] = None,
my_did: Optional[str] = None,
their_role: Optional[str] = None,
) -> "ConnRecord":
"""Retrieve a connection record by target DID.

Args:
session: The active profile session
their_did_long: The target DID to filter by, in long form
their_did_short: The target DID to filter by, in short form
my_did: One of our DIDs to filter by
my_role: Filter connections by their role
their_role: Filter connections by their role
"""
tag_filter = {}
if their_did_long and their_did_short:
tag_filter["$or"] = [
{"their_did": their_did_long},
{"their_did": their_did_short},
]
elif their_did_short:
tag_filter["their_did"] = their_did_short
elif their_did_long:
tag_filter["their_did"] = their_did_long
if my_did:
tag_filter["my_did"] = my_did

post_filter = {}
if their_role:
post_filter["their_role"] = cls.Role.get(their_role).rfc160

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_invitation_key(
cls, session: ProfileSession, invitation_key: str, their_role: str = None
Expand Down Expand Up @@ -375,9 +415,7 @@ async def retrieve_by_request_id(
return await cls.retrieve_by_tag_filter(session, tag_filter)

@classmethod
async def retrieve_by_alias(
cls, session: ProfileSession, alias: str
) -> "ConnRecord":
async def retrieve_by_alias(cls, session: ProfileSession, alias: str) -> "ConnRecord":
"""Retrieve a connection record from an alias.

Args:
Expand Down
86 changes: 77 additions & 9 deletions aries_cloudagent/connections/models/tests/test_conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def setUp(self):
self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP"
self.test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC"

self.test_did_peer_4_a = "did:peer:4zQmV3Hf1TT4Xn73MBVf2NAWdMwrzUabpEvwtV3RoZc17Vxr:z2pfttj3xn6tJ7wpHV9ZSwpQVMNtHC7EtM36r1mC5fDpZ25882Yitk21QPbqzuefKPrbFsexWmQtE78vYWweckXtKeu5BhuFDvCjMUf8SC5z7cMPvp8SCdcbqWnHxygjBH9zAAs9myGRnZYXuAkq6CfBdn6ZiNmdRf65TdVfE3cYfS4jNzVZDs1abwytn4jdFJ2fwVegPB3vLY8XxeUEx12a4rtjkqMhs6zBQbJvc4PVUM9rvMbPM2QeXDy7ovkkHaKLUbNUxjQrcQeiR8MTLe1iaVtUv6RpBf4z7ioqfa4VDRmAZT7isVM3NvENUceeUfDZoFbM8PZqGkCbFvfoKiK3SrmTsvPtpXaBAfR4z7w18cFjsvvLBNMZbPnARn4oZijCkYwgaNmAUthgDP4XBFetdUo8728w25FUwTWjAPc1BdSSWPWMRKwCqyAP1Q1hM8dU6otT27MQaQ1rozKncn3U48CXEi2Ef26EDBrSozEWR273ancFojNXBbZVghZG5b6xdypjQir9PgTF94dsygtu47hNxQweVKLUM1p9umqHLhjvLhpS1aGQkGZNnKUHjLDHdToigo15F7TAf8RfMaducHBThFzEp9TUJmiZFTUYQ1uaBgSPMSaWnvTfUoFmLoGbdrWj1vVEsRrARq37u1SJGLqBx7FM2SUd8nxPsChP5jY8ka8F8r7j8qZLHZqvUXbynPUsViwwdFFk8SCsBWfiQgvq7sRiTdLnYv3H5DSwA1uW2GNYXGgkT9aJza4Sk1gvag5iAbQZgxbU594enjVSTjiWsFw2oYQ75JJwiSEgsP2rhpGsNhXxfECNLUtb7FQbDQPtUvLHCJATf7QXJEoWjpfAywmB6NyQcXfskco6FKJNNHeZBnST6U1meH98Ku66vha1k8hAc72iBhXQBnWUjaGRyzELsh2LkBH2UNwW9TuFhxz3SKtL5pGShVQ5XGQhmdrkWP68d6h7c1JqsfogcDBnmWS4VSbJwgtsPNTSsTHGX8hpGvg"
self.test_did_peer_4_short_a = (
"did:peer:4zQmV3Hf1TT4Xn73MBVf2NAWdMwrzUabpEvwtV3RoZc17Vxr"
)
self.test_did_peer_4_b = "did:peer:4zQmQ4dEtoGcivpiH6gtWwhWJY2ENVWuZifb62uzR76HGPPw:z7p4QX8zEXt2sMjv1Tqq8Lv8Nx8oGo2uRczBe21vyfMhQzsWDnwGmjriYfUX75WDq622czcdHjWGhh2VTbzKhLXUjY8Ma7g64dKAVcy8SaxN5QVdjwpXgD7htKCgCjah8jHEzyBZFrtdfTHiVXfSUz1BiURQf1Z3NfxW5cWYsvDJVvQzVmdHb8ekzCnvxCqL2UV1v9SBb1DsU66N3PCp9HVpSrqUJQyFU2Ddc8bb6u8SJfBU1nyCkNMgfA1zAyKnSBrzZWyyNzAm9oBV36qjC1Qjfcpq4FBnGr7foh5sLXppBwu2ES8U2nxdGrQzAbN47DKBoKJqPVxNh5tTuBdYjDGt7PcvZQjHQGNXXuhJctM5besZci2saGefCHzoZ87vSsFuKq6oXEsW512eadiNZWjHSdG9J4ToMEMK9WT66vGGLFdZszB3xhdFqEDnAMcpnoFUL5WN243aH6492jPC2Zjdi1BvHC1J8bUuvyihAKXF3WmFz7gJWmh6MrTEWNqb17K6tqbyXjFmfnS2RbAi8xBFj3sSsXkSs6TRTXAZD9DenYaQq4RMa2Kqh6VKGvkXAjVHKcPh9Ncpt6rU9ZYttNHbDJFgahwB8KisVBK8FBpG"
self.test_did_peer_4_short_b = (
"did:peer:4zQmQ4dEtoGcivpiH6gtWwhWJY2ENVWuZifb62uzR76HGPPw"
)

self.test_conn_record = ConnRecord(
my_did=self.test_did,
their_did=self.test_target_did,
Expand All @@ -39,9 +48,7 @@ async def test_get_enums(self):
assert ConnRecord.Role.get("Larry") is None
assert ConnRecord.State.get("a suffusion of yellow") is None

assert (
ConnRecord.Role.get(ConnRecord.Role.REQUESTER) is ConnRecord.Role.REQUESTER
)
assert ConnRecord.Role.get(ConnRecord.Role.REQUESTER) is ConnRecord.Role.REQUESTER

assert (
ConnRecord.State.get(ConnRecord.State.RESPONSE) is ConnRecord.State.RESPONSE
Expand Down Expand Up @@ -133,6 +140,71 @@ async def test_retrieve_by_did(self):
)
assert result == record

async def test_retrieve_by_did_peer_4_by_long(self):
record = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
rec_id = await record.save(self.session)
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_long=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record

async def test_retrieve_by_did_peer_4_by_short(self):
record = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_short_b,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record.save(self.session)
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_b,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record

async def test_retrieve_by_did_peer_4_by_either(self):
record_short = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_short_a,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record_short.save(self.session)
record_long = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_b,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record_long.save(self.session)

result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_a,
their_did_long=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record_short
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_b,
their_did_long=self.test_did_peer_4_b,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record_long

async def test_from_storage_with_initiator_old(self):
record = ConnRecord(my_did=self.test_did, state=ConnRecord.State.COMPLETED)
ser = record.serialize()
Expand Down Expand Up @@ -300,9 +372,7 @@ async def test_attach_retrieve_request(self):
connection_id = await record.save(self.session)

req = ConnectionRequest(
connection=ConnectionDetail(
did=self.test_did, did_doc=DIDDoc(self.test_did)
),
connection=ConnectionDetail(did=self.test_did, did_doc=DIDDoc(self.test_did)),
label="abc123",
)
await record.attach_request(self.session, req)
Expand All @@ -317,9 +387,7 @@ async def test_attach_request_abstain_on_alien_deco(self):
connection_id = await record.save(self.session)

req = ConnectionRequest(
connection=ConnectionDetail(
did=self.test_did, did_doc=DIDDoc(self.test_did)
),
connection=ConnectionDetail(did=self.test_did, did_doc=DIDDoc(self.test_did)),
label="abc123",
)
ser = req.serialize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def test_called(self):
await handler.handle(request_context, responder)

mock_tran_mgr.return_value.set_transaction_their_job.assert_called_once_with(
request_context.message, request_context.message_receipt
request_context.message, request_context.connection_record
)
assert not responder.messages

Expand All @@ -48,6 +48,6 @@ async def test_called_x(self):
await handler.handle(request_context, responder)

mock_tran_mgr.return_value.set_transaction_their_job.assert_called_once_with(
request_context.message, request_context.message_receipt
request_context.message, request_context.connection_record
)
assert not responder.messages
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder):

if not context.connection_ready:
raise HandlerException("No connection established")
assert context.connection_record

mgr = TransactionManager(context.profile)
try:
await mgr.set_transaction_their_job(context.message, context.message_receipt)
await mgr.set_transaction_their_job(
context.message, context.connection_record
)
except TransactionManagerError:
self._logger.exception("Error receiving transaction jobs")
20 changes: 5 additions & 15 deletions aries_cloudagent/protocols/endorse_transaction/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
notify_revocation_reg_endorsed_event,
)
from ....storage.error import StorageError, StorageNotFoundError
from ....transport.inbound.receipt import MessageReceipt
from ....wallet.base import BaseWallet
from ....wallet.util import notify_endorse_did_attrib_event, notify_endorse_did_event
from .messages.cancel_transaction import CancelTransaction
Expand Down Expand Up @@ -310,9 +309,7 @@ async def create_endorse_response(
)
# we don't have an endorsed transaction so just return did meta-data
ledger_response = {
"result": {
"txn": {"type": "1", "data": {"dest": meta_data["did"]}}
},
"result": {"txn": {"type": "1", "data": {"dest": meta_data["did"]}}},
"meta_data": meta_data,
}
endorsed_msg = json.dumps(ledger_response)
Expand Down Expand Up @@ -430,9 +427,7 @@ async def complete_transaction(

# if we are the author, we need to write the endorsed ledger transaction ...
# ... EXCEPT for DID transactions, which the endorser will write
if (not endorser) and (
txn_goal_code != TransactionRecord.WRITE_DID_TRANSACTION
):
if (not endorser) and (txn_goal_code != TransactionRecord.WRITE_DID_TRANSACTION):
ledger = self.profile.inject(BaseLedger)
if not ledger:
raise TransactionManagerError("No ledger available")
Expand Down Expand Up @@ -772,20 +767,17 @@ async def set_transaction_my_job(self, record: ConnRecord, transaction_my_job: s
return tx_job_to_send

async def set_transaction_their_job(
self, tx_job_received: TransactionJobToSend, receipt: MessageReceipt
self, tx_job_received: TransactionJobToSend, connection: ConnRecord
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at this for a bit and didn't understand why it used MessageReceipt here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed; my impression is that it was just a question of the implementer being unfamiliar with the request context or something along those lines.

):
"""Set transaction_their_job.

Args:
tx_job_received: The transaction job that is received from the other agent
receipt: The Message Receipt Object
connection: connection to set metadata on
"""

try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did(
session, receipt.sender_did, receipt.recipient_did
)
value = await connection.metadata_get(session, "transaction_jobs")
if value:
value["transaction_their_job"] = tx_job_received.job
Expand Down Expand Up @@ -893,9 +885,7 @@ async def endorsed_txn_post_processing(
elif ledger_response["result"]["txn"]["type"] == "114":
# revocation entry transaction
rev_reg_id = ledger_response["result"]["txn"]["data"]["revocRegDefId"]
revoked = ledger_response["result"]["txn"]["data"]["value"].get(
"revoked", []
)
revoked = ledger_response["result"]["txn"]["data"]["value"].get("revoked", [])
meta_data["context"]["rev_reg_id"] = rev_reg_id
if is_anoncreds:
await AnonCredsRevocation(self._profile).finish_revocation_list(
Expand Down
Loading
Loading