Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: multiuse invites with did peer 4 #3112

Merged
merged 6 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ repos:
# Run the formatter
- id: ruff-format
stages: [commit]
args: [--fix, --exit-non-zero-on-fix, --formatter]
40 changes: 18 additions & 22 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,6 @@ async def find_did_for_key(self, key: str) -> str:
storage: BaseStorage = session.inject(BaseStorage)
record = await storage.find_record(self.RECORD_TYPE_DID_KEY, {"key": key})
ret_did = record.tags["did"]
if ret_did.startswith("did:peer:4"):
ret_did = self.long_did_peer_to_short(ret_did)
return ret_did

async def remove_keys_for_did(self, did: str):
Expand All @@ -452,9 +450,7 @@ async def resolve_didcomm_services(
doc_dict: dict = await resolver.resolve(self._profile, did, service_accept)
doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True)
except ResolverError as error:
raise BaseConnectionManagerError(
"Failed to resolve DID services"
) from error
raise BaseConnectionManagerError("Failed to resolve DID services") from error

if not doc.service:
raise BaseConnectionManagerError(
Expand Down Expand Up @@ -523,10 +519,7 @@ async def resolve_invitation(

return (
endpoint,
[
self._extract_key_material_in_base58_format(key)
for key in recipient_keys
],
[self._extract_key_material_in_base58_format(key) for key in recipient_keys],
[self._extract_key_material_in_base58_format(key) for key in routing_keys],
)

Expand Down Expand Up @@ -800,9 +793,7 @@ async def get_connection_targets(
async with cache.acquire(cache_key) as entry:
if entry.result:
self._logger.debug("Connection targets retrieved from cache")
targets = [
ConnectionTarget.deserialize(row) for row in entry.result
]
targets = [ConnectionTarget.deserialize(row) for row in entry.result]
else:
if not connection:
async with self._profile.session() as session:
Expand All @@ -817,9 +808,7 @@ async def get_connection_targets(
# Otherwise, a replica that participated early in exchange
# may have bad data set in cache.
self._logger.debug("Caching connection targets")
await entry.set_result(
[row.serialize() for row in targets], 3600
)
await entry.set_result([row.serialize() for row in targets], 3600)
else:
self._logger.debug(
"Not caching connection targets for connection in "
Expand Down Expand Up @@ -878,12 +867,8 @@ def diddoc_connection_targets(
did=doc.did,
endpoint=service.endpoint,
label=their_label,
recipient_keys=[
key.value for key in (service.recip_keys or ())
],
routing_keys=[
key.value for key in (service.routing_keys or ())
],
recipient_keys=[key.value for key in (service.recip_keys or ())],
routing_keys=[key.value for key in (service.routing_keys or ())],
sender_key=sender_verkey,
)
)
Expand Down Expand Up @@ -920,7 +905,18 @@ async def find_connection(

"""
connection = None
if their_did:
if their_did and their_did.startswith("did:peer:4"):
# did:peer:4 always recorded as long
long = their_did
short = self.long_did_peer_to_short(their_did)
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did_peer_4(
session, long, short, my_did
)
except StorageNotFoundError:
pass
elif their_did:
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did(
Expand Down
46 changes: 42 additions & 4 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def __init__(
self.their_role = (
ConnRecord.Role.get(their_role).rfc160
if isinstance(their_role, str)
else None if their_role is None else their_role.rfc160
else None
if their_role is None
else their_role.rfc160
)
self.invitation_key = invitation_key
self.invitation_msg_id = invitation_msg_id
Expand Down Expand Up @@ -293,6 +295,44 @@ async def retrieve_by_did(

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_did_peer_4(
cls,
session: ProfileSession,
their_did_long: Optional[str] = None,
their_did_short: Optional[str] = None,
my_did: Optional[str] = None,
their_role: Optional[str] = None,
) -> "ConnRecord":
"""Retrieve a connection record by target DID.

Args:
session: The active profile session
their_did_long: The target DID to filter by, in long form
their_did_short: The target DID to filter by, in short form
my_did: One of our DIDs to filter by
my_role: Filter connections by their role
their_role: Filter connections by their role
"""
tag_filter = {}
if their_did_long and their_did_short:
tag_filter["$or"] = [
{"their_did": their_did_long},
{"their_did": their_did_short},
]
elif their_did_short:
tag_filter["their_did"] = their_did_short
elif their_did_long:
tag_filter["their_did"] = their_did_long
if my_did:
tag_filter["my_did"] = my_did

post_filter = {}
if their_role:
post_filter["their_role"] = cls.Role.get(their_role).rfc160

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_invitation_key(
cls, session: ProfileSession, invitation_key: str, their_role: str = None
Expand Down Expand Up @@ -375,9 +415,7 @@ async def retrieve_by_request_id(
return await cls.retrieve_by_tag_filter(session, tag_filter)

@classmethod
async def retrieve_by_alias(
cls, session: ProfileSession, alias: str
) -> "ConnRecord":
async def retrieve_by_alias(cls, session: ProfileSession, alias: str) -> "ConnRecord":
"""Retrieve a connection record from an alias.

Args:
Expand Down
86 changes: 77 additions & 9 deletions aries_cloudagent/connections/models/tests/test_conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def setUp(self):
self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP"
self.test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC"

self.test_did_peer_4_a = "did:peer:4zQmV3Hf1TT4Xn73MBVf2NAWdMwrzUabpEvwtV3RoZc17Vxr:z2pfttj3xn6tJ7wpHV9ZSwpQVMNtHC7EtM36r1mC5fDpZ25882Yitk21QPbqzuefKPrbFsexWmQtE78vYWweckXtKeu5BhuFDvCjMUf8SC5z7cMPvp8SCdcbqWnHxygjBH9zAAs9myGRnZYXuAkq6CfBdn6ZiNmdRf65TdVfE3cYfS4jNzVZDs1abwytn4jdFJ2fwVegPB3vLY8XxeUEx12a4rtjkqMhs6zBQbJvc4PVUM9rvMbPM2QeXDy7ovkkHaKLUbNUxjQrcQeiR8MTLe1iaVtUv6RpBf4z7ioqfa4VDRmAZT7isVM3NvENUceeUfDZoFbM8PZqGkCbFvfoKiK3SrmTsvPtpXaBAfR4z7w18cFjsvvLBNMZbPnARn4oZijCkYwgaNmAUthgDP4XBFetdUo8728w25FUwTWjAPc1BdSSWPWMRKwCqyAP1Q1hM8dU6otT27MQaQ1rozKncn3U48CXEi2Ef26EDBrSozEWR273ancFojNXBbZVghZG5b6xdypjQir9PgTF94dsygtu47hNxQweVKLUM1p9umqHLhjvLhpS1aGQkGZNnKUHjLDHdToigo15F7TAf8RfMaducHBThFzEp9TUJmiZFTUYQ1uaBgSPMSaWnvTfUoFmLoGbdrWj1vVEsRrARq37u1SJGLqBx7FM2SUd8nxPsChP5jY8ka8F8r7j8qZLHZqvUXbynPUsViwwdFFk8SCsBWfiQgvq7sRiTdLnYv3H5DSwA1uW2GNYXGgkT9aJza4Sk1gvag5iAbQZgxbU594enjVSTjiWsFw2oYQ75JJwiSEgsP2rhpGsNhXxfECNLUtb7FQbDQPtUvLHCJATf7QXJEoWjpfAywmB6NyQcXfskco6FKJNNHeZBnST6U1meH98Ku66vha1k8hAc72iBhXQBnWUjaGRyzELsh2LkBH2UNwW9TuFhxz3SKtL5pGShVQ5XGQhmdrkWP68d6h7c1JqsfogcDBnmWS4VSbJwgtsPNTSsTHGX8hpGvg"
self.test_did_peer_4_short_a = (
"did:peer:4zQmV3Hf1TT4Xn73MBVf2NAWdMwrzUabpEvwtV3RoZc17Vxr"
)
self.test_did_peer_4_b = "did:peer:4zQmQ4dEtoGcivpiH6gtWwhWJY2ENVWuZifb62uzR76HGPPw:z7p4QX8zEXt2sMjv1Tqq8Lv8Nx8oGo2uRczBe21vyfMhQzsWDnwGmjriYfUX75WDq622czcdHjWGhh2VTbzKhLXUjY8Ma7g64dKAVcy8SaxN5QVdjwpXgD7htKCgCjah8jHEzyBZFrtdfTHiVXfSUz1BiURQf1Z3NfxW5cWYsvDJVvQzVmdHb8ekzCnvxCqL2UV1v9SBb1DsU66N3PCp9HVpSrqUJQyFU2Ddc8bb6u8SJfBU1nyCkNMgfA1zAyKnSBrzZWyyNzAm9oBV36qjC1Qjfcpq4FBnGr7foh5sLXppBwu2ES8U2nxdGrQzAbN47DKBoKJqPVxNh5tTuBdYjDGt7PcvZQjHQGNXXuhJctM5besZci2saGefCHzoZ87vSsFuKq6oXEsW512eadiNZWjHSdG9J4ToMEMK9WT66vGGLFdZszB3xhdFqEDnAMcpnoFUL5WN243aH6492jPC2Zjdi1BvHC1J8bUuvyihAKXF3WmFz7gJWmh6MrTEWNqb17K6tqbyXjFmfnS2RbAi8xBFj3sSsXkSs6TRTXAZD9DenYaQq4RMa2Kqh6VKGvkXAjVHKcPh9Ncpt6rU9ZYttNHbDJFgahwB8KisVBK8FBpG"
self.test_did_peer_4_short_b = (
"did:peer:4zQmQ4dEtoGcivpiH6gtWwhWJY2ENVWuZifb62uzR76HGPPw"
)

self.test_conn_record = ConnRecord(
my_did=self.test_did,
their_did=self.test_target_did,
Expand All @@ -39,9 +48,7 @@ async def test_get_enums(self):
assert ConnRecord.Role.get("Larry") is None
assert ConnRecord.State.get("a suffusion of yellow") is None

assert (
ConnRecord.Role.get(ConnRecord.Role.REQUESTER) is ConnRecord.Role.REQUESTER
)
assert ConnRecord.Role.get(ConnRecord.Role.REQUESTER) is ConnRecord.Role.REQUESTER

assert (
ConnRecord.State.get(ConnRecord.State.RESPONSE) is ConnRecord.State.RESPONSE
Expand Down Expand Up @@ -133,6 +140,71 @@ async def test_retrieve_by_did(self):
)
assert result == record

async def test_retrieve_by_did_peer_4_by_long(self):
record = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
rec_id = await record.save(self.session)
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_long=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record

async def test_retrieve_by_did_peer_4_by_short(self):
record = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_short_b,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record.save(self.session)
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_b,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record

async def test_retrieve_by_did_peer_4_by_either(self):
record_short = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_short_a,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record_short.save(self.session)
record_long = ConnRecord(
my_did=self.test_did,
their_did=self.test_did_peer_4_b,
their_role=ConnRecord.Role.RESPONDER.rfc23,
state=ConnRecord.State.COMPLETED.rfc23,
)
await record_long.save(self.session)

result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_a,
their_did_long=self.test_did_peer_4_a,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record_short
result = await ConnRecord.retrieve_by_did_peer_4(
session=self.session,
my_did=self.test_did,
their_did_short=self.test_did_peer_4_short_b,
their_did_long=self.test_did_peer_4_b,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert result == record_long

async def test_from_storage_with_initiator_old(self):
record = ConnRecord(my_did=self.test_did, state=ConnRecord.State.COMPLETED)
ser = record.serialize()
Expand Down Expand Up @@ -300,9 +372,7 @@ async def test_attach_retrieve_request(self):
connection_id = await record.save(self.session)

req = ConnectionRequest(
connection=ConnectionDetail(
did=self.test_did, did_doc=DIDDoc(self.test_did)
),
connection=ConnectionDetail(did=self.test_did, did_doc=DIDDoc(self.test_did)),
label="abc123",
)
await record.attach_request(self.session, req)
Expand All @@ -317,9 +387,7 @@ async def test_attach_request_abstain_on_alien_deco(self):
connection_id = await record.save(self.session)

req = ConnectionRequest(
connection=ConnectionDetail(
did=self.test_did, did_doc=DIDDoc(self.test_did)
),
connection=ConnectionDetail(did=self.test_did, did_doc=DIDDoc(self.test_did)),
label="abc123",
)
ser = req.serialize()
Expand Down
Loading