diff --git a/run-tests.sh b/run-tests.sh index e322664c..3b718427 100755 --- a/run-tests.sh +++ b/run-tests.sh @@ -14,4 +14,4 @@ else . ./venv/bin/activate fi -pytest +pytest "$@" diff --git a/scitt_emulator/client.py b/scitt_emulator/client.py index 351c8922..e80844d0 100644 --- a/scitt_emulator/client.py +++ b/scitt_emulator/client.py @@ -4,6 +4,7 @@ from typing import Optional from pathlib import Path import json +import time import httpx @@ -11,19 +12,47 @@ from scitt_emulator.tree_algs import TREE_ALGS DEFAULT_URL = "http://127.0.0.1:8000" +CONNECT_RETRIES = 3 +HTTP_RETRIES = 3 +HTTP_DEFAULT_RETRY_DELAY = 1 def raise_for_status(response: httpx.Response): if response.is_success: return - try: - error = response.json() - except json.JSONDecodeError: - error = response.text - raise RuntimeError(f"HTTP error {response.status_code}: {error}") - raise RuntimeError( - f"HTTP error {response.status_code}: {error['error']['message']}" - ) + raise RuntimeError(f"HTTP error {response.status_code}: {response.text}") + + +def raise_for_operation_status(operation: dict): + if operation["status"] != "failed": + return + raise RuntimeError(f"Operation error: {operation['error']}") + + +class HttpClient: + def __init__(self, cacert: Optional[Path] = None): + verify = True if cacert is None else str(cacert) + transport = httpx.HTTPTransport(retries=CONNECT_RETRIES, verify=verify) + self.client = httpx.Client(transport=transport) + + def _request(self, *args, **kwargs): + response = self.client.request(*args, **kwargs) + retries = HTTP_RETRIES + while retries >= 0 and response.status_code == 503: + retries -= 1 + retry_after = int( + response.headers.get("retry-after", HTTP_DEFAULT_RETRY_DELAY) + ) + time.sleep(retry_after) + response = self.client.request(*args, **kwargs) + raise_for_status(response) + return response + + def get(self, *args, **kwargs): + return self._request("GET", *args, **kwargs) + + def post(self, *args, **kwargs): + return self._request("POST", *args, **kwargs) def create_claim(issuer: str, content_type: str, payload: str, claim_path: Path): @@ -31,19 +60,41 @@ def create_claim(issuer: str, content_type: str, payload: str, claim_path: Path) def submit_claim( - url: str, claim_path: Path, receipt_path: Path, entry_id_path: Optional[Path] + url: str, + claim_path: Path, + receipt_path: Path, + entry_id_path: Optional[Path], + client: HttpClient, ): with open(claim_path, "rb") as f: claim = f.read() # Submit claim - response = httpx.post(f"{url}/entries", content=claim) - raise_for_status(response) - entry_id = response.json()["entry_id"] + response = client.post(f"{url}/entries", content=claim) + if response.status_code == 201: + entry = response.json() + entry_id = entry["entryId"] + + elif response.status_code == 202: + operation = response.json() + + # Wait for registration to finish + while operation["status"] != "succeeded": + retry_after = int( + response.headers.get("retry-after", HTTP_DEFAULT_RETRY_DELAY) + ) + time.sleep(retry_after) + response = client.get(f"{url}/operations/{operation['operationId']}") + operation = response.json() + raise_for_operation_status(operation) + + entry_id = operation["entryId"] + + else: + raise RuntimeError(f"Unexpected status code: {response.status_code}") # Fetch receipt - response = httpx.get(f"{url}/entries/{entry_id}/receipt") - raise_for_status(response) + response = client.get(f"{url}/entries/{entry_id}/receipt") receipt = response.content print(f"Claim registered with entry ID {entry_id}") @@ -62,9 +113,8 @@ def submit_claim( print(f"Entry ID written to {entry_id_path}") -def retrieve_claim(url: str, entry_id: Path, claim_path: Path): - response = httpx.get(f"{url}/entries/{entry_id}") - raise_for_status(response) +def retrieve_claim(url: str, entry_id: Path, claim_path: Path, client: HttpClient): + response = client.get(f"{url}/entries/{entry_id}") claim = response.content with open(claim_path, "wb") as f: @@ -73,9 +123,8 @@ def retrieve_claim(url: str, entry_id: Path, claim_path: Path): print(f"Claim written to {claim_path}") -def retrieve_receipt(url: str, entry_id: Path, receipt_path: Path): - response = httpx.get(f"{url}/entries/{entry_id}/receipt") - raise_for_status(response) +def retrieve_receipt(url: str, entry_id: Path, receipt_path: Path, client: HttpClient): + response = client.get(f"{url}/entries/{entry_id}/receipt") receipt = response.content with open(receipt_path, "wb") as f: @@ -123,9 +172,10 @@ def cli(fn): help="Path to write the entry id to", ) p.add_argument("--url", required=False, default=DEFAULT_URL) + p.add_argument("--cacert", type=Path, help="CA certificate to verify host against") p.set_defaults( func=lambda args: submit_claim( - args.url, args.claim, args.out, args.out_entry_id + args.url, args.claim, args.out, args.out_entry_id, HttpClient(args.cacert) ) ) @@ -133,7 +183,12 @@ def cli(fn): p.add_argument("--entry-id", required=True, type=str) p.add_argument("--out", required=True, type=Path, help="Path to write the claim to") p.add_argument("--url", required=False, default=DEFAULT_URL) - p.set_defaults(func=lambda args: retrieve_claim(args.url, args.entry_id, args.out)) + p.add_argument("--cacert", type=Path, help="CA certificate to verify host against") + p.set_defaults( + func=lambda args: retrieve_claim( + args.url, args.entry_id, args.out, HttpClient(args.cacert) + ) + ) p = sub.add_parser("retrieve-receipt", description="Retrieve a SCITT receipt") p.add_argument("--entry-id", required=True, type=str) @@ -141,8 +196,11 @@ def cli(fn): "--out", required=True, type=Path, help="Path to write the receipt to" ) p.add_argument("--url", required=False, default=DEFAULT_URL) + p.add_argument("--cacert", type=Path, help="CA certificate to verify host against") p.set_defaults( - func=lambda args: retrieve_receipt(args.url, args.entry_id, args.out) + func=lambda args: retrieve_receipt( + args.url, args.entry_id, args.out, HttpClient(args.cacert) + ) ) p = sub.add_parser("verify-receipt", description="Verify a SCITT receipt") diff --git a/scitt_emulator/scitt.py b/scitt_emulator/scitt.py index 060629fa..18469941 100644 --- a/scitt_emulator/scitt.py +++ b/scitt_emulator/scitt.py @@ -6,6 +6,7 @@ from pathlib import Path import time import json +import uuid import cbor2 from pycose.messages import CoseMessage, Sign1Message @@ -30,6 +31,10 @@ class EntryNotFoundError(Exception): pass +class OperationNotFoundError(Exception): + pass + + class SCITTServiceEmulator(ABC): def __init__( self, service_parameters_path: Path, storage_path: Optional[Path] = None @@ -37,6 +42,10 @@ def __init__( self.storage_path = storage_path self.service_parameters_path = service_parameters_path + if storage_path is not None: + self.operations_path = storage_path / "operations" + self.operations_path.mkdir(exist_ok=True) + if self.service_parameters_path.exists(): with open(self.service_parameters_path) as f: self.service_parameters = json.load(f) @@ -53,6 +62,28 @@ def create_receipt_contents(self, countersign_tbi: bytes, entry_id: str): def verify_receipt_contents(receipt_contents: list, countersign_tbi: bytes): raise NotImplementedError + def get_operation(self, operation_id: str) -> dict: + operation_path = self.operations_path / f"{operation_id}.json" + try: + with open(operation_path, "r") as f: + operation = json.load(f) + except FileNotFoundError: + raise EntryNotFoundError(f"Operation {operation_id} not found") + + if operation["status"] == "running": + # Pretend that the service finishes the operation after + # the client having checked the operation status once. + operation = self._finish_operation(operation) + return operation + + def get_entry(self, entry_id: str) -> dict: + try: + self.get_claim(entry_id) + except EntryNotFoundError: + raise + # More metadata to follow in the future. + return { "entryId": entry_id } + def get_claim(self, entry_id: str) -> bytes: claim_path = self.storage_path / f"{entry_id}.cose" try: @@ -62,7 +93,13 @@ def get_claim(self, entry_id: str) -> bytes: raise EntryNotFoundError(f"Entry {entry_id} not found") return claim - def submit_claim(self, claim: bytes): + def submit_claim(self, claim: bytes, long_running=True) -> dict: + if long_running: + return self._create_operation(claim) + else: + return self._create_entry(claim) + + def _create_entry(self, claim: bytes) -> dict: last_entry_path = self.storage_path / "last_entry_id.txt" if last_entry_path.exists(): with open(last_entry_path, "r") as f: @@ -70,21 +107,59 @@ def submit_claim(self, claim: bytes): else: last_entry_id = 0 - entry_id = last_entry_id + 1 + entry_id = str(last_entry_id + 1) self._create_receipt(claim, entry_id) + last_entry_path.write_text(entry_id) + claim_path = self.storage_path / f"{entry_id}.cose" + claim_path.write_bytes(claim) + + print(f"Claim written to {claim_path}") + + entry = {"entryId": entry_id} + return entry + + def _create_operation(self, claim: bytes): + operation_id = str(uuid.uuid4()) + operation_path = self.operations_path / f"{operation_id}.json" + claim_path = self.operations_path / f"{operation_id}.cose" + + operation = { + "operationId": operation_id, + "status": "running" + } + + with open(operation_path, "w") as f: + json.dump(operation, f) + with open(claim_path, "wb") as f: f.write(claim) + + print(f"Operation {operation_id} created") print(f"Claim written to {claim_path}") - with open(last_entry_path, "w") as f: - f.write(str(entry_id)) + return operation + + def _finish_operation(self, operation: dict): + operation_id = operation["operationId"] + operation_path = self.operations_path / f"{operation_id}.json" + claim_src_path = self.operations_path / f"{operation_id}.cose" + + claim = claim_src_path.read_bytes() + entry = self._create_entry(claim) + claim_src_path.unlink() + + operation["status"] = "succeeded" + operation["entryId"] = entry["entryId"] + + with open(operation_path, "w") as f: + json.dump(operation, f) - return entry_id + return operation - def _create_receipt(self, claim: Path, entry_id: str): + def _create_receipt(self, claim: bytes, entry_id: str): # Validate claim # Note: This emulator does not verify the claim signature and does not apply # registration policies. diff --git a/scitt_emulator/server.py b/scitt_emulator/server.py index 825543f1..094d0b6a 100644 --- a/scitt_emulator/server.py +++ b/scitt_emulator/server.py @@ -4,25 +4,28 @@ import os from pathlib import Path from io import BytesIO +import random from flask import Flask, request, send_file, make_response from scitt_emulator.tree_algs import TREE_ALGS -from scitt_emulator.scitt import EntryNotFoundError, ClaimInvalidError +from scitt_emulator.scitt import EntryNotFoundError, ClaimInvalidError, OperationNotFoundError def make_error(code: str, msg: str, status_code: int): return make_response( { - "error": { - "code": code, - "message": msg, - } + "type": f"urn:ietf:params:scitt:error:{code}", + "detail": msg, }, status_code, ) +def make_unavailable_error(): + return make_error("serviceUnavailable", "Service unavailable, try again later", 503) + + def create_flask_app(config): app = Flask(__name__) @@ -30,6 +33,9 @@ def create_flask_app(config): app.config.update(dict(DEBUG=True)) app.config.update(config) + error_rate = app.config["error_rate"] + use_lro = app.config["use_lro"] + workspace_path = app.config["workspace"] storage_path = workspace_path / "storage" os.makedirs(storage_path, exist_ok=True) @@ -43,29 +49,63 @@ def create_flask_app(config): app.scitt_service.initialize_service() print(f"Service parameters: {app.service_parameters_path}") + def is_unavailable(): + return random.random() <= error_rate + @app.route("/entries//receipt", methods=["GET"]) def get_receipt(entry_id: str): + if is_unavailable(): + return make_unavailable_error() try: receipt = app.scitt_service.get_receipt(entry_id) except EntryNotFoundError as e: - return make_error("EntryNotFoundError", str(e), 404) + return make_error("entryNotFound", str(e), 404) return send_file(BytesIO(receipt), download_name=f"{entry_id}.receipt.cbor") @app.route("/entries/", methods=["GET"]) def get_claim(entry_id: str): + if is_unavailable(): + return make_unavailable_error() try: claim = app.scitt_service.get_claim(entry_id) except EntryNotFoundError as e: - return make_error("EntryNotFoundError", str(e), 404) + return make_error("entryNotFound", str(e), 404) return send_file(BytesIO(claim), download_name=f"{entry_id}.cose") @app.route("/entries", methods=["POST"]) def submit_claim(): + if is_unavailable(): + return make_unavailable_error() try: - entry_id = app.scitt_service.submit_claim(request.get_data()) + if use_lro: + result = app.scitt_service.submit_claim(request.get_data(), long_running=True) + headers = { + "Location": f"{request.host_url}/operations/{result['operationId']}", + "Retry-After": "1" + } + status_code = 202 + else: + result = app.scitt_service.submit_claim(request.get_data(), long_running=False) + headers = { + "Location": f"{request.host_url}/entries/{result['entryId']}", + } + status_code = 201 except ClaimInvalidError as e: - return make_error("ClaimInvalidError", str(e), 400) - return make_response({"entry_id": entry_id}) + return make_error("invalidInput", str(e), 400) + return make_response(result, status_code, headers) + + @app.route("/operations/", methods=["GET"]) + def get_operation(operation_id: str): + if is_unavailable(): + return make_unavailable_error() + try: + operation = app.scitt_service.get_operation(operation_id) + except OperationNotFoundError as e: + return make_error("operationNotFound", str(e), 404) + headers = {} + if operation["status"] == "running": + headers["Retry-After"] = "1" + return make_response(operation, 200, headers) return app @@ -73,6 +113,8 @@ def submit_claim(): def cli(fn): parser = fn() parser.add_argument("-p", "--port", type=int, default=8000) + parser.add_argument("--error-rate", type=float, default=0.01) + parser.add_argument("--use-lro", action="store_true", help="Create operations for submissions") parser.add_argument("--tree-alg", required=True, choices=list(TREE_ALGS.keys())) parser.add_argument("--workspace", type=Path, default=Path("workspace")) @@ -81,6 +123,8 @@ def cmd(args): { "tree_alg": args.tree_alg, "workspace": args.workspace, + "error_rate": args.error_rate, + "use_lro": args.use_lro } ) app.run(host="0.0.0.0", port=args.port) diff --git a/tests/test_cli.py b/tests/test_cli.py index 61d6fd69..f04f2cf5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import os import threading +import pytest from werkzeug.serving import make_server from scitt_emulator import cli, server @@ -21,7 +22,10 @@ def __init__(self, config): def __enter__(self): app = server.create_flask_app(self.config) self.service_parameters_path = app.service_parameters_path - self.server = make_server("127.0.0.1", 8000, app) + host = "127.0.0.1" + self.server = make_server(host, 0, app) + port = self.server.port + self.url = f"http://{host}:{port}" self.thread = threading.Thread(name="server", target=self.server.serve_forever) self.thread.start() return self @@ -30,8 +34,10 @@ def __exit__(self, *args): self.server.shutdown() self.thread.join() - -def test_client_cli(tmp_path): +@pytest.mark.parametrize( + "use_lro", [True, False], +) +def test_client_cli(use_lro: bool, tmp_path): workspace_path = tmp_path / "workspace" claim_path = tmp_path / "claim.cose" @@ -43,6 +49,8 @@ def test_client_cli(tmp_path): { "tree_alg": "CCF", "workspace": workspace_path, + "error_rate": 0.1, + "use_lro": use_lro } ) as service: # create claim @@ -71,6 +79,8 @@ def test_client_cli(tmp_path): receipt_path, "--out-entry-id", entry_id_path, + "--url", + service.url ] execute_cli(command) assert os.path.exists(receipt_path) @@ -100,6 +110,8 @@ def test_client_cli(tmp_path): entry_id, "--out", retrieved_claim_path, + "--url", + service.url ] execute_cli(command) assert os.path.exists(retrieved_claim_path) @@ -119,6 +131,8 @@ def test_client_cli(tmp_path): entry_id, "--out", receipt_path_2, + "--url", + service.url ] execute_cli(command) assert os.path.exists(receipt_path_2)