Skip to content

Commit

Permalink
Ensure refreshed tokens can be accessed across processes (#6817)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed May 9, 2024
1 parent 3348da7 commit 9e92736
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 24 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ jobs:
with:
name: unit_test_suite
python-version: ${{ matrix.python-version }}
channels: pyviz/label/dev,numba,bokeh/label/dev,conda-forge,nodefaults
channels: pyviz/label/dev,numba,conda-forge,nodefaults
conda-update: true
nodejs: true
nodejs-version: "20.9" # https://github.com/bokeh/bokeh/pull/13851
Expand Down Expand Up @@ -233,7 +233,7 @@ jobs:
with:
name: ui_test_suite
python-version: 3.9
channels: pyviz/label/dev,bokeh/label/dev,conda-forge,nodefaults
channels: pyviz/label/dev,conda-forge,nodefaults
envs: "-o recommended -o tests -o build"
cache: ${{ github.event.inputs.cache || github.event.inputs.cache == '' }}
nodejs: true
Expand Down
72 changes: 51 additions & 21 deletions panel/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tornado

from bokeh.server.auth_provider import AuthProvider
from bokeh.util.token import get_token_payload
from tornado.auth import OAuth2Mixin
from tornado.httpclient import HTTPError as HTTPClientError, HTTPRequest
from tornado.web import HTTPError, RequestHandler, decode_signed_value
Expand Down Expand Up @@ -413,7 +414,7 @@ def set_auth_cookies(handler, id_token, access_token, refresh_token=None, expire
type(handler).__name__, user_key)
raise HTTPError(401, "OAuth token payload missing user information")
handler.clear_cookie('is_guest')
handler.set_secure_cookie('user', user, expires_days=config.oauth_expiry)
handler.set_secure_cookie('user', user, expires_days=config.oauth_expiry, httponly=True)
else:
user = None

Expand All @@ -423,14 +424,14 @@ def set_auth_cookies(handler, id_token, access_token, refresh_token=None, expire
id_token = state.encryption.encrypt(id_token.encode('utf-8'))
if refresh_token:
refresh_token = state.encryption.encrypt(refresh_token.encode('utf-8'))
handler.set_secure_cookie('access_token', access_token, expires_days=config.oauth_expiry)
handler.set_secure_cookie('access_token', access_token, expires_days=config.oauth_expiry, httponly=True)
if id_token:
handler.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry)
handler.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry, httponly=True)
if expires_in:
now_ts = dt.datetime.now(dt.timezone.utc).timestamp()
handler.set_secure_cookie('oauth_expiry', str(int(now_ts + expires_in)), expires_days=config.oauth_expiry)
handler.set_secure_cookie('oauth_expiry', str(int(now_ts + expires_in)), expires_days=config.oauth_expiry, httponly=True)
if refresh_token:
handler.set_secure_cookie('refresh_token', refresh_token, expires_days=config.oauth_expiry)
handler.set_secure_cookie('refresh_token', refresh_token, expires_days=config.oauth_expiry, httponly=True)
if user and user in state._oauth_user_overrides:
state._oauth_user_overrides.pop(user, None)
return user
Expand Down Expand Up @@ -848,11 +849,11 @@ def set_current_user(self, user):
self.clear_cookie("user")
return
self.clear_cookie("is_guest")
self.set_secure_cookie("user", user, expires_days=config.oauth_expiry)
self.set_secure_cookie("user", user, expires_days=config.oauth_expiry, httponly=True)
id_token = base64url_encode(json.dumps({'user': user}))
if state.encryption:
id_token = state.encryption.encrypt(id_token.encode('utf-8'))
self.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry)
self.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry, httponly=True)


class LogoutHandler(tornado.web.RequestHandler):
Expand Down Expand Up @@ -987,6 +988,20 @@ async def get_user(handler):
if not config.oauth_refresh_tokens or user is None:
return user

# Try to obtain user oauth overrides from WS headers
# in case the HTTP handler refreshed tokens
is_ws = isinstance(handler, WebSocketHandler)
if is_ws and 'Sec-Websocket-Protocol' in handler.request.headers:
protocol_header = handler.request.headers['Sec-Websocket-Protocol']
_, token = protocol_header.split(', ')
payload = get_token_payload(token)
if 'user_data' in payload:
user_data = payload['user_data']
if state.encryption:
user_data = state.encryption.decrypt(user_data).decode('utf-8')
user_data = json.loads(user_data)
state._oauth_user_overrides[user] = user_data

now_ts = dt.datetime.now(dt.timezone.utc).timestamp()
expiry = None
if user in state._oauth_user_overrides:
Expand All @@ -1003,16 +1018,20 @@ async def get_user(handler):
return
access_token = state._decrypt_cookie(access_cookie)

# Try to get expiry directly from the token since that is
# the real source of truth
try:
access_json = decode_token(access_token)
expiry = access_json['exp']
except Exception:
pass

if expiry is None:
try:
access_json = decode_token(access_token)
expiry = access_json['exp']
except Exception:
expiry = handler.get_secure_cookie('oauth_expiry', max_age_days=config.oauth_expiry)
if expiry is None:
# Token does not have content and therefore does not expire
log.debug("access_token is not a valid JWT token. Expiry cannot be determined.")
return user
expiry = handler.get_secure_cookie('oauth_expiry', max_age_days=config.oauth_expiry)
if expiry is None:
# Token does not have content and therefore does not expire
log.debug("access_token is not a valid JWT token. Expiry cannot be determined.")
return user

if user in state._oauth_user_overrides:
refresh_token = state._oauth_user_overrides[user]['refresh_token']
Expand All @@ -1025,7 +1044,8 @@ async def get_user(handler):

if expiry > now_ts and refresh_token:
log.debug("Fully authenticated and tokens still valid.")
self._schedule_refresh(expiry, user, refresh_token, handler.application, handler.request)
if is_ws:
self._schedule_refresh(expiry, user, refresh_token, handler.application, handler.request)
expires_in = expiry - now_ts
OAuthLoginHandler.set_auth_cookies(
handler, None, access_token, refresh_token, expires_in
Expand All @@ -1047,8 +1067,13 @@ async def get_user(handler):

log.debug("access_token has expired, %s using refresh_token to obtain new tokens.", type(self).__name__)
access_token, refresh_token, expiry = await self._scheduled_refresh(
user, refresh_token, handler.application, handler.request
user, refresh_token, handler.application, handler.request,
reschedule=is_ws
)
# If user not in overrides refresh failed and we need to
# fully reauthenticate
if user not in state._oauth_user_overrides:
return
expires_in = expiry - now_ts
OAuthLoginHandler.set_auth_cookies(
handler, None, access_token, refresh_token, expires_in
Expand Down Expand Up @@ -1106,15 +1131,18 @@ def _schedule_refresh(self, expiry_ts, user, refresh_token, application, request
finally:
state.schedule_task(task, refresh_cb, at=expiry_date)

async def _scheduled_refresh(self, user, refresh_token, application, request):
async def _scheduled_refresh(self, user, refresh_token, application, request, reschedule=True):
await self._refresh_access_token(user, refresh_token, application, request)
if user not in state._oauth_user_overrides:
return None, None, None
user_state = state._oauth_user_overrides[user]
access_token, refresh_token = user_state['access_token'], user_state['refresh_token']
if user_state['expiry']:
expiry = user_state['expiry']
else:
expiry = decode_token(access_token)['exp']
self._schedule_refresh(expiry, user, refresh_token, application, request)
if reschedule:
self._schedule_refresh(expiry, user, refresh_token, application, request)
return access_token, refresh_token, expiry

async def _refresh_access_token(self, user, refresh_token, application, request):
Expand All @@ -1126,7 +1154,7 @@ async def _refresh_access_token(self, user, refresh_token, application, request)
return
else:
refresh_token = state._oauth_user_overrides[user]['refresh_token']
log.debug("%s refreshing token", type(self).__name__)
log.debug("%s refreshing tokens", type(self).__name__)
state._oauth_user_overrides[user] = {}
auth_handler = self.login_handler(application=application, request=request)
_, access_token, refresh_token, expires_in = await auth_handler._fetch_access_token(
Expand All @@ -1135,13 +1163,15 @@ async def _refresh_access_token(self, user, refresh_token, application, request)
refresh_token=refresh_token
)
if access_token:
log.debug("%s successfully refreshed access_token", type(self).__name__)
now_ts = dt.datetime.now(dt.timezone.utc).timestamp()
state._oauth_user_overrides[user] = {
'access_token': access_token,
'refresh_token': refresh_token,
'expiry': now_ts+expires_in if expires_in else None
}
else:
log.debug("%s failed to refresh access_token", type(self).__name__)
del state._oauth_user_overrides[user]


Expand Down
26 changes: 25 additions & 1 deletion panel/io/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
"""
from __future__ import annotations

import json
import logging
import os

from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import bokeh.command.util

Expand Down Expand Up @@ -84,6 +85,29 @@ def _log_session_destroyed(session_context):
doc.on_event('document_ready', partial(state._schedule_on_load, doc))
doc.on_session_destroyed(_log_session_destroyed)

def process_request(self, request) -> dict[str, Any]:
''' Processes incoming HTTP request returning a dictionary of
additional data to add to the session_context.
Args:
request: HTTP request
Returns:
A dictionary of JSON serializable data to be included on
the session context.
'''
request_data = super().process_request(request)
user = request.cookies.get('user')
if user:
from tornado.web import decode_signed_value
user = decode_signed_value(config.cookie_secret, 'user', user.value).decode('utf-8')
if user in state._oauth_user_overrides:
user_data = json.dumps(state._oauth_user_overrides[user])
if state.encryption:
user_data = state.encryption.encrypt(user_data.encode('utf-8'))
request_data['user_data'] = user_data
return request_data

bokeh.command.util.Application = Application # type: ignore


Expand Down

0 comments on commit 9e92736

Please sign in to comment.