From c88429d41a24ac772ba94f9587960758c883b81a Mon Sep 17 00:00:00 2001 From: roshii Date: Fri, 6 Oct 2023 20:39:39 +0200 Subject: [PATCH] JWT authority fixes --- docs/api/wallet-rpc.yaml | 9 +--- src/jmclient/auth.py | 12 +++-- src/jmclient/wallet_rpc.py | 10 ++-- test/jmclient/test_auth.py | 18 ++++--- test/jmclient/test_wallet_rpc.py | 84 ++++++++++++++++++-------------- test/jmclient/test_websocket.py | 13 ++--- 6 files changed, 81 insertions(+), 65 deletions(-) diff --git a/docs/api/wallet-rpc.yaml b/docs/api/wallet-rpc.yaml index 0dc79ae14..e18df43b7 100644 --- a/docs/api/wallet-rpc.yaml +++ b/docs/api/wallet-rpc.yaml @@ -21,7 +21,7 @@ paths: On initially creating, unlocking or recovering a wallet, store both the refresh and access tokens, the latter is valid for only 30 minutes (must be used for any authenticated call) while the former is for 4 hours (can only be used in the refresh request parameters). Use /token endpoint on a regular basis to get new access and refresh tokens, ideally before access token expiration to avoid authentication errors and in any case, before refresh token expiration. The newly issued tokens must be used in subsequent calls since operation invalidates previously issued tokens. responses: '200': - $ref: '#/components/responses/RefreshToken-200-OK' + $ref: '#/components/responses/Token-200-OK' '400': $ref: '#/components/responses/400-BadRequest' requestBody: @@ -579,11 +579,6 @@ paths: required: true schema: type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/GetSeedResponse' responses: '200': $ref: '#/components/responses/GetSeed-200-OK' @@ -684,7 +679,7 @@ components: token_type: type: string expires_in: - type: int + type: integer scope: type: string refresh_token: diff --git a/src/jmclient/auth.py b/src/jmclient/auth.py index 84a259760..56d5627d8 100644 --- a/src/jmclient/auth.py +++ b/src/jmclient/auth.py @@ -1,5 +1,6 @@ import datetime import os +from base64 import b64encode import jwt @@ -19,6 +20,9 @@ def get_random_key(size: int = 16) -> str: return bintohex(os.urandom(size)) +def b64str(s: str) -> str: + return b64encode(s.encode()).decode() + class JMTokenAuthority: """Manage authorization tokens.""" @@ -57,13 +61,13 @@ def verify(self, token: str, *, token_type: str = "access"): if not self._scope <= token_claims: raise InvalidScopeError - def add_to_scope(self, *args: str): + def add_to_scope(self, *args: str, encoded: bool = True): for arg in args: - self._scope.add(arg) + self._scope.add(b64str(arg) if encoded else arg) - def discard_from_scope(self, *args: str): + def discard_from_scope(self, *args: str, encoded: bool = True): for arg in args: - self._scope.discard(arg) + self._scope.discard(b64str(arg) if encoded else arg) @property def scope(self): diff --git a/src/jmclient/wallet_rpc.py b/src/jmclient/wallet_rpc.py index c2ad79068..979c5d857 100644 --- a/src/jmclient/wallet_rpc.py +++ b/src/jmclient/wallet_rpc.py @@ -280,10 +280,10 @@ def stopSubServices(self): self.taker_finished(False) def auth_err(self, request, error, description=None): - request.setHeader("WWW-Authenticate", "Bearer") - request.setHeader("WWW-Authenticate", f'error="{error}"') + value = f'Bearer, error="{error}"' if description is not None: - request.setHeader("WWW-Authenticate", f'error_description="{description}"') + value += f', error_description="{description}"' + request.setHeader("WWW-Authenticate", value) return def err(self, request, message): @@ -305,7 +305,7 @@ def invalid_credentials(self, request, failure): @app.handle_errors(InvalidToken) def invalid_token(self, request, failure): request.setResponseCode(401) - return self.auth_err(request, "invalid_token", str(failure)) + return self.auth_err(request, "invalid_token", failure.getErrorMessage()) @app.handle_errors(InsufficientScope) def insufficient_scope(self, request, failure): @@ -643,7 +643,7 @@ def _mkerr(err, description=""): "The requested scope is invalid, unknown, malformed, " "or exceeds the scope granted by the resource owner.", ) - except auth.ExpiredSignatureError: + except Exception: return _mkerr( "invalid_grant", f"The provided {grant_type} is invalid, revoked, " diff --git a/test/jmclient/test_auth.py b/test/jmclient/test_auth.py index b1f65e466..68aea86db 100644 --- a/test/jmclient/test_auth.py +++ b/test/jmclient/test_auth.py @@ -6,7 +6,12 @@ import jwt import pytest -from jmclient.auth import ExpiredSignatureError, InvalidScopeError, JMTokenAuthority +from jmclient.auth import ( + ExpiredSignatureError, + InvalidScopeError, + JMTokenAuthority, + b64str, +) class TestJMTokenAuthority: @@ -17,7 +22,7 @@ class TestJMTokenAuthority: refresh_sig = copy.copy(token_auth.signature_key["refresh"]) validity = datetime.timedelta(hours=1) - scope = f"walletrpc {wallet_name}" + scope = f"walletrpc {b64str(wallet_name)}" @pytest.mark.parametrize( "sig, token_type", [(access_sig, "access"), (refresh_sig, "refresh")] @@ -83,15 +88,16 @@ def scope_equals(scope): def test_scope_operation(self): assert "walletrpc" in self.token_auth._scope - assert self.wallet_name in self.token_auth._scope + assert b64str(self.wallet_name) in self.token_auth._scope scope = copy.copy(self.token_auth._scope) s = "new_wallet" self.token_auth.add_to_scope(s) assert scope < self.token_auth._scope - assert s in self.token_auth._scope + assert b64str(s) in self.token_auth._scope - self.token_auth.discard_from_scope(s, "walletrpc") + self.token_auth.discard_from_scope(s) + self.token_auth.discard_from_scope("walletrpc", encoded=False) assert scope > self.token_auth._scope - assert s not in self.token_auth._scope + assert b64str(s) not in self.token_auth._scope diff --git a/test/jmclient/test_wallet_rpc.py b/test/jmclient/test_wallet_rpc.py index d0904874e..dee3b3cca 100644 --- a/test/jmclient/test_wallet_rpc.py +++ b/test/jmclient/test_wallet_rpc.py @@ -29,8 +29,7 @@ from commontest import make_wallets from test_coinjoin import make_wallets_to_list, sync_wallets -from test_websocket import (ClientTProtocol, test_tx_hex_1, - test_tx_hex_txid, test_token_authority) +from test_websocket import ClientTProtocol, test_tx_hex_1, test_tx_hex_txid pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind") @@ -41,10 +40,6 @@ jlog = get_log() class JMWalletDaemonT(JMWalletDaemon): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.token = test_token_authority - def check_cookie(self, request, *args, **kwargs): if self.auth_disabled: return True @@ -220,6 +215,7 @@ def test_notif(self): "ws://127.0.0.1:"+str(self.wss_port), delay=0.1, callbackfn=self.fire_tx_notif) self.client_factory.protocol = ClientNotifTestProto + self.client_factory.protocol.ACCESS_TOKEN = self.daemon.token.issue()["token"].encode("utf8") self.client_connector = connectWS(self.client_factory) self.attempt_receipt_counter = 0 return task.deferLater(reactor, 0.0, self.wait_to_receive) @@ -754,22 +750,28 @@ def process_get_seed_response(self, response, code): class TrialTestWRPC_JWT(WalletRPCTestBase, unittest.TestCase): + @defer.inlineCallbacks + def do_request(self, agent, method, addr, body, handler, token): + headers = Headers({"Authorization": ["Bearer " + token]}) + response = yield agent.request(method, addr, headers, bodyProducer=body) + handler(response) + def get_token(self, grant_type: str, status: str = "valid"): now, delta = datetime.datetime.utcnow(), datetime.timedelta(hours=1) exp = now - delta if status == "expired" else now + delta scope = f"walletrpc {self.daemon.wallet_name}" if status == "invalid_scope": - scope = "walletrpc another_wallet" + scope = status - alg = test_token_authority.SIGNATURE_ALGORITHM + alg = self.daemon.token.SIGNATURE_ALGORITHM if status == "invalid_alg": alg = ({"HS256", "HS384", "HS512"} - {alg}).pop() t = jwt.encode( {"exp": exp, "scope": scope}, - test_token_authority.signature_key[grant_type], - algorithm=test_token_authority.SIGNATURE_ALGORITHM, + self.daemon.token.signature_key[grant_type], + algorithm=alg, ) if status == "invalid_sig": @@ -792,22 +794,23 @@ def get_token(self, grant_type: str, status: str = "valid"): return t - def authorized_response_handler(self, response, code): - assert code == 200 + def authorized_response_handler(self, response): + assert response.code == 200 - def forbidden_response_handler(self, response, code): - assert code == 403 - assert "insufficient_scope" in response.headers.get("WWW-Authenticate") + def forbidden_response_handler(self, response): + assert response.code == 403 + assert "insufficient_scope" in response.headers.getRawHeaders("WWW-Authenticate").pop() - def unauthorized_response_handler(self, response, code): - assert code == 401 - assert "Bearer" in response.headers.get("WWW-Authenticate") + def unauthorized_response_handler(self, response): + assert response.code == 401 + assert "Bearer" in response.headers.getRawHeaders("WWW-Authenticate").pop() - def expired_access_token_response_handler(self, response, code): - self.unauthorized_response_handler(response, code) - assert "expired" in response.headers.get("WWW-Authenticate") + def expired_access_token_response_handler(self, response): + self.unauthorized_response_handler(response) + assert "expired" in response.headers.getRawHeaders("WWW-Authenticate").pop() - async def test_jwt_authentication(self): + @defer.inlineCallbacks + def test_jwt_authentication(self): """Test JWT authentication and authorization""" agent = get_nontor_agent() @@ -828,31 +831,37 @@ async def test_jwt_authentication(self): }[responde_handler] token = self.get_token("access", access_token_status) - await self.do_request(agent, b"GET", addr, None, handler, token) + yield self.do_request(agent, b"GET", addr, None, handler, token) - def successful_refresh_response_handler(self, response, code): - self.authorized_response_handler(response, code) - json_body = json.loads(response.decode("utf-8")) + @defer.inlineCallbacks + def successful_refresh_response_handler(self, response): + self.authorized_response_handler(response) + body = yield readBody(response) + json_body = json.loads(body.decode("utf-8")) assert {"token", "refresh_token", "expires_in", "token_type", "scope"} <= set( json_body.keys() ) + @defer.inlineCallbacks def failed_refresh_response_handler( - self, response, code, *, message=None, error_description=None + self, response, *, message=None, error_description=None ): - assert code == 400 - json_body = json.loads(response.decode("utf-8")) + assert response.code == 400 + body = yield readBody(response) + json_body = json.loads(body.decode("utf-8")) if message is not None: assert json_body.get("message") == message if error_description is not None: assert error_description in json_body.get("error_description") - async def do_refresh_request(self, body, handler, token): + @defer.inlineCallbacks + def do_refresh_request(self, body, handler, token): agent = get_nontor_agent() addr = (self.get_route_root() + "/token").encode() body = BytesProducer(json.dumps(body).encode()) - await self.do_request(agent, b"POST", addr, body, handler, token) + yield self.do_request(agent, b"POST", addr, body, handler, token) + @defer.inlineCallbacks def test_refresh_token_request(self): """Test token endpoint with valid refresh token""" for access_token_status, request_status, error in [ @@ -864,7 +873,7 @@ def test_refresh_token_request(self): if error is None: handler = self.successful_refresh_response_handler else: - handler = functools.partialmethod( + handler = functools.partial( self.failed_refresh_response_handler, message=error ) @@ -877,11 +886,12 @@ def test_refresh_token_request(self): if request_status == "unsupported_grant_type": body["grant_type"] = "joinmarket" - self.do_refresh_request( + yield self.do_refresh_request( body, handler, self.get_token("access", access_token_status) ) - async def test_refresh_token(self): + @defer.inlineCallbacks + def test_refresh_token(self): """Test refresh token endpoint""" for refresh_token_status, error in [ ("expired", "expired"), @@ -889,11 +899,11 @@ async def test_refresh_token(self): ("invalid_sig", "invalid_grant"), ]: if error == "expired": - handler = functools.partialmethod( + handler = functools.partial( self.failed_refresh_response_handler, error_description=error ) else: - handler = functools.partialmethod( + handler = functools.partial( self.failed_refresh_response_handler, message=error ) @@ -902,7 +912,7 @@ async def test_refresh_token(self): "refresh_token": self.get_token("refresh", refresh_token_status), } - self.do_refresh_request(body, handler, self.get_token("access")) + yield self.do_refresh_request(body, handler, self.get_token("access")) """ diff --git a/test/jmclient/test_websocket.py b/test/jmclient/test_websocket.py index 38ba9d87c..8ac4ec6a0 100644 --- a/test/jmclient/test_websocket.py +++ b/test/jmclient/test_websocket.py @@ -21,7 +21,8 @@ test_tx_hex_txid = "ca606efc5ba8f6669ba15e9262e5d38e745345ea96106d5a919688d1ff0da0cc" # Shared JWT token authority for test: -test_token_authority = JMTokenAuthority("dummywallet") +token_authority = JMTokenAuthority() + class ClientTProtocol(WebSocketClientProtocol): """ @@ -29,11 +30,11 @@ class ClientTProtocol(WebSocketClientProtocol): message every 2 seconds and print everything it receives. """ + ACCESS_TOKEN = token_authority.issue()["token"].encode("utf8") + def sendAuth(self): - """ Our server will not broadcast - to us unless we authenticate. - """ - self.sendMessage(test_token_authority.issue()["token"].encode('utf8')) + """Our server will not broadcast to us unless we authenticate.""" + self.sendMessage(self.ACCESS_TOKEN) def onOpen(self): # auth on startup @@ -65,7 +66,7 @@ def setUp(self): free_ports = get_free_tcp_ports(1) self.wss_port = free_ports[0] self.wss_url = "ws://127.0.0.1:" + str(self.wss_port) - self.wss_factory = JmwalletdWebSocketServerFactory(self.wss_url, test_token_authority) + self.wss_factory = JmwalletdWebSocketServerFactory(self.wss_url, token_authority) self.wss_factory.protocol = JmwalletdWebSocketServerProtocol self.listeningport = listenWS(self.wss_factory, contextFactory=None) self.test_tx = CTransaction.deserialize(hextobin(test_tx_hex_1))