Skip to content

Commit

Permalink
Add upgrade to anoncreds via api
Browse files Browse the repository at this point in the history
Signed-off-by: jamshale <jamiehalebc@gmail.com>
  • Loading branch information
jamshale committed Mar 15, 2024
1 parent 974ec1d commit ac87b3c
Show file tree
Hide file tree
Showing 13 changed files with 867 additions and 11 deletions.
21 changes: 20 additions & 1 deletion aries_cloudagent/admin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
setup_aiohttp_apispec,
validation_middleware,
)

from marshmallow import fields

from ..config.injection_context import InjectionContext
Expand All @@ -38,6 +37,7 @@
from ..utils.stats import Collector
from ..utils.task_queue import TaskQueue
from ..version import __version__
from ..wallet.upgrade_singleton import UpgradeSingleton
from .base_server import BaseAdminServer
from .error import AdminSetupError
from .request_context import AdminRequestContext
Expand All @@ -58,6 +58,8 @@
"acapy::keylist::updated": "keylist",
}

upgrade_singleton = UpgradeSingleton()


class AdminModulesSchema(OpenAPISchema):
"""Schema for the modules endpoint."""
Expand Down Expand Up @@ -205,6 +207,17 @@ async def ready_middleware(request: web.BaseRequest, handler: Coroutine):
raise web.HTTPServiceUnavailable(reason="Shutdown in progress")


@web.middleware
async def upgrade_middleware(request: web.BaseRequest, handler: Coroutine):
"""Blocking middleware for upgrades."""
context: AdminRequestContext = request["context"]

if context._profile.name in upgrade_singleton:
raise web.HTTPServiceUnavailable(reason="Upgrade in progress")

return await handler(request)


@web.middleware
async def debug_middleware(request: web.BaseRequest, handler: Coroutine):
"""Show request detail in debug log."""
Expand Down Expand Up @@ -351,6 +364,8 @@ async def check_multitenant_authorization(request: web.Request, handler):

is_multitenancy_path = path.startswith("/multitenancy")
is_server_path = path in self.server_paths or path == "/features"
# allow base wallets to trigger update through api
is_upgrade_path = path.startswith("/anoncreds/wallet/upgrade")

# subwallets are not allowed to access multitenancy routes
if authorization_header and is_multitenancy_path:
Expand Down Expand Up @@ -380,6 +395,7 @@ async def check_multitenant_authorization(request: web.Request, handler):
and not is_unprotected_path(path)
and not base_limited_access_path
and not (request.method == "OPTIONS") # CORS fix
and not is_upgrade_path
):
raise web.HTTPUnauthorized()

Expand Down Expand Up @@ -453,6 +469,9 @@ async def setup_context(request: web.Request, handler):

middlewares.append(setup_context)

# Upgrade middleware needs the context setup
middlewares.append(upgrade_middleware)

# Register validation_middleware last avoiding unauthorized validations
middlewares.append(validation_middleware)

Expand Down
34 changes: 33 additions & 1 deletion aries_cloudagent/core/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import asyncio
import hashlib
import json
import logging
Expand Down Expand Up @@ -40,7 +41,9 @@
BaseMultipleLedgerManager,
MultipleLedgerManagerError,
)
from ..ledger.multiple_ledger.ledger_requests_executor import IndyLedgerRequestsExecutor
from ..ledger.multiple_ledger.ledger_requests_executor import (
IndyLedgerRequestsExecutor,
)
from ..ledger.multiple_ledger.manager_provider import MultiIndyLedgerManagerProvider
from ..messaging.responder import BaseResponder
from ..multitenant.base import BaseMultitenantManager
Expand Down Expand Up @@ -71,10 +74,15 @@
from ..transport.outbound.message import OutboundMessage
from ..transport.outbound.status import OutboundSendStatus
from ..transport.wire_format import BaseWireFormat
from ..utils.profiles import get_subwallet_profiles_from_storage
from ..utils.stats import Collector
from ..utils.task_queue import CompletedTask, TaskQueue
from ..vc.ld_proofs.document_loader import DocumentLoader
from ..version import RECORD_TYPE_ACAPY_VERSION, __version__
from ..wallet.anoncreds_upgrade import (
set_storage_type_to_anoncreds,
upgrade_wallet_to_anoncreds,
)
from ..wallet.did_info import DIDInfo
from .dispatcher import Dispatcher
from .error import StartupError
Expand Down Expand Up @@ -522,6 +530,8 @@ async def start(self) -> None:
except Exception:
LOGGER.exception("Error accepting mediation invitation")

await self.check_for_wallet_upgrades_in_progress()

# notify protcols of startup status
await self.root_profile.notify(STARTUP_EVENT_TOPIC, {})

Expand Down Expand Up @@ -823,3 +833,25 @@ async def check_for_valid_wallet_type(self, profile):
raise StartupError(
f"Wallet type config [{storage_type_from_config}] doesn't match with the wallet type in storage [{storage_type_record.value}]" # noqa: E501
)

async def _upgrade_subwallet(self, profile: Profile):
upgraded = await upgrade_wallet_to_anoncreds(profile)
if upgraded:
await set_storage_type_to_anoncreds(profile)

async def check_for_wallet_upgrades_in_progress(self):
"""Check for upgrade and upgrade if needed."""
multitenant_mgr = self.context.inject_or(BaseMultitenantManager)
if multitenant_mgr:
subwallet_profiles = await get_subwallet_profiles_from_storage(
self.root_profile
)
# TODO: await here?
await asyncio.gather(
*[self._upgrade_subwallet(profile) for profile in subwallet_profiles]
)

else:
upgraded = await upgrade_wallet_to_anoncreds(self.root_profile)
if upgraded:
await set_storage_type_to_anoncreds(self.root_profile)
39 changes: 36 additions & 3 deletions aries_cloudagent/core/tests/test_conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ async def test_startup_version_record_exists(self):
) as mock_outbound_mgr, mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -166,6 +168,7 @@ async def test_startup_version_record_exists(self):

mock_inbound_mgr.return_value.stop.assert_awaited_once_with()
mock_outbound_mgr.return_value.stop.assert_awaited_once_with()
assert mock_upgrade.called

async def test_startup_version_no_upgrade_add_record(self):
builder: ContextBuilder = StubContextBuilder(self.test_settings)
Expand All @@ -176,6 +179,8 @@ async def test_startup_version_no_upgrade_add_record(self):
) as mock_inbound_mgr, mock.patch.object(
test_module, "OutboundTransportManager", autospec=True
) as mock_outbound_mgr, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -213,6 +218,8 @@ async def test_startup_version_no_upgrade_add_record(self):
) as mock_inbound_mgr, mock.patch.object(
test_module, "OutboundTransportManager", autospec=True
) as mock_outbound_mgr, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -257,6 +264,8 @@ async def test_startup_version_force_upgrade(self):
) as mock_outbound_mgr, mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -296,6 +305,8 @@ async def test_startup_version_force_upgrade(self):
) as mock_outbound_mgr, mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -335,6 +346,8 @@ async def test_startup_version_force_upgrade(self):
) as mock_outbound_mgr, mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -373,6 +386,8 @@ async def test_startup_version_record_not_exists(self):
) as mock_outbound_mgr, mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -449,6 +464,8 @@ async def test_startup_no_public_did(self):
) as mock_outbound_mgr, mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -492,6 +509,8 @@ async def test_stats(self):
) as mock_inbound_mgr, mock.patch.object(
test_module, "OutboundTransportManager", autospec=True
) as mock_outbound_mgr, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
test_module, "LoggingConfigurator", autospec=True
) as mock_logger:
mock_inbound_mgr.return_value.sessions = ["dummy"]
Expand Down Expand Up @@ -884,6 +903,8 @@ async def test_admin(self):
) as admin_start, mock.patch.object(
admin, "stop", autospec=True
) as admin_stop, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -936,6 +957,8 @@ async def test_admin_startx(self):
) as oob_mgr, mock.patch.object(
test_module, "ConnectionManager"
) as conn_mgr, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -992,7 +1015,9 @@ async def test_start_static(self):
),
), mock.patch.object(
test_module, "OutboundTransportManager", autospec=True
) as mock_outbound_mgr:
) as mock_outbound_mgr, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade:
mock_outbound_mgr.return_value.registered_transports = {
"test": mock.MagicMock(schemes=["http"])
}
Expand Down Expand Up @@ -1164,7 +1189,9 @@ async def test_print_invite_connection(self):
),
), mock.patch.object(
test_module, "OutboundTransportManager", autospec=True
) as mock_outbound_mgr:
) as mock_outbound_mgr, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade:
mock_outbound_mgr.return_value.registered_transports = {
"test": mock.MagicMock(schemes=["http"])
}
Expand Down Expand Up @@ -1201,6 +1228,8 @@ async def test_clear_default_mediator(self):
"MediationManager",
return_value=mock.MagicMock(clear_default_mediator=mock.CoroutineMock()),
) as mock_mgr, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down Expand Up @@ -1252,7 +1281,9 @@ async def test_set_default_mediator(self):
mock.MagicMock(value=f"v{__version__}"),
]
),
):
), mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade:
await conductor.start()
await conductor.stop()
mock_mgr.return_value.set_default_mediator_by_id.assert_called_once()
Expand All @@ -1275,6 +1306,8 @@ async def test_set_default_mediator_x(self):
"retrieve_by_id",
mock.CoroutineMock(side_effect=Exception()),
), mock.patch.object(test_module, "LOGGER") as mock_logger, mock.patch.object(
test_module, "upgrade_wallet_to_anoncreds", return_value=False
) as mock_upgrade, mock.patch.object(
BaseStorage,
"find_record",
mock.CoroutineMock(
Expand Down
8 changes: 8 additions & 0 deletions aries_cloudagent/multitenant/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Iterable, Optional

from ..askar.profile_anon import AskarAnoncredsProfile
from ..config.injection_context import InjectionContext
from ..config.wallet import wallet_config
from ..core.profile import Profile
Expand Down Expand Up @@ -84,6 +85,13 @@ async def get_wallet_profile(
profile, _ = await wallet_config(context, provision=provision)
self._profiles.put(wallet_id, profile)

# return anoncreds profile if explicitly set as wallet type
if profile.context.settings.get("wallet.type") == "askar-anoncreds":
return AskarAnoncredsProfile(
profile.opened,
profile.context,
)

return profile

async def update_wallet(self, wallet_id: str, new_settings: dict) -> WalletRecord:
Expand Down
1 change: 1 addition & 0 deletions aries_cloudagent/storage/type.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Library version information."""

RECORD_TYPE_ACAPY_STORAGE_TYPE = "acapy_storage_type"
RECORD_TYPE_ACAPY_UPGRADING = "acapy_upgrading"
28 changes: 28 additions & 0 deletions aries_cloudagent/utils/profiles.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Profile utilities."""

import json

from aiohttp import web

from ..anoncreds.error_messages import ANONCREDS_PROFILE_REQUIRED_MSG
from ..askar.profile_anon import AskarAnoncredsProfile
from ..core.profile import Profile
from ..multitenant.manager import MultitenantManager
from ..storage.base import BaseStorageSearch
from ..wallet.models.wallet_record import WalletRecord


def is_anoncreds_profile_raise_web_exception(profile: Profile) -> None:
Expand All @@ -29,3 +34,26 @@ def subwallet_type_not_same_as_base_wallet_raise_web_exception(
raise web.HTTPForbidden(
reason="Subwallet type must be the same as the base wallet type"
)


async def get_subwallet_profiles_from_storage(root_profile: Profile) -> list[Profile]:
"""Get subwallet profiles from storage."""
subwallet_profiles = []
base_storage_search = root_profile.inject(BaseStorageSearch)
search_session = base_storage_search.search_records(
type_filter=WalletRecord.RECORD_TYPE, page_size=10
)
while search_session._done is False:
wallet_storage_records = await search_session.fetch()
for wallet_storage_record in wallet_storage_records:
wallet_record = WalletRecord.from_storage(
wallet_storage_record.id,
json.loads(wallet_storage_record.value),
)
subwallet_profiles.append(
await MultitenantManager(root_profile).get_wallet_profile(
base_context=root_profile.context,
wallet_record=wallet_record,
)
)
return subwallet_profiles
Loading

0 comments on commit ac87b3c

Please sign in to comment.