From ec4452f2e502f46e8dcdbf5ff399fddd81fa0e10 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 10 Jul 2024 16:00:26 -0500 Subject: [PATCH] Remove expired JWTs as needed --- singlestoredb/management/utils.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/singlestoredb/management/utils.py b/singlestoredb/management/utils.py index 9431ceb07..4e085289d 100644 --- a/singlestoredb/management/utils.py +++ b/singlestoredb/management/utils.py @@ -123,14 +123,14 @@ def get(self, name_or_id: str, *default: Any) -> Any: def _setup_authentication_info_handler() -> Callable[..., Dict[str, Any]]: """Setup authentication info event handler.""" - authentication_info: List[Tuple[str, Any]] = [] + authentication_info: Dict[str, Any] = {} def handle_authentication_info(msg: Dict[str, Any]) -> None: """Handle authentication info events.""" nonlocal authentication_info if msg.get('name', '') != 'singlestore.portal.authentication_updated': return - authentication_info = list(msg.get('data', {}).items()) + authentication_info = dict(msg.get('data', {})) events.subscribe(handle_authentication_info) @@ -145,11 +145,27 @@ def handle_connection_info(msg: Dict[str, Any]) -> None: out['user'] = data['user'] if 'password' in data: out['password'] = data['password'] - authentication_info = list(out.items()) + authentication_info = out events.subscribe(handle_authentication_info) + def retrieve_current_authentication_info() -> List[Tuple[str, Any]]: + """Retrieve JWT if not expired.""" + nonlocal authentication_info + password = authentication_info.get('password') + if password: + expires = datetime.datetime.fromtimestamp( + jwt.decode( + password, + options={'verify_signature': False}, + )['exp'], + ) + if datetime.datetime.now() > expires: + authentication_info = {} + return list(authentication_info.items()) + def get_env() -> List[Tuple[str, Any]]: + """Retrieve JWT from environment.""" conn = {} url = os.environ.get('SINGLESTOREDB_URL') or get_option('host') if url: @@ -170,7 +186,7 @@ def get_authentication_info(include_env: bool = True) -> Dict[str, Any]: return dict( itertools.chain( (get_env() if include_env else []), - authentication_info, + retrieve_current_authentication_info(), ), )