From ce68dbeef81faf5a9263beb2e5dcbee616e98c5e Mon Sep 17 00:00:00 2001 From: "Davin K. Tanabe" Date: Tue, 12 May 2020 12:27:17 -0400 Subject: [PATCH] python: Introduce a new connection API that is a bit less stateful. --- python/dazl/__init__.py | 1 + python/dazl/protocols/config/__init__.py | 62 +++ python/dazl/protocols/config/access.py | 390 +++++++++++++++++ python/dazl/protocols/config/ssl.py | 55 +++ python/dazl/protocols/config/url.py | 97 +++++ python/dazl/protocols/core.py | 78 ++++ python/dazl/protocols/errors.py | 28 ++ python/dazl/protocols/ledgerapi/__init__.py | 17 + python/dazl/protocols/ledgerapi/__init__.pyi | 32 ++ python/dazl/protocols/ledgerapi/channel.py | 58 +++ python/dazl/protocols/ledgerapi/codec_aio.py | 291 +++++++++++++ python/dazl/protocols/ledgerapi/conn_aio.py | 431 +++++++++++++++++++ python/dazl/protocols/pkgloader_aio.py | 224 ++++++++++ python/dazl/protocols/stream_aio.py | 117 +++++ python/tests/unit/test_protocol_ledgerapi.py | 43 ++ 15 files changed, 1924 insertions(+) create mode 100644 python/dazl/protocols/config/__init__.py create mode 100644 python/dazl/protocols/config/access.py create mode 100644 python/dazl/protocols/config/ssl.py create mode 100644 python/dazl/protocols/config/url.py create mode 100644 python/dazl/protocols/core.py create mode 100644 python/dazl/protocols/errors.py create mode 100644 python/dazl/protocols/ledgerapi/__init__.py create mode 100644 python/dazl/protocols/ledgerapi/__init__.pyi create mode 100644 python/dazl/protocols/ledgerapi/channel.py create mode 100644 python/dazl/protocols/ledgerapi/codec_aio.py create mode 100644 python/dazl/protocols/ledgerapi/conn_aio.py create mode 100644 python/dazl/protocols/pkgloader_aio.py create mode 100644 python/dazl/protocols/stream_aio.py create mode 100644 python/tests/unit/test_protocol_ledgerapi.py diff --git a/python/dazl/__init__.py b/python/dazl/__init__.py index 55553d8c..65412692 100644 --- a/python/dazl/__init__.py +++ b/python/dazl/__init__.py @@ -14,6 +14,7 @@ from .prim import FrozenDict as frozendict from .pretty.table import write_acs +from .protocols.ledgerapi import connect, Connection try: # This method is undocumented, but is required to read large size of model files when using diff --git a/python/dazl/protocols/config/__init__.py b/python/dazl/protocols/config/__init__.py new file mode 100644 index 00000000..c88d7522 --- /dev/null +++ b/python/dazl/protocols/config/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from os import PathLike +from typing import Collection, Optional, Union + +from ...prim import Party, TimeDeltaLike +from .access import AccessConfig, create_access +from .ssl import SSLConfig +from .url import URLConfig + +__all__ = ["Config", "AccessConfig", "SSLConfig", "URLConfig"] + + +class Config: + @classmethod + def create( + cls, + url: str = None, + oauth_client_id: "Optional[str]" = None, + oauth_token: "Optional[str]" = None, + party: "Union[None, Party, Collection[Party]]" = None, + read_as: "Union[None, Party, Collection[Party]]" = None, + act_as: "Union[None, Party, Collection[Party]]" = None, + admin: "Optional[bool]" = False, + ca_file: "Optional[PathLike]" = None, + cert_file: "Optional[PathLike]" = None, + cert_key_file: "Optional[PathLike]" = None, + verify_ssl: "Optional[str]" = None, + connect_timeout: "Optional[TimeDeltaLike]" = None, + enable_http_proxy: "bool" = True, + ledger_id: "Optional[str]" = None, + application_name: "Optional[str]" = None, + ) -> "Config": + url_config = URLConfig( + url=url, + connect_timeout=connect_timeout, + enable_http_proxy=enable_http_proxy, + ) + + access_config = create_access( + party=party, + ledger_id=ledger_id, + application_name=application_name, + read_as=read_as, + act_as=act_as, + admin=admin, + oauth_token=oauth_token, + ) + + ssl_config = SSLConfig( + ca_file=ca_file, + cert_file=cert_file, + cert_key_file=cert_key_file, + verify_ssl=verify_ssl, + ) + + return cls(access_config, ssl_config, url_config) + + def __init__(self, access: "AccessConfig", ssl: "SSLConfig", url: "URLConfig"): + self.access = access + self.ssl = ssl + self.url = url diff --git a/python/dazl/protocols/config/access.py b/python/dazl/protocols/config/access.py new file mode 100644 index 00000000..91bb1cd4 --- /dev/null +++ b/python/dazl/protocols/config/access.py @@ -0,0 +1,390 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import base64 +import json +from collections.abc import MutableSet as MutableSetBase, Set as SetBase +from typing import ( + AbstractSet, + Any, + Collection, + Iterator, + Mapping, + MutableSet, + Optional, + Union, + overload, +) + +try: + from typing_extensions import Protocol +except ImportError: + from typing import Protocol + +from ...prim import Party + + +@overload +def create_access( + *, + read_as: "Union[None, Party, Collection[Party]]" = None, + act_as: "Union[None, Party, Collection[Party]]" = None, + party: "Optional[Party]" = None, + admin: "Optional[bool]" = None, + ledger_id: "Optional[str]" = None, + application_name: "Optional[str]" = None, +) -> "PropertyBasedAccessConfig": + ... + + +@overload +def create_access(*, oauth_token: "str" = None) -> "TokenBasedAccessConfig": + ... + + +def create_access( + *, + party=None, + read_as=None, + act_as=None, + admin=None, + ledger_id=None, + application_name=None, + oauth_token=None, +): + if oauth_token: + # if a token is supplied, none of the other arguments are allowed + if ( + read_as is not None + or act_as is not None + or admin is not None + or ledger_id is not None + or application_name is not None + ): + raise ValueError( + "cannot configure access with both tokens and " + "read_as/act_as/admin/ledger_id/application_name configuration options" + ) + return TokenBasedAccessConfig(oauth_token) + else: + return PropertyBasedAccessConfig( + party=party, + read_as=read_as, + act_as=act_as, + admin=admin, + ledger_id=ledger_id, + application_name=application_name, + ) + + +class TokenBasedAccessConfig: + """ + Access configuration that is inherently token-based. The token can be changed at any time, and + party rights, the application name, and ledger ID are all derived off of the token. + """ + + def __init__(self, oauth_token: str): + self.token = oauth_token + + @property + def token(self) -> str: + """ + The bearer token that provides authorization and authentication to a ledger. + """ + return self._token + + @token.setter + def token(self, value: str) -> None: + self._token = value + claims = decode_token(self._token) + self._admin = claims.get("admin", False) + self._ledger_id = claims.get("ledgerId", None) + + def ledger_id(self) -> str: + return self._ledger_id + + def admin(self) -> bool: + return self._admin + + +class PropertyBasedAccessConfig: + """ + Access configuration that is manually specified outside of an authentication/authorization + framework. Suitable for local testing or when no auth server is available, and the Ledger API + inherently trusts any caller to provide its own authentication and authorization. + """ + + def __init__( + self, + party: "Optional[Party]" = None, + read_as: "Union[None, Party, Collection[Party]]" = None, + act_as: "Union[None, Party, Collection[Party]]" = None, + admin: "Optional[bool]" = False, + ledger_id: "Optional[str]" = None, + application_name: "Optional[str]" = None, + ): + """ + + :param party: + The singular party to use for reading and writing. This parameter is a convenience + parameter for the common case where "read as" and "act as" parties are the same, + and there is only one of them. If you specify this parameter, you CANNOT supply + ``read_as`` or ``act_as``, nor can you supply an access token When connecting to the + HTTP JSON API, ``ledger_id`` must _also_ be supplied when using this parameter.. + :param read_as: + A party or set of parties that can be used to read data from the ledger. In a Daml-based + ledger, read-as rights are implied by act-as rights, so you may choose to supply only + act-as parties if you wish. If you specify this parameter, you CANNOT supply + ``party``, nor can you supply an access token. When connecting to the HTTP JSON API, + ``ledger_id`` must _also_ be supplied when using this parameter. + :param act_as: + A party of set of parties that can be used to submit commands to the ledger. In a + Daml-based ledger, act-as rights imply read-as rights. If you specify this parameter, + you CANNOT supply ``party``, nor can you supply an access token. When connecting to the + HTTP JSON API, ``ledger_id`` must _also_ be supplied when using this parameter. + :param ledger_id: + The + """ + self._parties = PartyRights() + self._parties.maybe_add(read_as, False) + self._parties.maybe_add(act_as, True) + self._parties.maybe_add(party, True) + self._admin = bool(admin) + self._ledger_id = ledger_id + self._application_name = application_name or "dazl-client" + + @property + def token(self): + """ + Produces a token without signing, utilizing our parameters. + """ + return encode_unsigned_token( + self.read_as, self.act_as, self.ledger_id, self.application_name, self.admin + ) + + @property + def ledger_id(self) -> str: + return self._ledger_id + + @ledger_id.setter + def ledger_id(self, value: str) -> None: + self._ledger_id = value + + @property + def application_name(self) -> str: + return self._application_name + + @property + def read_as(self) -> "MutableSet[Party]": + """ + Return the set of parties for which read rights are granted. + + This set always includes the act_as parties. For the set of parties that can be read as + but NOT acted as, use :property:`read_only_as`. + """ + return self._parties + + @property + def read_only_as(self) -> "MutableSet[Party]": + """""" + return self._parties.read_as + + @property + def act_as(self) -> "MutableSet[Party]": + """ + Return the set of parties for which act-as rights are granted. This collection can be + modified. + """ + return self._parties.act_as + + @property + def admin(self) -> bool: + return self._admin + + +class AccessConfig: + """ + Configuration parameters for providing access to a ledger. + """ + + @property + def ledger_id(self) -> str: + return decode_token(self.token)["ledgerId"] + + @ledger_id.setter + def ledger_id(self, value: str): + token = decode_token(self.token) + self.token = encode_unsigned_token( + read_as=token["readAs"], + act_as=token["actAs"], + ledger_id=value, + application_id=token["applicationName"], + ) + + @property + def application_name(self) -> str: + return decode_token(self.token)["applicationName"] + + @property + def token(self) -> str: + """ + Return a JSON Web Token (JWT). If a token was _not_ supplied in the constructor or via a + call to the token setter, an unsigned token is generated and used instead. + """ + return self._token + + @token.setter + def token(self, value: str): + """ + Set the current token. Note that tokens are NOT validated for authenticity. + """ + self._token = value + + @property + def read_as(self) -> "AbstractSet[Party]": + """ + Return the set of parties for which read rights are granted. + + This set always includes the act_as parties. For the set of parties that can be read as + but NOT acted as, use :property:`read_only_as`. + """ + claims = decode_token(self.token) + return frozenset((*claims["readAs"], *claims["actAs"])) + + @property + def read_only_as(self) -> "AbstractSet[Party]": + """ + Return the set of parties for which read rights are granted, AND write access is NOT + granted. + """ + claims = decode_token(self.token) + a = {*claims["readAs"], *claims["actAs"]} + a.difference(claims["actAs"]) + return frozenset(a) + + @property + def act_as(self) -> "AbstractSet[Party]": + claims = decode_token(self.token) + return frozenset(claims["actAs"]) + + +def parties(p: "Union[None, Party, Collection[Party]]") -> "Collection[Party]": + if p is None: + return [] + elif isinstance(p, str): + return p + else: + return p + + +DamlLedgerApiNamespace = "https://daml.com/ledger-api" + + +def decode_token(token: str) -> "Mapping[str, Any]": + components = token.split(".", 3) + if len(components) != 3: + raise ValueError("not a JWT") + claim_str = base64.urlsafe_b64decode(components[1]) + claims = json.loads(claim_str) + claims_dict = claims.get(DamlLedgerApiNamespace) + if claims_dict is None: + raise ValueError(f"JWT is missing claim namespace: {DamlLedgerApiNamespace!r}") + return claims_dict + + +def encode_unsigned_token( + read_as: "Collection[Party]", + act_as: "Collection[Party]", + ledger_id: str, + application_id: str, + admin: bool = True, +) -> bytes: + header = { + "alg": "none", + "typ": "JWT", + } + payload = { + DamlLedgerApiNamespace: { + "ledgerId": ledger_id, + "applicationId": application_id, + "actAs": sorted(act_as), + "readAs": sorted(read_as), + "admin": admin, + } + } + + return ( + base64.urlsafe_b64encode(json.dumps(header).encode("utf-8")) + + b"." + + base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")) + + b"." + ) + + +class PartyRights(SetBase): + __slots__ = ("_rights", "read_as", "act_as") + + def __init__(self): + self._rights = dict() + self.read_as = PartyRightsSet(self, False) + self.act_as = PartyRightsSet(self, True) + + def maybe_add( + self, value: "Union[None, Party, Collection[Party]]", has_act_rights: bool + ) -> None: + if value is None: + return + + # Party is a fake Python newtype, so isinstance checks don't work on it + if isinstance(value, str): + self.add(Party(value), has_act_rights) + else: + for party in value: + self.add(party, has_act_rights) + + def add(self, value: "Party", has_act_rights: bool) -> None: + """ + Add/replace a ``Party`` and its rights. + """ + self._rights[value] = has_act_rights + + def discard(self, value: "Party") -> None: + self._rights.pop(value) + + def get(self, value: "Party") -> "Optional[bool]": + return self._rights.get(value) + + def count(self, act_as: bool) -> int: + return sum(1 for p, a in self._rights.items() if act_as == a) + + def __contains__(self, party: object) -> bool: + return party in self._rights + + def __len__(self) -> int: + return len(self._rights) + + def __iter__(self) -> "Iterator[Party]": + return iter(sorted(self._rights)) + + def iter(self, act_as: bool) -> "Iterator[Party]": + return iter(p for p, a in sorted(self._rights.items()) if a == act_as) + + +class PartyRightsSet(MutableSetBase): + def __init__(self, rights: "PartyRights", act_as: bool): + self._rights = rights + self._act_as = act_as + + def add(self, value: "Party") -> None: + self._rights.add(value, self._act_as) + + def discard(self, value: "Party") -> None: + self._rights.discard(value) + + def __contains__(self, party: "Party") -> bool: + return self._rights.get(party) == self._act_as + + def __len__(self) -> int: + return self._rights.count(self._act_as) + + def __iter__(self) -> "Iterator[Party]": + return self._rights.iter(self._act_as) diff --git a/python/dazl/protocols/config/ssl.py b/python/dazl/protocols/config/ssl.py new file mode 100644 index 00000000..d0f62885 --- /dev/null +++ b/python/dazl/protocols/config/ssl.py @@ -0,0 +1,55 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from os import PathLike, fspath +from typing import Optional + + +class SSLConfig: + """ + Configuration parameters that affect SSL connections. + """ + + def __init__( + self, + ca_file: "Optional[PathLike]" = None, + cert_file: "Optional[PathLike]" = None, + cert_key_file: "Optional[PathLike]" = None, + verify_ssl: "Optional[str]" = None, + ): + if ca_file: + ca_file = fspath(ca_file) + with open(ca_file) as f: + self._root_certificates = f.read() + else: + self._root_certificates = None + + if cert_file: + cert_file = fspath(cert_file) + with open(cert_file) as f: + self._certificate_chain = f.read() + else: + self._certificate_chain = None + + if cert_key_file: + cert_key_file = fspath(cert_key_file) + with open(cert_key_file) as f: + self._private_key = f.read() + else: + self._private_key = None + + def __bool__(self): + return bool( + self._root_certificates or self._certificate_chain or self._private_key + ) + + @property + def root_certificates(self) -> "Optional[bytes]": + return self._root_certificates + + @property + def certificate_chain(self) -> "Optional[bytes]": + return self._certificate_chain + + @property + def private_key(self) -> "Optional[bytes]": + return self._private_key diff --git a/python/dazl/protocols/config/url.py b/python/dazl/protocols/config/url.py new file mode 100644 index 00000000..4814b279 --- /dev/null +++ b/python/dazl/protocols/config/url.py @@ -0,0 +1,97 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from datetime import timedelta +from typing import TYPE_CHECKING, Optional +from urllib.parse import urlparse + +if TYPE_CHECKING: + from ...prim import TimeDeltaLike + +__all__ = ["URLConfig"] + +DEFAULT_CONNECT_TIMEOUT = timedelta(30) + + +class URLConfig: + def __init__( + self, + url: str, + host: "Optional[str]" = None, + port: "Optional[int]" = None, + scheme: "Optional[str]" = None, + connect_timeout: "Optional[TimeDeltaLike]" = DEFAULT_CONNECT_TIMEOUT, + enable_http_proxy: bool = True, + ): + if url: + if host or port or scheme: + raise ValueError( + "url or host/port/scheme must be specified, but not both" + ) + self._url = sanitize_url(url) + components = urlparse(self._url, allow_fragments=False) + self._host = components.hostname + self._port = components.port + self._scheme = components.scheme + else: + self._scheme = scheme or "" + self._host = host or "localhost" + self._port = port or 6865 + self._url = f"{self._scheme}//{self._host}:{self._port}" + self._connect_timeout = connect_timeout + self._enable_http_proxy = enable_http_proxy + + @property + def url(self) -> str: + """ + The full URL to connect to, including a protocol, host, and port. + """ + return self._url + + @property + def scheme(self) -> "Optional[str]": + return self._scheme + + @property + def host(self) -> str: + return self._host + + @property + def port(self) -> "Optional[int]": + return self._port + + @property + def enable_http_proxy(self) -> bool: + """ + Whether to allow the use of HTTP proxies. + """ + return self._enable_http_proxy + + @property + def connect_timeout(self) -> timedelta: + """ + How long to wait for a connection before giving up. + + The default is 30 seconds. + """ + return self._connect_timeout + + +def sanitize_url(url: str) -> str: + """ + Perform some basic sanitization on a URL string: + * Convert a URL with no specified protocol to one with a blank protocol + * Strip out any trailing slashes + + >>> sanitize_url("somewhere:1000") + '//somewhere:1000' + + >>> sanitize_url("http://somewhere:1000") + 'http://somewhere:1000' + + >>> sanitize_url("http://somewhere:1000/") + 'http://somewhere:1000/' + """ + first_slash = url.find("/") + if first_slash == -1 or first_slash != url.find("//"): + url = "//" + url + return url diff --git a/python/dazl/protocols/core.py b/python/dazl/protocols/core.py new file mode 100644 index 00000000..9d486dba --- /dev/null +++ b/python/dazl/protocols/core.py @@ -0,0 +1,78 @@ +from typing import Any, Collection, Mapping, Optional, Sequence, Union + +from ..prim import ContractData, ContractId, Party + +__all__ = ['Event', 'CreateEvent', 'ArchiveEvent', 'ExerciseResponse', "PartyInfo", "Query"] + + +Query = Union[None, Mapping[str, Any], Collection[Mapping[str, Any]]] + + +class Event: + __slots__ = '_cid', + + def __init__(self, cid: "ContractId"): + object.__setattr__(self, "_cid", cid) + + @property + def cid(self) -> "ContractId": + return self._cid + + def __setattr__(self, key, value): + raise AttributeError("Event instances are read-only") + + +class CreateEvent(Event): + __slots__ = '_cdata', + + def __init__(self, cid: "ContractId", cdata: "ContractData"): + super(CreateEvent, self).__init__(cid) + object.__setattr__(self, "_cdata", cdata) + + @property + def cdata(self) -> "ContractData": + return self._cdata + + +class ArchiveEvent(Event): + pass + + +class ExerciseResponse: + __slots__ = '_result', '_events' + + def __init__(self, result, events): + object.__setattr__(self, '_result', result) + object.__setattr__(self, '_events', tuple(events)) + + @property + def result(self) -> 'Optional[Any]': + return self._result + + @property + def events(self) -> 'Sequence[Event]': + return self._events + + def __repr__(self): + return f"ExerciseResponse(result={self.result}, events={self.events})" + + +class PartyInfo: + __slots__ = '_party', '_display_name', '_is_local' + + def __init__(self, party: 'Party', display_name: str, is_local: bool): + object.__setattr__(self, '_party', party) + object.__setattr__(self, '_display_name', display_name) + object.__setattr__(self, '_is_local', is_local) + + @property + def party(self) -> "Party": + return self._party + + @property + def display_name(self) -> str: + return self._display_name + + @property + def is_local(self) -> bool: + return self._is_local diff --git a/python/dazl/protocols/errors.py b/python/dazl/protocols/errors.py new file mode 100644 index 00000000..d41fca95 --- /dev/null +++ b/python/dazl/protocols/errors.py @@ -0,0 +1,28 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +__all__ = ['StreamError', 'ProtocolWarning', 'CallbackReturnWarning'] + + +class StreamError: + """ + An error that arises when trying to read from a query stream. + """ + + +class ProtocolWarning(Warning): + """ + Warnings that are raised when dazl detects incompatibilities between the Ledger API server-side + implementation and dazl. + """ + + +class CallbackReturnWarning(Warning): + """ + Raised when a user callback on a stream returns a value. These objects have no meaning and are + ignored by dazl. + + This warning is raised primarily because older versions of dazl interpreted returning commands + from a callback as a request to send commands to the underlying ledger, and this is not + supported in newer APIs. + """ diff --git a/python/dazl/protocols/ledgerapi/__init__.py b/python/dazl/protocols/ledgerapi/__init__.py new file mode 100644 index 00000000..8a7b8a03 --- /dev/null +++ b/python/dazl/protocols/ledgerapi/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from ..config import Config +from .conn_aio import Connection + +__all__ = ["connect", "Connection"] + + +def connect(**kwargs): + """ + Connect to a gRPC Ledger API implementation and return a connection that uses asyncio. + """ + # TODO: Support async=False, which should return a Connection implementation that uses blocking + # calls instead of async calls + config = Config.create(**kwargs) + return Connection(config) diff --git a/python/dazl/protocols/ledgerapi/__init__.pyi b/python/dazl/protocols/ledgerapi/__init__.pyi new file mode 100644 index 00000000..bb6c87e8 --- /dev/null +++ b/python/dazl/protocols/ledgerapi/__init__.pyi @@ -0,0 +1,32 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal, NoReturn, Optional, overload + +# When we support blocking connections, the return type of this annotation will change to that +# blocking implementation. Right now we throw an exception if blocking=True is specified. +from ... import Party +from .conn_aio import Connection + +# TODO: Figure out clever ways to make this function's type signature easier to maintain while +# preserving its ease of use to callers. +@overload +def connect( + url: str, + *, + blocking: Literal[True], + party: "Optional[Party]" = None, + admin: "Optional[bool]" = None +) -> NoReturn: ... +@overload +def connect( + url: str, + *, + blocking: Literal[False], + party: "Optional[Party]" = None, + admin: "Optional[bool]" = None +) -> Connection: ... +@overload +def connect( + url: str, *, admin: "Optional[bool]" = None, party: "Optional[Party]" = None +) -> Connection: ... diff --git a/python/dazl/protocols/ledgerapi/channel.py b/python/dazl/protocols/ledgerapi/channel.py new file mode 100644 index 00000000..c6395a0d --- /dev/null +++ b/python/dazl/protocols/ledgerapi/channel.py @@ -0,0 +1,58 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from grpc import ( + AuthMetadataContext, + AuthMetadataPlugin, + AuthMetadataPluginCallback, + composite_channel_credentials, + metadata_call_credentials, + ssl_channel_credentials, +) +from grpc.aio import Channel, insecure_channel, secure_channel + +from ..config import Config + +__all__ = ["create_channel"] + + +def create_channel(config: "Config") -> "Channel": + """ + Create a :class:`Channel` for the specified configuration. + """ + target = f"{config.url.host}:{config.url.port}" + options = [ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ] + if config.url.enable_http_proxy: + options.append(("grpc.enable_http_proxy", 0)) + + if (config.url.scheme in ("https", "grpcs")) or config.ssl: + credentials = ssl_channel_credentials( + root_certificates=config.ssl.root_certificates, + private_key=config.ssl.private_key, + certificate_chain=config.ssl.certificate_chain, + ) + if config.access.token: + credentials = composite_channel_credentials( + credentials, metadata_call_credentials(GrpcAuth(config)) + ) + return secure_channel(target, credentials, options) + else: + return insecure_channel(target, options) + + +class GrpcAuth(AuthMetadataPlugin): + def __init__(self, config: "Config"): + self._config = config + + def __call__(self, context: "AuthMetadataContext", callback: "AuthMetadataPluginCallback"): + options = [] + + # TODO: Add support here for refresh tokens + token = self._config.access.token + if token: + options.append(("Authorization", "Bearer " + self._config.access.token)) + + callback(options, None) diff --git a/python/dazl/protocols/ledgerapi/codec_aio.py b/python/dazl/protocols/ledgerapi/codec_aio.py new file mode 100644 index 00000000..9e22c2da --- /dev/null +++ b/python/dazl/protocols/ledgerapi/codec_aio.py @@ -0,0 +1,291 @@ +""" +This module contains the mapping between Protobuf objects and Python/dazl types. +""" + +# Earlier versions of dazl (before v7.5.0) had an API that mapped less directly to the gRPC Ledger API. +# But with the HTTP JSON API, many common ledger methods now have much more direct translations that +# still manage to adhere quite closely to dazl's historical behavior. +# +# References: +# * https://github.com/digital-asset/daml/blob/main/ledger-service/http-json/src/main/scala/com/digitalasset/http/CommandService.scala + +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union + +from ... import LOG, ContractData +from ..._gen.com.daml.ledger.api.v1.admin.party_management_service_pb2 import ( + PartyDetails as G_PartyDetails, +) +from ..._gen.com.daml.ledger.api.v1.commands_pb2 import ( + Command as G_Command, + CreateAndExerciseCommand as G_CreateAndExerciseCommand, + CreateCommand as G_CreateCommand, + ExerciseByKeyCommand as G_ExerciseByKeyCommand, + ExerciseCommand as G_ExerciseCommand, +) +from ..._gen.com.daml.ledger.api.v1.event_pb2 import ( + ArchivedEvent as G_ArchivedEvent, + CreatedEvent as G_CreatedEvent, + ExercisedEvent as G_ExercisedEvent, +) +from ..._gen.com.daml.ledger.api.v1.transaction_filter_pb2 import ( + Filters as G_Filters, + InclusiveFilters as G_InclusiveFilters, +) +from ..._gen.com.daml.ledger.api.v1.transaction_pb2 import ( + TransactionTree as G_TransactionTree, +) +from ..._gen.com.daml.ledger.api.v1.value_pb2 import Identifier as G_Identifier +from ...damlast.daml_lf_1 import DefTemplate, TemplateChoice, Type, TypeConName +from ...damlast.daml_types import con +from ...damlast.lookup import MultiPackageLookup +from ...damlast.util import module_local_name, module_name, package_ref +from ...prim import ContractId +from ...values import Context +from ...values.protobuf import ProtobufDecoder, ProtobufEncoder, set_value +from ..core import ArchiveEvent, CreateEvent, Event, ExerciseResponse, PartyInfo +from ..pkgloader_aio import PackageLoader +from ..v1.pb_parse_event import to_type_con_name + +if TYPE_CHECKING: + from .conn_aio import Connection + +__all__ = ["Codec"] + +SHARED_PACKAGE_DATABASE = MultiPackageLookup() + + +class Codec: + """ + Contains methods for converting to/from Protobuf Ledger API types. + + Some encode/decode methods require package information to be available, which is why a + connection must be supplied in order to use the codec. + + By default, the package database is _globally_ shared; this is safe to do because we make the + same assumption that the remote gRPC Ledger API implementation makes: that package IDs uniquely + identify package contents. + """ + + def __init__(self, conn: "Connection"): + self.conn = conn + self._lookup = SHARED_PACKAGE_DATABASE + self._loader = PackageLoader(self._lookup, conn) + self._encode_context = Context(ProtobufEncoder(), self._lookup) + self._decode_context = Context(ProtobufDecoder(), self._lookup) + + async def encode_create_command( + self, template_id: "Any", payload: "ContractData" + ) -> "G_Command": + item_type = await self._loader.do_with_retry( + lambda: self._lookup.template_name(template_id) + ) + _, value = self._encode_context.convert(con(item_type), payload) + return G_Command( + create=G_CreateCommand( + template_id=self.encode_identifier(item_type), create_arguments=value + ) + ) + + async def encode_exercise_command( + self, + contract_id: "ContractId", + choice_name: str, + argument: "Optional[Any]" = None, + ): + item_type, _, choice = await self._look_up_choice(contract_id.value_type, choice_name) + + cmd_pb = G_ExerciseCommand( + template_id=self.encode_identifier(item_type), + contract_id=contract_id.value, + choice=choice_name, + ) + value_field, value_pb = await self.encode_value(choice.arg_binder.type, argument) + set_value(cmd_pb.choice_argument, value_field, value_pb) + + return G_Command(exercise=cmd_pb) + + async def encode_create_and_exercise_command( + self, + template_id: str, + payload: "ContractData", + choice_name: str, + argument: "Optional[Any]" = None, + ) -> "G_CreateAndExerciseCommand": + item_type, _, choice = await self._look_up_choice(template_id, choice_name) + + cmd_pb = G_CreateAndExerciseCommand( + template_id=self.encode_identifier(item_type), + payload=await self.encode_value(con(item_type), payload), + choice=choice_name, + ) + value_field, value_pb = await self.encode_value(choice.arg_binder.type, argument) + set_value(cmd_pb.choice_argument, value_field, value_pb) + + return G_CreateAndExerciseCommand(createAndExercise=cmd_pb) + + async def encode_exercise_by_key_command( + self, + template_id: str, + choice_name: str, + key: "Any", + argument: "Optional[ContractData]" = None, + ) -> "G_ExerciseByKeyCommand": + item_type, template, choice = await self._look_up_choice(template_id, choice_name) + + cmd_pb = G_ExerciseByKeyCommand( + template_id=self.encode_identifier(item_type), + contract_key=await self.encode_value(template.key.type, key), + choice=choice_name, + ) + value_field, value_pb = await self.encode_value(choice.arg_binder.type, argument) + set_value(cmd_pb.choice_argument, value_field, value_pb) + + return G_Command(exerciseByKey=cmd_pb) + + async def encode_filters(self, template_ids: "Sequence[Any]") -> "G_Filters": + # Search for a reference to the "wildcard" template; if any of the requested template_ids + # is "*", then return results for all templates. We do this first because resolving template + # IDs otherwise requires do_with_retry, which can be expensive. + for template_id in template_ids: + if template_id == "*": + # if any of the keys references the "wildcard" template, then this means we + # need to fetch values for all templates; note that we + return G_Filters() + + # No wildcard template IDs, so inspect and resolve all template references to concrete + # template IDs + requested_types = set() + for template_id in template_ids: + requested_types.update( + await self._loader.do_with_retry(lambda: self._lookup.template_names(template_id)) + ) + + return G_Filters( + inclusive=G_InclusiveFilters( + template_ids=[self.encode_identifier(i) for i in sorted(requested_types)] + ) + ) + + async def encode_value(self, item_type: "Type", obj: "Any") -> "Tuple[str, Optional[Any]]": + """ + Convert a dazl/Python value to its Protobuf equivalent. + """ + return await self._loader.do_with_retry( + lambda: self._encode_context.convert(item_type, obj) + ) + + @staticmethod + def encode_identifier(name: "TypeConName") -> "G_Identifier": + return G_Identifier( + package_id=package_ref(name), + module_name=str(module_name(name)), + entity_name=module_local_name(name), + ) + + async def decode_created_event(self, event: "G_CreatedEvent") -> "CreateEvent": + cid = self.decode_contract_id(event) + cdata = await self.decode_value(con(cid.value_type), event.create_arguments) + return CreateEvent(cid, cdata) + + async def decode_archived_event(self, event: "G_ArchivedEvent") -> "ArchiveEvent": + cid = self.decode_contract_id(event) + return ArchiveEvent(cid) + + async def decode_exercise_response(self, tree: "G_TransactionTree") -> "ExerciseResponse": + """ + Convert a Protobuf TransactionTree response to an ExerciseResponse. The TransactionTree is + expected to only contain a single exercise node at the root level. + """ + found_choice = None + result = None + cid = None + + events = [] # type: List[Event] + for event_id in tree.root_event_ids: + event_pb = tree.events_by_id[event_id] + event_pb_type = event_pb.WhichOneof("kind") + if event_pb_type == "created": + events.append(await self.decode_created_event(event_pb.created)) + elif event_pb_type == "exercised": + # Find the "first" exercised node and grab its result value + if cid is None: + cid = self.decode_contract_id(event_pb.exercised) + + template = self._lookup.template(cid.value_type) + + if found_choice is None: + for choice in template.choices: + if choice.name == event_pb.exercised.choice: + found_choice = choice + break + if found_choice is not None: + result = await self.decode_value( + found_choice.ret_type, + event_pb.exercised.exercise_result, + ) + else: + LOG.error( + "Received an exercise node that referred to a choice that doesn't exist!" + ) + + events.extend(await self._decode_exercised_child_events(tree, [event_id])) + else: + LOG.warning("Received an unknown event type: %s", event_pb_type) + + return ExerciseResponse(result, events) + + async def _decode_exercised_child_events( + self, tree: "G_TransactionTree", event_ids: "Sequence[str]" + ) -> "Sequence[Event]": + events = [] # type: List[Event] + for event_id in event_ids: + event_pb = tree.events_by_id[event_id] + event_pb_type = event_pb.WhichOneof("kind") + if event_pb_type == "created": + events.append(await self.decode_created_event(event_pb.created)) + elif event_pb_type == "exercised": + if event_pb.exercised.consuming: + events.append(ArchiveEvent(self.decode_contract_id(event_pb.exercised))) + events.extend( + await self._decode_exercised_child_events( + tree, event_pb.exercised.child_event_ids + ) + ) + else: + LOG.warning("Received an unknown event type: %s", event_pb_type) + return events + + async def decode_value(self, item_type: "Type", obj: "Any") -> "Optional[Any]": + """ + Convert a Protobuf Ledger API value to its dazl/Python equivalent. + """ + return await self._loader.do_with_retry( + lambda: self._decode_context.convert(item_type, obj) + ) + + @staticmethod + def decode_contract_id( + event: "Union[G_CreatedEvent, G_ExercisedEvent, G_ArchivedEvent]", + ) -> "ContractId": + vt = to_type_con_name(event.template_id) + return ContractId(vt, event.contract_id) + + @staticmethod + def decode_identifier(identifier: "G_Identifier") -> "TypeConName": + return to_type_con_name(identifier) + + @staticmethod + def decode_party_info(party_details: "G_PartyDetails") -> "PartyInfo": + return PartyInfo(party_details.party, party_details.display_name, party_details.is_local) + + async def _look_up_choice( + self, template_id: "Any", choice_name: str + ) -> "Tuple[TypeConName, DefTemplate, TemplateChoice]": + template_type = await self._loader.do_with_retry( + lambda: self._lookup.template_name(template_id) + ) + template = self._lookup.template(template_type) + for choice in template.choices: + if choice.name == choice_name: + return template_type, template, choice + raise ValueError(f"template {template.tycon} has no choice named {choice_name}") diff --git a/python/dazl/protocols/ledgerapi/conn_aio.py b/python/dazl/protocols/ledgerapi/conn_aio.py new file mode 100644 index 00000000..482f9549 --- /dev/null +++ b/python/dazl/protocols/ledgerapi/conn_aio.py @@ -0,0 +1,431 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +This module contains the mapping between gRPC calls and Python/dazl types. +""" +import uuid +import warnings +from typing import ( + AbstractSet, + Any, + AsyncIterable, + Collection, + Generic, + Mapping, + Optional, + Sequence, + TypeVar, + Union, +) + +from grpc.aio import Channel + +from ... import ContractData, Party +from ..._gen.com.daml.ledger.api.v1.active_contracts_service_pb2 import ( + GetActiveContractsRequest as G_GetActiveContractsRequest, +) +from ..._gen.com.daml.ledger.api.v1.active_contracts_service_pb2_grpc import ( + ActiveContractsServiceStub, +) +from ..._gen.com.daml.ledger.api.v1.admin.package_management_service_pb2 import ( + UploadDarFileRequest as G_UploadDarFileRequest, +) +from ..._gen.com.daml.ledger.api.v1.admin.package_management_service_pb2_grpc import ( + PackageManagementServiceStub, +) +from ..._gen.com.daml.ledger.api.v1.admin.party_management_service_pb2 import ( + AllocatePartyRequest as G_AllocatePartyRequest, +) +from ..._gen.com.daml.ledger.api.v1.admin.party_management_service_pb2_grpc import ( + PartyManagementServiceStub, +) +from ..._gen.com.daml.ledger.api.v1.command_service_pb2 import ( + SubmitAndWaitRequest as G_SubmitAndWaitRequest, +) +from ..._gen.com.daml.ledger.api.v1.command_service_pb2_grpc import CommandServiceStub +from ..._gen.com.daml.ledger.api.v1.commands_pb2 import ( + Command as G_Command, + Commands as G_Commands, +) +from ..._gen.com.daml.ledger.api.v1.ledger_identity_service_pb2 import ( + GetLedgerIdentityRequest as G_GetLedgerIdentityRequest, +) +from ..._gen.com.daml.ledger.api.v1.ledger_identity_service_pb2_grpc import ( + LedgerIdentityServiceStub, +) +from ..._gen.com.daml.ledger.api.v1.package_service_pb2 import ( + GetPackageRequest as G_GetPackageRequest, + ListPackagesRequest as G_ListPackagesRequest, +) +from ..._gen.com.daml.ledger.api.v1.package_service_pb2_grpc import PackageServiceStub +from ..._gen.com.daml.ledger.api.v1.transaction_filter_pb2 import ( + TransactionFilter as G_TransactionFilter, +) +from ..._gen.com.daml.ledger.api.v1.transaction_service_pb2 import ( + GetTransactionsRequest as G_GetTransactionsRequest, +) +from ..._gen.com.daml.ledger.api.v1.transaction_service_pb2_grpc import ( + TransactionServiceStub, +) +from ...damlast.daml_lf_1 import PackageRef +from ...prim import ContractId +from ..config import Config +from ..core import ArchiveEvent, CreateEvent, ExerciseResponse, PartyInfo, Query +from ..errors import ProtocolWarning +from ..stream_aio import QueryStreamBase +from .channel import create_channel +from .codec_aio import Codec + +__all__ = ["Connection"] + +T = TypeVar("T") + + +class Connection: + def __init__(self, config: "Config"): + self._config = config + self._channel = create_channel(config) + self._codec = Codec(self) + + @property + def config(self) -> "Config": + return self._config + + @property + def channel(self) -> "Channel": + """ + Provides access to the underlying gRPC channel. + """ + return self._channel + + @property + def codec(self) -> "Codec": + return self._codec + + async def __aenter__(self) -> "Connection": + """ + Does final validation of the token, including possibly fetching the ledger ID if it is not + yet known. + """ + if not self._config.access.ledger_id: + # most calls require a ledger ID; if it wasn't supplied as part of our token or we were + # never given a token in the first place, fetch the ledger ID from the destination + stub = LedgerIdentityServiceStub(self._channel) + response = await stub.GetLedgerIdentity(G_GetLedgerIdentityRequest()) + self._config.access.ledger_id = response.ledger_id + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self._channel.close() + + # region Write API + + async def create( + self, + template_id: str, + payload: "ContractData", + *, + workflow_id: "Optional[str]" = None, + command_id: "Optional[str]" = None, + ) -> "CreateEvent": + """ + Create a contract for a given template. + + :param template_id: The template of the contract to be created. + :param payload: Template arguments for the contract to be created. + :param workflow_id: An optional workflow ID. + :param command_id: An optional command ID. If unspecified, a random one will be created. + """ + stub = CommandServiceStub(self.channel) + + request = G_SubmitAndWaitRequest( + commands=G_Commands( + ledger_id=self._config.access.ledger_id, + application_id=self._config.access.application_name, + command_id=self._command_id(command_id), + workflow_id=self._workflow_id(workflow_id), + party=self._ensure_act_as(), + commands=[await self._codec.encode_create_command(template_id, payload)], + act_as=self._config.access.act_as, + read_as=self._config.access.read_only_as, + ) + ) + response = await stub.SubmitAndWaitForTransaction(request) + + return await self._codec.decode_created_event(response.transaction.events[0].created) + + async def exercise( + self, + contract_id: "ContractId", + choice_name: str, + argument: "Optional[ContractData]" = None, + *, + workflow_id: "Optional[str]" = None, + command_id: "Optional[str]" = None, + ) -> "ExerciseResponse": + """ + Exercise a choice on a contract identified by its contract ID. + + :param contract_id: The contract ID of the contract to exercise. + :param choice_name: The name of the choice to exercise. + :param argument: The choice arguments. Can be omitted for choices that take no argument. + :param workflow_id: An optional workflow ID. + :param command_id: An optional command ID. If unspecified, a random one will be created. + :return: A response + """ + stub = CommandServiceStub(self.channel) + + commands = [await self._codec.encode_exercise_command(contract_id, choice_name, argument)] + request = self._submit_and_wait_request(commands, workflow_id, command_id) + response = await stub.SubmitAndWaitForTransactionTree(request) + + return await self._codec.decode_exercise_response(response.transaction) + + async def create_and_exercise( + self, + template_id: str, + payload: "ContractData", + choice_name: str, + argument: "Optional[ContractData]" = None, + *, + workflow_id: "Optional[str]" = None, + command_id: "Optional[str]" = None, + ) -> "ExerciseResponse": + stub = CommandServiceStub(self.channel) + + commands = [ + await self._codec.encode_create_and_exercise_command( + template_id, payload, choice_name, argument + ) + ] + request = self._submit_and_wait_request(commands, workflow_id, command_id) + response = await stub.SubmitAndWaitForTransactionTree(request) + + return await self._codec.decode_exercise_response(response.transaction) + + async def exercise_by_key( + self, + template_id: str, + choice_name: str, + key: "Any", + argument: "Optional[ContractData]" = None, + *, + workflow_id: "Optional[str]" = None, + command_id: "Optional[str]" = None, + ) -> "ExerciseResponse": + stub = CommandServiceStub(self.channel) + + commands = [ + await self._codec.encode_exercise_by_key_command( + template_id, choice_name, key, argument + ) + ] + request = await self._submit_and_wait_request(commands, workflow_id, command_id) + response = await stub.SubmitAndWaitForTransactionTree(request) + + return await self._codec.decode_exercise_response(response.transaction) + + async def archive(self, contract_id: "ContractId") -> "ArchiveEvent": + await self.exercise(contract_id, "Archive") + return ArchiveEvent(contract_id) + + async def archive_by_key(self, template_id: str, key: "Any") -> "ArchiveEvent": + response = await self.exercise_by_key(template_id, "Archive", key) + return next(iter(event for event in response.events if isinstance(event, ArchiveEvent))) + + def _ensure_act_as(self) -> "Party": + act_as_party = next(iter(self._config.access.act_as), None) + if not act_as_party: + raise ValueError("current access rights do not include any act-as parties") + return act_as_party + + @staticmethod + def _workflow_id(workflow_id: str) -> str: + if workflow_id: + # TODO: workflow_id must be a LedgerString; we could enforce some minimal validation + # here to make for a more obvious error than failing on the server-side + return workflow_id + + @staticmethod + def _command_id(command_id: str) -> str: + # TODO: command_id must be a LedgerString; we could enforce some minimal validation + # here to make for a more obvious error than failing on the server-side + return command_id or uuid.uuid4().hex + + def _submit_and_wait_request( + self, + commands: "Collection[G_Command]", + workflow_id: "Optional[str]" = None, + command_id: "Optional[str]" = None, + ) -> "G_SubmitAndWaitRequest": + return G_SubmitAndWaitRequest( + commands=G_Commands( + ledger_id=self._config.access.ledger_id, + application_id=self._config.access.application_name, + command_id=self._command_id(command_id), + workflow_id=self._workflow_id(workflow_id), + party=self._ensure_act_as(), + commands=commands, + act_as=self._config.access.act_as, + read_as=self._config.access.read_only_as, + ) + ) + + # endregion + + # region Read API + + def query(self, template_id: str = "*", query: "Query" = None) -> "QueryStream[CreateEvent]": + return QueryStream(self, {template_id: query}, False) + + def query_many( + self, queries: "Optional[Mapping[str, Query]]" = None + ) -> "QueryStream[CreateEvent]": + return QueryStream(self, queries, False) + + def stream( + self, template_id: str = "*", query: "Query" = None + ) -> "QueryStream[Union[CreateEvent, ArchiveEvent]]": + return QueryStream(self, {template_id: query}, True) + + def stream_many( + self, queries: "Optional[Mapping[str, Query]]" = None + ) -> "QueryStream[Union[CreateEvent, ArchiveEvent]]": + return QueryStream(self, queries, True) + + # endregion + + # region Party Management calls + + async def allocate_party( + self, identifier_hint: str = None, display_name: str = None + ) -> "PartyInfo": + """ + Allocate a new party. + """ + stub = PartyManagementServiceStub(self.channel) + request = G_AllocatePartyRequest(party_id_hint=identifier_hint, display_name=display_name) + response = await stub.AllocateParty(request) + return Codec.decode_party_info(response.party_details) + + async def list_known_parties(self) -> "Sequence[PartyInfo]": + stub = PartyManagementServiceStub(self.channel) + response = await stub.ListKnownParties() + return [Codec.decode_party_info(pd) for pd in response.party_details] + + # endregion + + # region Package Management calls + + async def get_package(self, package_id: "PackageRef") -> bytes: + stub = PackageServiceStub(self.channel) + request = G_GetPackageRequest( + ledger_id=self._config.access.ledger_id, package_id=str(package_id) + ) + response = await stub.GetPackage(request) + return response.archive_payload + + async def list_package_ids(self) -> "AbstractSet[PackageRef]": + stub = PackageServiceStub(self.channel) + request = G_ListPackagesRequest(ledger_id=self._config.access.ledger_id) + response = await stub.ListPackages(request) + return frozenset({PackageRef(pkg_id) for pkg_id in response.package_ids}) + + async def upload_package(self, contents: bytes) -> None: + stub = PackageManagementServiceStub(self.channel) + request = G_UploadDarFileRequest(dar_file=contents) + await stub.UploadDarFile(request) + return + + # endregion + + +class QueryStream(Generic[T], QueryStreamBase): + def __init__( + self, + conn: "Connection", + queries: "Optional[Mapping[str, Query]]", + continue_stream: bool, + ): + self.conn = conn + self._queries = queries + self._continue_stream = continue_stream + + self._offset = None + self._filter = None + self._response_stream = None + + async def close(self) -> None: + if self._response_stream is not None: + self._response_stream.cancel() + self._response_stream = None + + async def items(self) -> "AsyncIterable[T]": + """ + Return an asynchronous stream of events. + + .. code-block:: python + + async with conn.query('SampleApp:Iou') as query: + async for r in query: + print(f"Offset: {r.offset}") + for event in r.events: + print(f" Event: {event}") + + :return: + A stream of responses, where each response contains one or more events at a particular + offset. + """ + filters = await self.conn.codec.encode_filters(self._queries) + filters_by_party = {party: filters for party in self.conn.config.access.read_as} + tx_filter_pb = G_TransactionFilter(filters_by_party=filters_by_party) + + try: + async for event in self._acs_events(tx_filter_pb): + await self._emit_create(event) + yield event + + if self._continue_stream: + # now start returning events as they come off the transaction stream; note this + # stream will never naturally close, so it's on the caller to call close() or to + # otherwise exit our current context + async for event in self._tx_events(tx_filter_pb): + if isinstance(event, CreateEvent): + await self._emit_create(event) + elif isinstance(event, ArchiveEvent): + await self._emit_archive(event) + else: + warnings.warn(f"Received an unknown event: {event}", ProtocolWarning) + yield event + finally: + await self.close() + + async def _acs_events(self, filter_pb: "G_TransactionFilter") -> "AsyncIterable[CreateEvent]": + stub = ActiveContractsServiceStub(self.conn.channel) + + request = G_GetActiveContractsRequest( + ledger_id=self.conn.config.access.ledger_id, filter=filter_pb + ) + self._response_stream = stub.GetActiveContracts(request) + async for response in self._response_stream: + self._offset = response.offset + for event in response.active_contracts: + yield await self.conn.codec.decode_created_event(event) + + async def _tx_events(self, filter_pb: "G_TransactionFilter") -> "AsyncIterable[CreateEvent]": + stub = TransactionServiceStub(self.conn.channel) + + request = G_GetTransactionsRequest( + ledger_id=self.conn.config.access.ledger_id, filter=filter_pb + ) + self._response_stream = stub.GetTransactions(request) + async for response in self._response_stream: + self._offset = response.offset + for event in response.active_contracts: + event_type = event.WhichOneof("event") + if event_type == "created": + yield await self.conn.codec.decode_created_event(event.created) + elif event_type == "archived": + yield await self.conn.codec.decode_archived_event(event.archived) + else: + warnings.warn(f"Unknown Event({event_type}=...)", ProtocolWarning) diff --git a/python/dazl/protocols/pkgloader_aio.py b/python/dazl/protocols/pkgloader_aio.py new file mode 100644 index 00000000..b1f17dbd --- /dev/null +++ b/python/dazl/protocols/pkgloader_aio.py @@ -0,0 +1,224 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import sys +from asyncio import ensure_future, gather, get_event_loop, sleep, wait_for +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from typing import AbstractSet, Awaitable, Callable, Dict, Set, TypeVar + +from .. import LOG +from ..damlast.daml_lf_1 import Archive, Package, PackageRef +from ..damlast.errors import NameNotFoundError, PackageNotFoundError +from ..damlast.lookup import MultiPackageLookup +from ..damlast.pkgfile import Dar +from ..model.core import DazlError +from ..model.lookup import validate_template + +if sys.version_info >= (3, 7): + from typing import Protocol +else: + from typing_extensions import Protocol + + +__all__ = ['PackageService', 'PackageLoader'] + + +T = TypeVar('T') + +DEFAULT_TIMEOUT = timedelta(seconds=30) + + +class PackageService(Protocol): + """ + A service that provides package information. + """ + + async def get_package(self, package_id: 'PackageRef') -> bytes: + raise NotImplementedError('SyncPackageService.package_bytes requires an implementation') + + async def list_package_ids(self) -> 'AbstractSet[PackageRef]': + raise NotImplementedError('SyncPackageService.package_ids requires an implementation') + + +class PackageLoader: + """ + Loader for packages from a remote PackageService. + + This class handles retries and backoffs, and avoids having more than one request in flight for + the same package ID. It is intended to be shared by all local clients that may need package + information. + """ + + def __init__( + self, + package_lookup: 'MultiPackageLookup', + conn: 'PackageService' = None, + timeout: 'timedelta' = DEFAULT_TIMEOUT): + self._package_lookup = package_lookup + self._conn = conn + self._timeout = timeout + self._loading_futs = dict() # type: Dict[PackageRef, Awaitable[Package]] + self._parsing_futs = dict() # type: Dict[PackageRef, Awaitable[Archive]] + self._executor = ThreadPoolExecutor(3) + + def set_connection(self, conn): + self._conn = conn + + async def do_with_retry(self, fn: 'Callable[[], T]') -> 'T': + """ + Perform a synchronous action that assumes the existence of one or more packages. In the + event the function raises :class:`PackageNotFoundError` or a wildcarded + :class:`NameNotFoundError`, the required package/type is fetched and the operation retried. + + If, after a retry, an expected package or type could not be found, the exception is + re-raised to the caller. + + :param fn: A function to invoke. + :return: The result of that function. + """ + failed_types = set() # type: Set[str] + failed_packages = set() # type: Set[PackageRef] + while True: + try: + return fn() + + except PackageNotFoundError as ex: + # every time we fail serialization due to a missing package or type, + # try to resolve it; remember what we tried, because if we fail again + # for the same reason it is likely fatal + if ex.ref in failed_packages: + # we already looked for this package and couldn't find it; this will + # never succeed + raise + failed_packages.add(ex.ref) + await self.load(ex.ref) + + except NameNotFoundError as ex: + if ex.ref in failed_types: + # we already looked for this type and couldn't find it; this will + # never succeed + LOG.verbose( + "Failed to find name %s in all known packages, " + "even after fetching the latest.", + ex.ref) + raise + + pkg_id, name = validate_template(ex.ref) + if pkg_id == '*': + # we don't know what package contains this type, so we have no + # choice but to look in all known packages + LOG.verbose( + "Failed to find name %s in all known packages, " + "so loading ALL packages...", name) + failed_types.add(ex.ref) + await self.load_all() + else: + # we know what package this type comes from, but it did not contain + # the required type + LOG.warning("Found package %s, but it did not include type %s", pkg_id, name) + raise + + async def preload(self, *contents: 'Dar') -> None: + """ + Populate a :class:`PackageCache` with types from DARs. + + :param contents: + One or more DARs to load into a local package cache. + """ + + async def load(self, ref: 'PackageRef') -> 'Package': + """ + Load a package ID from the remote server. If the package has additional dependencies, they + are also loaded. + + :param ref: One or more :class:`PackageRef`s. + :raises: PackageNotFoundError if the package could not be resolved + """ + # If the package has already been loaded, then skip all the expensive I/O stuff + try: + return self._package_lookup.package(ref) + except PackageNotFoundError: + pass + + # If we already have a request in-flight, simply return that same Future to our caller; + # do not try to schedule a new request + fut = self._loading_futs.get(ref) + if fut is None: + fut = ensure_future(self._load_and_parse_package(ref)) + self._loading_futs[ref] = fut + package = await fut + + _ = self._loading_futs.pop(ref, None) + _ = self._parsing_futs.pop(ref, None) + + return package + + async def _load_and_parse_package(self, package_id: 'PackageRef') -> 'Package': + from ..damlast.parse import parse_archive + + LOG.info("Loading package: %s", package_id) + + loop = get_event_loop() + conn = self._conn + if conn is None: + raise DazlError('a connection is not configured') + + archive_bytes = await wait_for( + fetch_package_bytes(conn, package_id), timeout=self._timeout.total_seconds()) + + LOG.info("Loaded for package: %s, %d bytes", package_id, len(archive_bytes)) + + # we only ever want a package to be parsed once; it could be that there were multiple + # attempts to load a package in flight (though this shouldn't happen either) + fut = self._parsing_futs.get(package_id) + if fut is None: + fut = ensure_future(loop.run_in_executor( + self._executor, lambda: parse_archive(package_id, archive_bytes))) + self._parsing_futs[package_id] = fut + + archive = await fut + self._package_lookup.add_archive(archive) + return archive.package + + @staticmethod + async def __fetch_package_bytes(conn, package_id): + sleep_interval = 1 + + while True: + # noinspection PyBroadException + try: + return await conn.package_bytes(package_id) + except Exception: + # We tried fetching the package but got an error. Retry, backing off to waiting as + # much as 30 seconds between each attempt. + await sleep(sleep_interval) + sleep_interval = min(sleep_interval * 2, 30) + LOG.exception("Failed to fetch package; this will be retried.") + + async def load_all(self): + """ + Load all packages from the remote server. + """ + package_ids = set(await self._conn.list_package_ids()) + package_ids -= self._package_lookup.package_ids() + if package_ids: + await gather(*(self.load(package_id) for package_id in package_ids)) + + +async def fetch_package_bytes(conn: "PackageService", package_id: "PackageRef") -> bytes: + """ + Fetch package bytes, with retry. + """ + sleep_interval = 1 + + while True: + # noinspection PyBroadException + try: + return await conn.get_package(package_id) + except Exception: + # We tried fetching the package but got an error. Retry, backing off to waiting as + # much as 30 seconds between each attempt. + await sleep(sleep_interval) + sleep_interval = min(sleep_interval * 2, 30) + LOG.exception("Failed to fetch package; this will be retried.") diff --git a/python/dazl/protocols/stream_aio.py b/python/dazl/protocols/stream_aio.py new file mode 100644 index 00000000..09d7a534 --- /dev/null +++ b/python/dazl/protocols/stream_aio.py @@ -0,0 +1,117 @@ +# Copyright (c) 2017-2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from collections import defaultdict +from inspect import iscoroutine +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + DefaultDict, + List, + TypeVar, + overload, +) + +from .core import ArchiveEvent, CreateEvent, Event +from .errors import CallbackReturnWarning + +__all__ = ["QueryStreamBase"] + +CREATE_EVENT = "create" +ARCHIVE_EVENT = "archive" + +Self = TypeVar("Self") + + +class QueryStreamBase: + + @property + def _callbacks(self) -> "DefaultDict[str, List[Callable[[Any], None]]]": + cb = getattr(self, "_callbacks", None) + if cb is None: + cb = defaultdict(list) + object.__setattr__(self, "_callbacks", cb) + return cb + + @overload + def on_create(self, fn: 'Callable[[CreateEvent], None]') -> 'Callable[[CreateEvent], None]': + ... + + @overload + def on_create(self, fn: 'Callable[[CreateEvent], Awaitable[None]]') \ + -> 'Callable[[CreateEvent], Awaitable[None]]': + ... + + def on_create(self, fn): + if not callable(fn): + raise ValueError('fn must be a callable') + + self._callbacks[CREATE_EVENT].append(fn) + + @overload + def on_archive(self, fn: 'Callable[[ArchiveEvent], None]') -> 'Callable[[ArchiveEvent], None]': + ... + + @overload + def on_archive(self, fn: 'Callable[[ArchiveEvent], Awaitable[None]]') \ + -> 'Callable[[ArchiveEvent], Awaitable[None]]': + ... + + def on_archive(self, fn): + if not callable(fn): + raise ValueError('fn must be a callable') + + self._callbacks[ARCHIVE_EVENT].append(fn) + + async def __aenter__(self: Self) -> "Self": + """ + Prepare the stream. + """ + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """ + Close the stream. + """ + + async def close(self) -> None: + """ + Close and dispose of any resources used by this stream. + """ + + async def run(self) -> "None": + """ + "Runs" the stream. This can be called as an alternative to :meth:`items` when using + callback-based APIs. + """ + async for _ in self: + pass + + def items(self) -> "AsyncIterable[Event]": + """ + Must be overridden by subclasses to provide a stream of events. The implementation is + expected to call :meth:`_emit_create` and :meth:`_emit_archive` for every encountered event. + """ + raise NotImplementedError + + def __aiter__(self) -> "AsyncIterable[Event]": + return self.items() + + async def _emit(self, name: str, obj: "Any"): + for cb in self._callbacks[name]: + result = cb(obj) + if iscoroutine(result): + result = await result + if result is not None: + warnings.warn( + "callbacks should not return anything; the result will be ignored", + CallbackReturnWarning) + + async def _emit_create(self, event: "CreateEvent"): + await self._emit(CREATE_EVENT, event) + + async def _emit_archive(self, event: "ArchiveEvent"): + await self._emit(ARCHIVE_EVENT, event) diff --git a/python/tests/unit/test_protocol_ledgerapi.py b/python/tests/unit/test_protocol_ledgerapi.py new file mode 100644 index 00000000..ac115888 --- /dev/null +++ b/python/tests/unit/test_protocol_ledgerapi.py @@ -0,0 +1,43 @@ +import logging + +import pytest + +from .dars import PostOffice +from dazl.protocols.ledgerapi import connect + + +@pytest.mark.asyncio +async def test_protocol_ledger_api(sandbox): + # first, administrative stuff--upload the DAR and allocate two parties that we'll use later + async with connect(url=sandbox, admin=True) as conn: + await conn.upload_package(PostOffice.read_bytes()) + postman = (await conn.allocate_party()).party + participant = (await conn.allocate_party()).party + + async with connect(url=sandbox, party=postman) as conn: + event = await conn.create("Main:PostmanRole", {"postman": postman}) + result = await conn.exercise( + event.cid, "InviteParticipant", {"party": participant, "address": "Somewhere!"}) + logging.info("Result of inviting a participant: %s", result) + + async with connect(url=sandbox, party=participant) as conn: + # Stream results for Main:InviteAuthorRole, and then Main:InviteReceiverRole. Then break the + # stream once we find the first contract. + # + # We do NOT use query() here, because in a distributed ledger setting, the result of the + # postman inviting participants may not yet have been observed by the clients. Instead, use + # stream() since it remains open until explicitly closed. We break the never-ending iterator + # as soon as we see one of each contract. + async with conn.stream("Main:InviteAuthorRole") as query: + async for event in query: + result = await conn.exercise(event.cid, "AcceptInviteAuthorRole") + logging.info("The result of AcceptInviteAuthorRole: %s", result) + break + + async with conn.stream("Main:InviteReceiverRole") as query: + async for event in query: + result = await conn.exercise(event.cid, "AcceptInviteReceiverRole") + logging.info("The result of AcceptInviteReceiverRole: %s", result) + break + + logging.info("Done!")