Skip to content

Commit

Permalink
add spotinst token renewer
Browse files Browse the repository at this point in the history
  • Loading branch information
sigmarkarl committed Nov 3, 2023
1 parent aeb6f9d commit 76ac37f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
14 changes: 13 additions & 1 deletion jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations

import contextvars
import functools
import inspect
import ipaddress
Expand Down Expand Up @@ -43,6 +44,7 @@
if TYPE_CHECKING:
from jupyter_server.auth.identity import User

_current_request_var: contextvars.ContextVar = contextvars.ContextVar("current_request")
# -----------------------------------------------------------------------------
# Top-level handlers
# -----------------------------------------------------------------------------
Expand All @@ -69,6 +71,9 @@ def log():
class AuthenticatedHandler(web.RequestHandler):
"""A RequestHandler with an authenticated user."""

def prepare(self):
_current_request_var.set(self.request)

@property
def base_url(self) -> str:
return self.settings.get("base_url", "/")
Expand All @@ -89,7 +94,7 @@ def content_security_policy(self):
# Make sure the report-uri is relative to the base_url
"report-uri "
+ self.settings.get("csp_report_uri", url_path_join(self.base_url, csp_report_uri)),
]
]
)

def set_default_headers(self):
Expand Down Expand Up @@ -1098,6 +1103,13 @@ def get(self):
self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY))


def get_current_request():
"""
Get :class:`tornado.httputil.HTTPServerRequest` that is currently being processed.
"""
return _current_request_var.get(None)


# -----------------------------------------------------------------------------
# URL pattern fragments for re-use
# -----------------------------------------------------------------------------
Expand Down
43 changes: 43 additions & 0 deletions jupyter_server/gateway/spottokenrenewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import typing as ty

import logging
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase
import jupyter_server.base.handlers
import jupyter_server.serverapp


def get_header_value(request: ty.Any, header: str) -> str:
if header not in request.headers:
logging.error(f'Header "{header}" is missing')
return ""
logging.debug(f'Getting value from header "{header}"')
value = request.headers[header]
if len(value) == 0:
logging.error(f'Header "{header}" is empty')
return ""
return value


class SpotTokenRenewer(GatewayTokenRenewerBase):

def get_token(
self,
auth_header_key: str,
auth_scheme: ty.Union[str, None],
auth_token: str,
**kwargs: ty.Any,
) -> str:
request = jupyter_server.base.handlers.get_current_request()
if request is None:
logging.error("Could not get current request")
return auth_token

auth_header_value = get_header_value(request, auth_header_key)
if auth_header_value:
try:
# We expect the header value to be of the form "Bearer: XXX"
auth_token = auth_header_value.split(" ", maxsplit=1)[1]
except Exception as e:
logging.error(f"Could not read token from auth header: {str(e)}")

return auth_token

0 comments on commit 76ac37f

Please sign in to comment.