Skip to content

Commit

Permalink
Merge pull request #3112 from dbluhm/fix/did-peer-4-conn-lookup
Browse files Browse the repository at this point in the history
fix: multiuse invites with did peer 4
  • Loading branch information
dbluhm authored Jul 22, 2024
2 parents 9f0a9d2 + c428b23 commit e07066c
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 94 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ repos:
# Run the formatter
- id: ruff-format
stages: [commit]
args: [--fix, --exit-non-zero-on-fix, --formatter]
40 changes: 18 additions & 22 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,6 @@ async def find_did_for_key(self, key: str) -> str:
storage: BaseStorage = session.inject(BaseStorage)
record = await storage.find_record(self.RECORD_TYPE_DID_KEY, {"key": key})
ret_did = record.tags["did"]
if ret_did.startswith("did:peer:4"):
ret_did = self.long_did_peer_to_short(ret_did)
return ret_did

async def remove_keys_for_did(self, did: str):
Expand All @@ -452,9 +450,7 @@ async def resolve_didcomm_services(
doc_dict: dict = await resolver.resolve(self._profile, did, service_accept)
doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True)
except ResolverError as error:
raise BaseConnectionManagerError(
"Failed to resolve DID services"
) from error
raise BaseConnectionManagerError("Failed to resolve DID services") from error

if not doc.service:
raise BaseConnectionManagerError(
Expand Down Expand Up @@ -523,10 +519,7 @@ async def resolve_invitation(

return (
endpoint,
[
self._extract_key_material_in_base58_format(key)
for key in recipient_keys
],
[self._extract_key_material_in_base58_format(key) for key in recipient_keys],
[self._extract_key_material_in_base58_format(key) for key in routing_keys],
)

Expand Down Expand Up @@ -800,9 +793,7 @@ async def get_connection_targets(
async with cache.acquire(cache_key) as entry:
if entry.result:
self._logger.debug("Connection targets retrieved from cache")
targets = [
ConnectionTarget.deserialize(row) for row in entry.result
]
targets = [ConnectionTarget.deserialize(row) for row in entry.result]
else:
if not connection:
async with self._profile.session() as session:
Expand All @@ -817,9 +808,7 @@ async def get_connection_targets(
# Otherwise, a replica that participated early in exchange
# may have bad data set in cache.
self._logger.debug("Caching connection targets")
await entry.set_result(
[row.serialize() for row in targets], 3600
)
await entry.set_result([row.serialize() for row in targets], 3600)
else:
self._logger.debug(
"Not caching connection targets for connection in "
Expand Down Expand Up @@ -878,12 +867,8 @@ def diddoc_connection_targets(
did=doc.did,
endpoint=service.endpoint,
label=their_label,
recipient_keys=[
key.value for key in (service.recip_keys or ())
],
routing_keys=[
key.value for key in (service.routing_keys or ())
],
recipient_keys=[key.value for key in (service.recip_keys or ())],
routing_keys=[key.value for key in (service.routing_keys or ())],
sender_key=sender_verkey,
)
)
Expand Down Expand Up @@ -920,7 +905,18 @@ async def find_connection(
"""
connection = None
if their_did:
if their_did and their_did.startswith("did:peer:4"):
# did:peer:4 always recorded as long
long = their_did
short = self.long_did_peer_to_short(their_did)
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did_peer_4(
session, long, short, my_did
)
except StorageNotFoundError:
pass
elif their_did:
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did(
Expand Down
46 changes: 42 additions & 4 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def __init__(
self.their_role = (
ConnRecord.Role.get(their_role).rfc160
if isinstance(their_role, str)
else None if their_role is None else their_role.rfc160
else None
if their_role is None
else their_role.rfc160
)
self.invitation_key = invitation_key
self.invitation_msg_id = invitation_msg_id
Expand Down Expand Up @@ -293,6 +295,44 @@ async def retrieve_by_did(

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_did_peer_4(
cls,
session: ProfileSession,
their_did_long: Optional[str] = None,
their_did_short: Optional[str] = None,
my_did: Optional[str] = None,
their_role: Optional[str] = None,
) -> "ConnRecord":
"""Retrieve a connection record by target DID.
Args:
session: The active profile session
their_did_long: The target DID to filter by, in long form
their_did_short: The target DID to filter by, in short form
my_did: One of our DIDs to filter by
my_role: Filter connections by their role
their_role: Filter connections by their role
"""
tag_filter = {}
if their_did_long and their_did_short:
tag_filter["$or"] = [
{"their_did": their_did_long},
{"their_did": their_did_short},
]
elif their_did_short:
tag_filter["their_did"] = their_did_short
elif their_did_long:
tag_filter["their_did"] = their_did_long
if my_did:
tag_filter["my_did"] = my_did

post_filter = {}
if their_role:
post_filter["their_role"] = cls.Role.get(their_role).rfc160

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_invitation_key(
cls, session: ProfileSession, invitation_key: str, their_role: str = None
Expand Down Expand Up @@ -375,9 +415,7 @@ async def retrieve_by_request_id(
return await cls.retrieve_by_tag_filter(session, tag_filter)

@classmethod
async def retrieve_by_alias(
cls, session: ProfileSession, alias: str
) -> "ConnRecord":
async def retrieve_by_alias(cls, session: ProfileSession, alias: str) -> "ConnRecord":
"""Retrieve a connection record from an alias.
Args:
Expand Down
86 changes: 77 additions & 9 deletions aries_cloudagent/connections/models/tests/test_conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def setUp(self):
self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP"
self.test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC"

self.test_did_peer_4_a = "did:peer:4zQmV3Hf1TT4Xn73MBVf2NAWdMwrzUabpEvwtV3RoZc17Vxr:z2pfttj3xn6tJ7wpHV9ZSwpQVMNtHC7EtM36r1mC5fDpZ25882Yitk21QPbqzuefKPrbFsexWmQtE78vYWweckXtKeu5BhuFDvCjMUf8SC5z7cMPvp8SCdcbqWnHxygjBH9zAAs9myGRnZYXuAkq6CfBdn6ZiNmdRf65TdVfE3cYfS4jNzVZDs1abwytn4jdFJ2fwVegPB3vLY8XxeUEx12a4rtjkqMhs6zBQbJvc4PVUM9rvMbPM2QeXDy7ovkkHaKLUbNUxjQrcQeiR8MTLe1iaVtUv6RpBf4z7ioqfa4VDRmAZT7isVM3NvENUceeUfDZoFbM8PZqGkCbFvfoKiK3SrmTsvPtpXaBAfR4z7w18cFjsvvLBNMZbPnARn4oZijCkYwgaNmAUthgDP4XBFetdUo8728w25FUwTWjAPc1BdSSWPWMRKwCqyAP1Q1hM8dU6otT27MQaQ1rozKncn3U48CXEi2Ef26EDBrSozEWR273ancFojNXBbZVghZG5b6xdypjQir9PgTF94dsygtu47hNxQweVKLUM1p9umqHLhjvLhpS1aGQkGZNnKUHjLDHdToigo15F7TAf8RfMaducHBThFzEp9TUJmiZFTUYQ1uaBgSPMSaWnvTfUoFmLoGbdrWj1vVEsRrARq37u1SJGLqBx7FM2SUd8nxPsChP5jY8ka8F8r7j8qZLHZqvUXbynPUsViwwdFFk8SCsBWfiQgvq7sRiTdLnYv3H5DSwA1uW2GNYXGgkT9aJza4Sk1gvag5iAbQZgxbU594enjVSTjiWsFw2oYQ75JJwiSEgsP2rhpGsNhXxfECNLUtb7FQbDQPtUvLHCJATf7QXJEoWjpfAywmB6NyQcXfskco6FKJNNHeZBnST6U1meH98Ku66vha1k8hAc72iBhXQBnWUjaGRyzELsh2LkBH2UNwW9TuFhxz3SKtL5pGShVQ5XGQhmdrkWP68d6h7c1JqsfogcDBnmWS4VSbJwgtsPNTSsTHGX8hpGvg"
self.test_did_peer_4_short_a = (
"did:peer:4zQmV3Hf1TT4Xn73MBVf2NAWdMwrzUabpEvwtV3RoZc17Vxr"
)
self.test_did_peer_4_b = "did:peer:4zQmQ4dEtoGcivpiH6gtWwhWJY2ENVWuZifb62uzR76HGPPw:z7p4QX8zEXt2sMjv1Tqq8Lv8Nx8oGo2uRczBe21vyfMhQzsWDnwGmjriYfUX75WDq622czcdHjWGhh2VTbzKhLXUjY8Ma7g64dKAVcy8SaxN5QVdjwpXgD7htKCgCjah8jHEzyBZFrtdfTHiVXfSUz1BiURQf1Z3NfxW5cWYsvDJVvQzVmdHb8ekzCnvxCqL2UV1v9SBb1DsU66N3PCp9HVpSrqUJQyFU2Ddc8bb6u8SJfBU1nyCkNMgfA1zAyKnSBrzZWyyNzAm9oBV36qjC1Qjfcpq4FBnGr7foh5sLXppBwu2ES8U2nxdGrQzAbN47DKBoKJqPVxNh5tTuBdYjDGt7PcvZQjHQGNXXuhJctM5besZci2saGefCHzoZ87vSsFuKq6oXEsW512eadiNZWjHSdG9J4ToMEMK9WT66vGGLFdZszB3xhdFqEDnAMcpnoFUL5WN243aH6492jPC2Zjdi1BvHC1J8bUuvyihAKXF3WmFz7gJWmh6MrTEWNqb17K6tqbyXjFmfnS2RbAi8xBFj3sSsXkSs6TRTXAZD9DenYaQq4RMa2Kqh6VKGvkXAjVHKcPh9Ncpt6rU9ZYttNHbDJFgahwB8KisVBK8FBpG"
self.test_did_peer_4_short_b = (
"did:peer:4zQmQ4dEtoGcivpiH6gtWwhWJY2ENVWuZifb62uzR76HGPPw"
)

self.test_conn_record = ConnRecord(
my_did=self.test_did,
their_did=self.test_target_did,
Expand All @@ -39,9 +48,7 @@ async def test_get_enums(self):
assert ConnRecord.Role.get("Larry") is None
assert ConnRecord.State.get("a suffusion of yellow") is None

assert (
ConnRecord.Role.get(ConnRecord.Role.REQUESTER) is ConnRecord.Role.REQUESTER
)
assert ConnRecord.Role.get(ConnRecord.Role.REQUESTER) is ConnRecord.Role.REQUESTER

assert (
ConnRecord.State.get(ConnRecord.State.RESPONSE) is ConnRecord.State.RESPONSE
Expand Down Expand Up @@ -133,6 +140,71 @@ async def test_retrieve_by_did(self):
)
assert result == record

async def test_retrieve_by_did_peer_4_by_long(self):
record = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
rec_id = await record.save(self.session)
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_long=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record

async def test_retrieve_by_did_peer_4_by_short(self):
record = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_short_b,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record.save(self.session)
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_b,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record

async def test_retrieve_by_did_peer_4_by_either(self):
record_short = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_short_a,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record_short.save(self.session)
record_long = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_b,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record_long.save(self.session)

result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_a,
their_did_long=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record_short
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_b,
their_did_long=self.test_did_peer_4_b,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record_long

async def test_from_storage_with_initiator_old(self):
record = ConnRecord(my_did=self.test_did, state=ConnRecord.State.COMPLETED)
ser = record.serialize()
Expand Down Expand Up @@ -300,9 +372,7 @@ async def test_attach_retrieve_request(self):
connection_id = await record.save(self.session)

req = ConnectionRequest(
connection=ConnectionDetail(
did=self.test_did, did_doc=DIDDoc(self.test_did)
),
connection=ConnectionDetail(did=self.test_did, did_doc=DIDDoc(self.test_did)),
label="abc123",
)
await record.attach_request(self.session, req)
Expand All @@ -317,9 +387,7 @@ async def test_attach_request_abstain_on_alien_deco(self):
connection_id = await record.save(self.session)

req = ConnectionRequest(
connection=ConnectionDetail(
did=self.test_did, did_doc=DIDDoc(self.test_did)
),
connection=ConnectionDetail(did=self.test_did, did_doc=DIDDoc(self.test_did)),
label="abc123",
)
ser = req.serialize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def test_called(self):
await handler.handle(request_context, responder)

mock_tran_mgr.return_value.set_transaction_their_job.assert_called_once_with(
request_context.message, request_context.message_receipt
request_context.message, request_context.connection_record
)
assert not responder.messages

Expand All @@ -48,6 +48,6 @@ async def test_called_x(self):
await handler.handle(request_context, responder)

mock_tran_mgr.return_value.set_transaction_their_job.assert_called_once_with(
request_context.message, request_context.message_receipt
request_context.message, request_context.connection_record
)
assert not responder.messages
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder):

if not context.connection_ready:
raise HandlerException("No connection established")
assert context.connection_record

mgr = TransactionManager(context.profile)
try:
await mgr.set_transaction_their_job(context.message, context.message_receipt)
await mgr.set_transaction_their_job(
context.message, context.connection_record
)
except TransactionManagerError:
self._logger.exception("Error receiving transaction jobs")
20 changes: 5 additions & 15 deletions aries_cloudagent/protocols/endorse_transaction/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
notify_revocation_reg_endorsed_event,
)
from ....storage.error import StorageError, StorageNotFoundError
from ....transport.inbound.receipt import MessageReceipt
from ....wallet.base import BaseWallet
from ....wallet.util import notify_endorse_did_attrib_event, notify_endorse_did_event
from .messages.cancel_transaction import CancelTransaction
Expand Down Expand Up @@ -310,9 +309,7 @@ async def create_endorse_response(
)
# we don't have an endorsed transaction so just return did meta-data
ledger_response = {
"result": {
"txn": {"type": "1", "data": {"dest": meta_data["did"]}}
},
"result": {"txn": {"type": "1", "data": {"dest": meta_data["did"]}}},
"meta_data": meta_data,
}
endorsed_msg = json.dumps(ledger_response)
Expand Down Expand Up @@ -430,9 +427,7 @@ async def complete_transaction(

# if we are the author, we need to write the endorsed ledger transaction ...
# ... EXCEPT for DID transactions, which the endorser will write
if (not endorser) and (
txn_goal_code != TransactionRecord.WRITE_DID_TRANSACTION
):
if (not endorser) and (txn_goal_code != TransactionRecord.WRITE_DID_TRANSACTION):
ledger = self.profile.inject(BaseLedger)
if not ledger:
raise TransactionManagerError("No ledger available")
Expand Down Expand Up @@ -772,20 +767,17 @@ async def set_transaction_my_job(self, record: ConnRecord, transaction_my_job: s
return tx_job_to_send

async def set_transaction_their_job(
self, tx_job_received: TransactionJobToSend, receipt: MessageReceipt
self, tx_job_received: TransactionJobToSend, connection: ConnRecord
):
"""Set transaction_their_job.
Args:
tx_job_received: The transaction job that is received from the other agent
receipt: The Message Receipt Object
connection: connection to set metadata on
"""

try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did(
session, receipt.sender_did, receipt.recipient_did
)
value = await connection.metadata_get(session, "transaction_jobs")
if value:
value["transaction_their_job"] = tx_job_received.job
Expand Down Expand Up @@ -893,9 +885,7 @@ async def endorsed_txn_post_processing(
elif ledger_response["result"]["txn"]["type"] == "114":
# revocation entry transaction
rev_reg_id = ledger_response["result"]["txn"]["data"]["revocRegDefId"]
revoked = ledger_response["result"]["txn"]["data"]["value"].get(
"revoked", []
)
revoked = ledger_response["result"]["txn"]["data"]["value"].get("revoked", [])
meta_data["context"]["rev_reg_id"] = rev_reg_id
if is_anoncreds:
await AnonCredsRevocation(self._profile).finish_revocation_list(
Expand Down
Loading

0 comments on commit e07066c

Please sign in to comment.