diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 8aec2fe128..0656e44e80 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -3,6 +3,7 @@ # Distributed under the terms of the Modified BSD License. from __future__ import annotations +import contextvars import functools import inspect import ipaddress @@ -43,6 +44,7 @@ if TYPE_CHECKING: from jupyter_server.auth.identity import User +_current_request_var: contextvars.ContextVar = contextvars.ContextVar("current_request") # ----------------------------------------------------------------------------- # Top-level handlers # ----------------------------------------------------------------------------- @@ -69,6 +71,9 @@ def log(): class AuthenticatedHandler(web.RequestHandler): """A RequestHandler with an authenticated user.""" + def prepare(self): + _current_request_var.set(self.request) + @property def base_url(self) -> str: return self.settings.get("base_url", "/") @@ -89,7 +94,7 @@ def content_security_policy(self): # Make sure the report-uri is relative to the base_url "report-uri " + self.settings.get("csp_report_uri", url_path_join(self.base_url, csp_report_uri)), - ] + ] ) def set_default_headers(self): @@ -1098,6 +1103,13 @@ def get(self): self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY)) +def get_current_request(): + """ + Get :class:`tornado.httputil.HTTPServerRequest` that is currently being processed. + """ + return _current_request_var.get(None) + + # ----------------------------------------------------------------------------- # URL pattern fragments for re-use # ----------------------------------------------------------------------------- diff --git a/jupyter_server/gateway/spottokenrenewer.py b/jupyter_server/gateway/spottokenrenewer.py new file mode 100644 index 0000000000..f3d1f0e927 --- /dev/null +++ b/jupyter_server/gateway/spottokenrenewer.py @@ -0,0 +1,43 @@ +import typing as ty + +import logging +from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase +import jupyter_server.base.handlers +import jupyter_server.serverapp + + +def get_header_value(request: ty.Any, header: str) -> str: + if header not in request.headers: + logging.error(f'Header "{header}" is missing') + return "" + logging.debug(f'Getting value from header "{header}"') + value = request.headers[header] + if len(value) == 0: + logging.error(f'Header "{header}" is empty') + return "" + return value + + +class SpotTokenRenewer(GatewayTokenRenewerBase): + + def get_token( + self, + auth_header_key: str, + auth_scheme: ty.Union[str, None], + auth_token: str, + **kwargs: ty.Any, + ) -> str: + request = jupyter_server.base.handlers.get_current_request() + if request is None: + logging.error("Could not get current request") + return auth_token + + auth_header_value = get_header_value(request, auth_header_key) + if auth_header_value: + try: + # We expect the header value to be of the form "Bearer: XXX" + auth_token = auth_header_value.split(" ", maxsplit=1)[1] + except Exception as e: + logging.error(f"Could not read token from auth header: {str(e)}") + + return auth_token