Skip to content

Commit

Permalink
Fix type errors and add type annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
Majsvaffla committed Apr 16, 2024
1 parent aed72c6 commit 1430080
Show file tree
Hide file tree
Showing 17 changed files with 186 additions and 170 deletions.
58 changes: 26 additions & 32 deletions bankid/asyncclient.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Optional, Tuple, Dict, Any
from typing import Tuple, Dict, Any

import httpx

from bankid.baseclient import BankIDClientBaseclass
from bankid.exceptions import get_json_error_class


class BankIDAsyncClient(BankIDClientBaseclass):
class BankIDAsyncClient(BankIDClientBaseclass[httpx.AsyncClient]):
"""The asynchronous client to use for communicating with BankID servers via the v6 API.
:param certificates: Tuple of string paths to the certificate to use and
Expand All @@ -19,25 +19,19 @@ class BankIDAsyncClient(BankIDClientBaseclass):
"""

def __init__(self, certificates: Tuple[str, str], test_server: bool = False, request_timeout: Optional[int] = None):
def __init__(self, certificates: Tuple[str, str], test_server: bool = False, request_timeout: int = 5):
super().__init__(certificates, test_server, request_timeout)

kwargs = {
"cert": self.certs,
"headers": {"Content-Type": "application/json"},
"verify": self.verify_cert,
}
if request_timeout:
kwargs["timeout"] = request_timeout
self.client = httpx.AsyncClient(**kwargs)
headers = {"Content-Type": "application/json"}
self.client = httpx.AsyncClient(cert=self.certs, headers=headers, verify=str(self.verify_cert), timeout=request_timeout)

async def authenticate(
self,
end_user_ip: str,
requirement: Dict[str, Any] = None,
user_visible_data: str = None,
user_non_visible_data: str = None,
user_visible_data_format: str = None,
requirement: Dict[str, Any] | None = None,
user_visible_data: str | None = None,
user_non_visible_data: str | None = None,
user_visible_data_format: str | None = None,
) -> Dict[str, str]:
"""Request an authentication order. The :py:meth:`collect` method
is used to query the status of the order.
Expand Down Expand Up @@ -85,18 +79,18 @@ async def authenticate(
response = await self.client.post(self._auth_endpoint, json=data)

if response.status_code == 200:
return response.json()
return response.json() # type: ignore[no-any-return]
else:
raise get_json_error_class(response)

async def phone_authenticate(
self,
personal_number: str,
call_initiator: str,
requirement: Dict[str, Any] = None,
user_visible_data: str = None,
user_non_visible_data: str = None,
user_visible_data_format: str = None,
requirement: Dict[str, Any] | None = None,
user_visible_data: str | None = None,
user_non_visible_data: str | None = None,
user_visible_data_format: str | None = None,
) -> Dict[str, str]:
"""Initiates an authentication order when the user is talking
to the RP over the phone. The :py:meth:`collect` method
Expand Down Expand Up @@ -150,17 +144,17 @@ async def phone_authenticate(
response = await self.client.post(self._phone_auth_endpoint, json=data)

if response.status_code == 200:
return response.json()
return response.json() # type: ignore[no-any-return]
else:
raise get_json_error_class(response)

async def sign(
self,
end_user_ip,
end_user_ip: str,
user_visible_data: str,
requirement: Dict[str, Any] = None,
user_non_visible_data: str = None,
user_visible_data_format: str = None,
requirement: Dict[str, Any] | None = None,
user_non_visible_data: str | None = None,
user_visible_data_format: str | None = None,
) -> Dict[str, str]:
"""Request a signing order. The :py:meth:`collect` method
is used to query the status of the order.
Expand Down Expand Up @@ -206,7 +200,7 @@ async def sign(
response = await self.client.post(self._sign_endpoint, json=data)

if response.status_code == 200:
return response.json()
return response.json() # type: ignore[no-any-return]
else:
raise get_json_error_class(response)

Expand All @@ -215,9 +209,9 @@ async def phone_sign(
personal_number: str,
call_initiator: str,
user_visible_data: str,
requirement: Dict[str, Any] = None,
user_non_visible_data: str = None,
user_visible_data_format: str = None,
requirement: Dict[str, Any] | None = None,
user_non_visible_data: str | None = None,
user_visible_data_format: str | None = None,
) -> Dict[str, str]:
"""Initiates an authentication order when the user is talking to
the RP over the phone. The :py:meth:`collect` method
Expand Down Expand Up @@ -269,7 +263,7 @@ async def phone_sign(
response = await self.client.post(self._phone_sign_endpoint, json=data)

if response.status_code == 200:
return response.json()
return response.json() # type: ignore[no-any-return]
else:
raise get_json_error_class(response)

Expand Down Expand Up @@ -341,7 +335,7 @@ async def collect(self, order_ref: str) -> dict:
response = await self.client.post(self._collect_endpoint, json={"orderRef": order_ref})

if response.status_code == 200:
return response.json()
return response.json() # type: ignore[no-any-return]
else:
raise get_json_error_class(response)

Expand All @@ -362,6 +356,6 @@ async def cancel(self, order_ref: str) -> bool:
response = await self.client.post(self._cancel_endpoint, json={"orderRef": order_ref})

if response.status_code == 200:
return response.json() == {}
return response.json() == {} # type: ignore[no-any-return]
else:
raise get_json_error_class(response)
39 changes: 20 additions & 19 deletions bankid/baseclient.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from __future__ import annotations
import base64
from datetime import datetime
from typing import Tuple, Optional, Dict, Any
from typing import Tuple, Dict, Any, Union, TYPE_CHECKING, TypeVar, Generic
from urllib.parse import urljoin

from bankid.qr import generate_qr_code_content
from bankid.certutils import resolve_cert_path

import httpx

class BankIDClientBaseclass:
TClient = TypeVar("TClient", httpx.AsyncClient,httpx.Client)


class BankIDClientBaseclass(Generic[TClient]):
"""Baseclass for BankID clients.
Both the synchronous and asynchronous clients inherit from this base class and has the methods implemented here.
"""

client: TClient

def __init__(
self,
certificates: Tuple[str, str],
test_server: bool = False,
request_timeout: Optional[int] = None,
request_timeout: int = 5,
):
self.certs = certificates
self._request_timeout = request_timeout

if test_server:
self.api_url = "https://appapi2.test.bankid.com/rp/v6.0/"
Expand All @@ -36,28 +42,23 @@ def __init__(
self._collect_endpoint = urljoin(self.api_url, "collect")
self._cancel_endpoint = urljoin(self.api_url, "cancel")

self.client = None

@staticmethod
def generate_qr_code_content(qr_start_token: str, start_t: [float, datetime], qr_start_secret: str) -> str:
def generate_qr_code_content(qr_start_token: str, start_t: Union[float, datetime], qr_start_secret: str) -> str:
return generate_qr_code_content(qr_start_token, start_t, qr_start_secret)

@staticmethod
def _encode_user_data(user_data):
if isinstance(user_data, str):
return base64.b64encode(user_data.encode("utf-8")).decode("ascii")
else:
return base64.b64encode(user_data).decode("ascii")
def _encode_user_data(user_data: str) -> str:
return base64.b64encode(user_data.encode("utf-8")).decode("ascii")

def _create_payload(
self,
end_user_ip: str = None,
requirement: Dict[str, Any] = None,
user_visible_data: str = None,
user_non_visible_data: str = None,
user_visible_data_format: str = None,
):
data = {}
end_user_ip: str | None = None,
requirement: Dict[str, Any] | None = None,
user_visible_data: str | None = None,
user_non_visible_data: str | None = None,
user_visible_data_format: str | None = None,
) -> Dict[str, str]:
data: Dict[str, Any] = {}
if end_user_ip:
data["endUserIp"] = end_user_ip
if requirement and isinstance(requirement, dict):
Expand Down
5 changes: 3 additions & 2 deletions bankid/certs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# We have to pin these to prevent basic MITM attacks.

from pathlib import Path
from typing import Tuple


def get_test_cert_p12():
def get_test_cert_p12() -> Path:
return (Path(__file__).parent / "FPTestcert4_20230629.p12").resolve()


def get_test_cert_and_key():
def get_test_cert_and_key() -> Tuple[Path, Path]:
return (
(Path(__file__).parent / "FPTestcert4_20230629_cert.pem").resolve(),
(Path(__file__).parent / "FPTestcert4_20230629_key.pem").resolve(),
Expand Down
18 changes: 10 additions & 8 deletions bankid/certutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@


def resolve_cert_path(file: str) -> pathlib.Path:
return importlib.resources.files("bankid.certs").joinpath(file)
path = importlib.resources.files("bankid.certs").joinpath(file)
assert isinstance(path, pathlib.Path)
return path


def create_bankid_test_server_cert_and_key(destination_path: str = ".") -> Tuple[str]:
def create_bankid_test_server_cert_and_key(destination_path: str = ".") -> Tuple[str, str]:
"""Split the bundled test certificate into certificate and key parts and save them
as separate files, stored in PEM format.
Expand All @@ -35,9 +37,9 @@ def create_bankid_test_server_cert_and_key(destination_path: str = ".") -> Tuple
:rtype: tuple
"""
if os.getenv("TEST_CERT_FILE"):
if test_cert_file := os.getenv("TEST_CERT_FILE"):
certificate, key = split_certificate(
os.getenv("TEST_CERT_FILE"), destination_path, password=_TEST_CERT_PASSWORD
test_cert_file, destination_path, password=_TEST_CERT_PASSWORD
)

else:
Expand All @@ -48,7 +50,7 @@ def create_bankid_test_server_cert_and_key(destination_path: str = ".") -> Tuple
return certificate, key


def split_certificate(certificate_path, destination_folder, password=None):
def split_certificate(certificate_path: str, destination_folder: str, password: str | None = None) -> Tuple[str, str]:
"""Splits a PKCS12 certificate into Base64-encoded DER certificate and key.
This method splits a potentially password-protected
Expand All @@ -64,7 +66,7 @@ def split_certificate(certificate_path, destination_folder, password=None):
try:
# Attempt Linux and Darwin call first.
p = subprocess.Popen(["openssl", "version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
sout, serr = p.communicate()
sout, _ = p.communicate()
openssl_executable_version = sout.decode().lower()
if not (openssl_executable_version.startswith("openssl") or openssl_executable_version.startswith("libressl")):
raise BankIDError("OpenSSL executable could not be found. " "Splitting cannot be performed.")
Expand All @@ -76,7 +78,7 @@ def split_certificate(certificate_path, destination_folder, password=None):
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
sout, serr = p.communicate()
sout, _ = p.communicate()
if not sout.decode().lower().startswith("openssl"):
raise BankIDError("OpenSSL executable could not be found. " "Splitting cannot be performed.")
openssl_executable = "C:\\Program Files\\Git\\mingw64\\bin\\openssl.exe"
Expand Down Expand Up @@ -129,7 +131,7 @@ def split_certificate(certificate_path, destination_folder, password=None):
return out_cert_path, out_key_path


def main(verbose=True):
def main(verbose: bool = True) -> Tuple[str, str]:
paths = create_bankid_test_server_cert_and_key(os.path.expanduser("~"))
if verbose:
print("Saved certificate as {0}".format(paths[0]))
Expand Down
30 changes: 16 additions & 14 deletions bankid/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import httpx
from typing import Any, Dict


def get_json_error_class(response):
def get_json_error_class(response: httpx.Response) -> BankIDError:
data = response.json()
error_class = _JSON_ERROR_CODE_TO_CLASS.get(data.get("errorCode"), BankIDError)
return error_class("{0}: {1}".format(data.get("errorCode"), data.get("details")), raw_data=data)
Expand All @@ -10,10 +12,10 @@ def get_json_error_class(response):
class BankIDError(Exception):
"""Parent exception class for all PyBankID errors."""

def __init__(self, *args, **kwargs):
super(BankIDError, self).__init__(*args)
self.rfa = None
self.json = kwargs.get("raw_data", {})
def __init__(self, *args: Any, raw_data: Dict[str, Any] | None = None, **kwargs: Any) -> None:
super(BankIDError, self).__init__(*args, **kwargs)
self.rfa: int | None = None
self.json = raw_data or {}


class BankIDWarning(Warning):
Expand All @@ -35,7 +37,7 @@ class InvalidParametersError(BankIDError):
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


Expand All @@ -53,7 +55,7 @@ class AlreadyInProgressError(BankIDError):
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.rfa = 4

Expand All @@ -71,7 +73,7 @@ class InternalError(BankIDError):
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.rfa = 5

Expand All @@ -87,7 +89,7 @@ class MaintenanceError(BankIDError):
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.rfa = 5

Expand All @@ -103,7 +105,7 @@ class UnauthorizedError(BankIDError):
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


Expand All @@ -118,7 +120,7 @@ class NotFoundError(BankIDError):
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


Expand All @@ -133,12 +135,12 @@ class RequestTimeoutError(BankIDError):
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.rfa = 5


_JSON_ERROR_CODE_TO_CLASS = {
_JSON_ERROR_CODE_TO_CLASS: Dict[str, type[BankIDError]] = {
"invalidParameters": InvalidParametersError,
"alreadyInProgress": AlreadyInProgressError,
"unauthorized": UnauthorizedError,
Expand Down
Loading

0 comments on commit 1430080

Please sign in to comment.