From a14c6ff9189fdfe4dfc4507bdcf6a52e5ec0b79c Mon Sep 17 00:00:00 2001 From: Lily Kuang Date: Tue, 15 Feb 2022 10:47:42 -0800 Subject: [PATCH 1/3] feat(embedded): make guest token JWT audience callable --- superset/config.py | 3 ++- superset/security/manager.py | 13 ++++++++----- tests/integration_tests/security_tests.py | 21 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/superset/config.py b/superset/config.py index 2637c0032bef5..09ef4e71a2d54 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1316,7 +1316,8 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument GUEST_TOKEN_JWT_ALGO = "HS256" GUEST_TOKEN_HEADER_NAME = "X-GuestToken" GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes -GUEST_TOKEN_JWT_AUDIENCE = None +# Guest token audience for the embedded superset, either string or callable +GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None # A SQL dataset health check. Note if enabled it is strongly advised that the callable # be memoized to aid with performance, i.e., diff --git a/superset/security/manager.py b/superset/security/manager.py index ac494a1837827..3fe212b0bcef5 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1309,8 +1309,9 @@ def create_guest_access_token( secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] exp_seconds = current_app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] - aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() - + audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() + if callable(audience): + audience = audience() # calculate expiration time now = self._get_current_epoch_time() exp = now + (exp_seconds * 1000) @@ -1321,7 +1322,7 @@ def create_guest_access_token( # standard jwt claims: "iat": now, # issued at "exp": exp, # expiration time - "aud": aud, + "aud": audience, "type": "guest", } token = jwt.encode(claims, secret, algorithm=algo) @@ -1372,8 +1373,10 @@ def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]: """ secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] - aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() - return jwt.decode(raw_token, secret, algorithms=[algo], audience=aud) + audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() + if callable(audience): + audience = audience() + return jwt.decode(raw_token, secret, algorithms=[algo], audience=audience) @staticmethod def is_guest_user(user: Optional[Any] = None) -> bool: diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index efcd191ffafc1..6562c36632dc7 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1299,3 +1299,24 @@ def test_get_guest_user_bad_audience(self): self.assertRaisesRegex(jwt.exceptions.InvalidAudienceError, "Invalid audience") self.assertIsNone(guest_user) + + @patch("superset.security.SupersetSecurityManager._get_current_epoch_time") + def test_create_guest_access_token_callable_audience(self, get_time_mock): + now = time.time() + get_time_mock.return_value = now + app.config["GUEST_TOKEN_JWT_AUDIENCE"] = Mock(return_value="cool_code") + + user = {"username": "test_guest"} + resources = [{"some": "resource"}] + rls = [{"dataset": 1, "clause": "access = 1"}] + token = security_manager.create_guest_access_token(user, resources, rls) + + decoded_token = jwt.decode( + token, + self.app.config["GUEST_TOKEN_JWT_SECRET"], + algorithms=[self.app.config["GUEST_TOKEN_JWT_ALGO"]], + audience="cool_code", + ) + app.config["GUEST_TOKEN_JWT_AUDIENCE"].assert_called_once() + self.assertEqual("cool_code", decoded_token["aud"]) + self.assertEqual("guest", decoded_token["type"]) From 32da50ada5a6d156c08a360eb4fa57bca9cd26ad Mon Sep 17 00:00:00 2001 From: Lily Kuang Date: Tue, 15 Feb 2022 14:14:08 -0800 Subject: [PATCH 2/3] reset GUEST_TOKEN_JWT_AUDIENCE after test --- tests/integration_tests/security_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 6562c36632dc7..9dca5ac51375c 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1320,3 +1320,4 @@ def test_create_guest_access_token_callable_audience(self, get_time_mock): app.config["GUEST_TOKEN_JWT_AUDIENCE"].assert_called_once() self.assertEqual("cool_code", decoded_token["aud"]) self.assertEqual("guest", decoded_token["type"]) + app.config["GUEST_TOKEN_JWT_AUDIENCE"] = None From 7563982586539165ada1ed00c320ab76104bb719 Mon Sep 17 00:00:00 2001 From: Lily Kuang Date: Tue, 15 Feb 2022 15:44:14 -0800 Subject: [PATCH 3/3] helper method for get audience --- superset/security/manager.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/superset/security/manager.py b/superset/security/manager.py index 3fe212b0bcef5..91b203e83f774 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1300,6 +1300,13 @@ def _get_current_epoch_time() -> float: """ This is used so the tests can mock time """ return time.time() + @staticmethod + def _get_guest_token_jwt_audience() -> str: + audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() + if callable(audience): + audience = audience() + return audience + def create_guest_access_token( self, user: GuestTokenUser, @@ -1309,9 +1316,7 @@ def create_guest_access_token( secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] exp_seconds = current_app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] - audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() - if callable(audience): - audience = audience() + audience = self._get_guest_token_jwt_audience() # calculate expiration time now = self._get_current_epoch_time() exp = now + (exp_seconds * 1000) @@ -1364,8 +1369,7 @@ def get_guest_user_from_token(self, token: GuestToken) -> GuestUser: token=token, roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], ) - @staticmethod - def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]: + def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]: """ Parses a guest token. Raises an error if the jwt fails standard claims checks. :param raw_token: the token gotten from the request @@ -1373,9 +1377,7 @@ def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]: """ secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] - audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() - if callable(audience): - audience = audience() + audience = self._get_guest_token_jwt_audience() return jwt.decode(raw_token, secret, algorithms=[algo], audience=audience) @staticmethod