Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure token refresh is always scheduled #6802

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions panel/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ async def _fetch_access_token(
return None, access_token, refresh_token, expires_in
elif id_token:= body.get('id_token'):
try:
user = self._on_auth(id_token, access_token, refresh_token, expires_in)
user = OAuthLoginHandler._on_auth(self, id_token, access_token, refresh_token, expires_in)
except HTTPError:
pass
else:
Expand Down Expand Up @@ -283,7 +283,7 @@ async def _fetch_access_token(
self._raise_error(response, body, status=401)

log.debug("%s successfully obtained access_token and userinfo.", type(self).__name__)
user = self._on_auth(id_token, access_token, refresh_token, expires_in)
user = OAuthLoginHandler._on_auth(self, id_token, access_token, refresh_token, expires_in)
return user, access_token, refresh_token, expires_in

def get_state_cookie(self):
Expand Down Expand Up @@ -400,34 +400,41 @@ async def get(self):
self.set_state_cookie(state)
await self.get_authenticated_user(**params)

def _on_auth(self, id_token, access_token, refresh_token=None, expires_in=None):
if isinstance(id_token, str):
decoded = decode_token(id_token)
else:
decoded = id_token
id_token = base64url_encode(json.dumps(id_token))
user_key = config.oauth_jwt_user or self._USER_KEY
if user_key in decoded:
user = decoded[user_key]
@staticmethod
def _on_auth(handler, id_token, access_token, refresh_token=None, expires_in=None):
if id_token:
if isinstance(id_token, str):
decoded = decode_token(id_token)
else:
decoded = id_token
id_token = base64url_encode(json.dumps(id_token))
user_key = config.oauth_jwt_user or handler._USER_KEY
if user_key in decoded:
user = decoded[user_key]
else:
log.error("%s token payload did not contain expected %r.",
type(handler).__name__, user_key)
raise HTTPError(401, "OAuth token payload missing user information")
handler.clear_cookie('is_guest')
handler.set_secure_cookie('user', user, expires_days=config.oauth_expiry)
else:
log.error("%s token payload did not contain expected %r.",
type(self).__name__, user_key)
raise HTTPError(401, "OAuth token payload missing user information")
self.clear_cookie('is_guest')
self.set_secure_cookie('user', user, expires_days=config.oauth_expiry)
user = None

if state.encryption:
access_token = state.encryption.encrypt(access_token.encode('utf-8'))
id_token = state.encryption.encrypt(id_token.encode('utf-8'))
if id_token:
id_token = state.encryption.encrypt(id_token.encode('utf-8'))
if refresh_token:
refresh_token = state.encryption.encrypt(refresh_token.encode('utf-8'))
self.set_secure_cookie('access_token', access_token, expires_days=config.oauth_expiry)
self.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry)
handler.set_secure_cookie('access_token', access_token, expires_days=config.oauth_expiry)
if id_token:
handler.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry)
if expires_in:
now_ts = dt.datetime.now(dt.timezone.utc).timestamp()
self.set_secure_cookie('oauth_expiry', str(int(now_ts + expires_in)), expires_days=config.oauth_expiry)
handler.set_secure_cookie('oauth_expiry', str(int(now_ts + expires_in)), expires_days=config.oauth_expiry)
if refresh_token:
self.set_secure_cookie('refresh_token', refresh_token, expires_days=config.oauth_expiry)
if user in state._oauth_user_overrides:
handler.set_secure_cookie('refresh_token', refresh_token, expires_days=config.oauth_expiry)
if user and user in state._oauth_user_overrides:
state._oauth_user_overrides.pop(user, None)
return user

Expand Down Expand Up @@ -1016,12 +1023,14 @@ async def get_user(handler):
refresh_cookie = handler.get_secure_cookie('refresh_token', max_age_days=config.oauth_expiry)
if refresh_cookie:
refresh_token = state._decrypt_cookie(refresh_cookie)
self._schedule_refresh(access_json['exp'], user, refresh_token, handler.application, handler.request)
else:
refresh_token = None

if expiry > now_ts:
if expiry > now_ts and refresh_token:
log.debug("Fully authenticated and access_token still valid.")
self._schedule_refresh(expiry, user, refresh_token, handler.application, handler.request)
expires_in = expiry - now_ts
OAuthLoginHandler._on_auth(handler, None, access_token, refresh_token, expires_in)
return user

if refresh_token:
Expand All @@ -1038,7 +1047,11 @@ async def get_user(handler):
return

log.debug("%s refreshing token", type(self).__name__)
await self._refresh_access_token(user, refresh_token, handler.application, handler.request)
access_token, refresh_token, expiry = await self._scheduled_refresh(
user, refresh_token, handler.application, handler.request
)
expires_in = expiry - now_ts
OAuthLoginHandler._on_auth(handler, None, access_token, refresh_token, expires_in)
return user
return get_user

Expand Down Expand Up @@ -1076,13 +1089,14 @@ def _schedule_refresh(self, expiry_ts, user, refresh_token, application, request
if not state._active_users.get(user):
return
now_ts = dt.datetime.now(dt.timezone.utc).timestamp()
expiry_seconds = expiry_ts - now_ts - 10
log.debug("%s scheduling token refresh in %d seconds", type(self).__name__, expiry_seconds)
expiry_seconds = expiry_ts - now_ts - 60
expiry_date = dt.datetime.now() + dt.timedelta(seconds=expiry_seconds) # schedule_task is in local TZ
refresh_cb = partial(self._scheduled_refresh, user, refresh_token, application, request)
if expiry_seconds <= 0:
log.debug("%s token expired unexpectedly, refreshing immediately.", type(self).__name__, expiry_seconds)
philippjfr marked this conversation as resolved.
Show resolved Hide resolved
state.execute(refresh_cb)
return
log.debug("%s scheduling token refresh in %d seconds", type(self).__name__, expiry_seconds)
task = f'{user}-refresh-access-tokens'
try:
state.cancel_task(task)
Expand All @@ -1100,6 +1114,7 @@ async def _scheduled_refresh(self, user, refresh_token, application, request):
else:
expiry = decode_token(access_token)['exp']
self._schedule_refresh(expiry, user, refresh_token, application, request)
return access_token, refresh_token, expiry

async def _refresh_access_token(self, user, refresh_token, application, request):
if user in state._oauth_user_overrides:
Expand Down
Loading