Skip to content

Commit

Permalink
Improve retries + caching (#21)
Browse files Browse the repository at this point in the history
* Improve request retry behavior
* Improve caching
- Cache `get_identity_groups` results
- Wait before trying to refresh again after a failed request
* Add test for failing `get_api_session`
  • Loading branch information
ThiefMaster authored Jul 3, 2024
1 parent 2dba5eb commit e6e734c
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 34 deletions.
62 changes: 40 additions & 22 deletions flask_multipass_cern.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
CACHE_LONG_TTL = 86400 * 7
CACHE_TTL = 1800
CERN_OIDC_WELLKNOWN_URL = 'https://auth.cern.ch/auth/realms/cern/.well-known/openid-configuration'
HTTP_RETRY_COUNT = 5

# not sure if retries are still needed, but by not using a backoff we don't risk taking down the site
# using this library in case the API is persistently failing with an error
HTTP_RETRY_COUNT = 2
retry_config = HTTPAdapter(max_retries=Retry(total=HTTP_RETRY_COUNT,
backoff_factor=0.5,
backoff_factor=0,
status_forcelist=[503, 504],
allowed_methods=frozenset(['GET']),
raise_on_status=False))
Expand Down Expand Up @@ -64,6 +65,9 @@ def set(self, key, value, timeout=0, refresh_timeout=None):
if refresh_timeout:
self.cache.set(f'{key}:timestamp', datetime.now(), refresh_timeout)

def delay_refresh(self, key, timeout):
self.cache.set(f'{key}:timestamp', datetime.now(), timeout)

def should_refresh(self, key):
if self.cache is None:
return True
Expand Down Expand Up @@ -163,21 +167,12 @@ def get_members(self):
yield IdentityInfo(self.provider, identifier, extra_data, **res)

def has_member(self, identifier):
cache = self.provider.cache
logger = self.provider.logger
cache_key = f'flask-multipass-cern:{self.provider.name}:groups:{identifier}'
all_groups = cache.get(cache_key)

if all_groups is None or cache.should_refresh(cache_key):
try:
all_groups = {g.name.lower() for g in self.provider.get_identity_groups(identifier)}
cache.set(cache_key, all_groups, CACHE_LONG_TTL, CACHE_TTL)
except RequestException:
logger.warning('Refreshing user groups failed for %s', identifier)
if all_groups is None:
logger.error('Getting user groups failed for %s, access will be denied', identifier)
return False

try:
all_groups = {g.name.lower() for g in self.provider.get_identity_groups(identifier)}
except RequestException:
# request failed and could not be satisfied from cache
self.provider.logger.error('Getting user groups failed for %s, access will be denied', identifier)
return False
if self.provider.settings['cern_users_group'] and self.name.lower() == 'cern users':
return self.provider.settings['cern_users_group'].lower() in all_groups
return self.name.lower() in all_groups
Expand Down Expand Up @@ -351,6 +346,7 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
except RequestException:
self.logger.warning('Refreshing identities failed for criteria %s (could not get API token)', criteria)
if use_cache and cached_data:
self.cache.delay_refresh(cache_key, CACHE_TTL)
return cached_results, cached_data[1]
else:
self.logger.error('Getting identities failed for criteria %s (could not get API token)', criteria)
Expand All @@ -363,6 +359,7 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
except RequestException:
self.logger.warning('Refreshing identities failed for criteria %s', criteria)
if use_cache and cached_data:
self.cache.delay_refresh(cache_key, CACHE_TTL)
return cached_results, cached_data[1]
else:
self.logger.error('Getting identities failed for criteria %s', criteria)
Expand All @@ -387,19 +384,40 @@ def search_identities_ex(self, criteria, exact=False, limit=None):
self.cache.set(cache_key, (cache_data, total), CACHE_LONG_TTL, CACHE_TTL * 2)
return identities, total

def get_identity_groups(self, identifier):
def _fetch_identity_group_names(self, identifier):
with self._get_api_session() as api_session:
identifier = identifier.replace('/', '%2F') # edugain identifiers sometimes contain slashes
resp = api_session.get(f'{self.authz_api_base}/api/v1.0/IdentityMembership/{identifier}/precomputed')
if resp.status_code == 404:
return set()
resp.raise_for_status()
results = resp.json()['data']
groups = {self.group_class(self, res['groupIdentifier']) for res in results}
if self.settings['cern_users_group'] and any(g.name == self.settings['cern_users_group'] for g in groups):
groups.add(self.group_class(self, 'CERN Users'))
groups = {res['groupIdentifier'] for res in results}
if self.settings['cern_users_group'] and any(g == self.settings['cern_users_group'] for g in groups):
groups.add('CERN Users')
return groups

def get_identity_groups(self, identifier):
cache_key = f'flask-multipass-cern:{self.name}:groups:{identifier}'
group_names = self.cache.get(cache_key)

if group_names is None or self.cache.should_refresh(cache_key):
try:
group_names = self._fetch_identity_group_names(identifier)
self.cache.set(cache_key, group_names, CACHE_LONG_TTL, CACHE_TTL)
except RequestException:
self.logger.warning('Refreshing user groups failed for %s', identifier)
if group_names is not None:
self.cache.delay_refresh(cache_key, CACHE_TTL)
else:
self.logger.error('Getting user groups failed for %s, request will fail', identifier)
raise

if self.settings['cern_users_group'] and any(g == self.settings['cern_users_group'] for g in group_names):
group_names.add('CERN Users')

return {self.group_class(self, g) for g in group_names}

def get_group(self, name):
return self.group_class(self, name)

Expand Down
21 changes: 9 additions & 12 deletions tests/test_has_member.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime
from unittest.mock import MagicMock

import pytest
from requests import Session
Expand All @@ -17,17 +16,15 @@ def mock_get_api_session(mocker):


@pytest.fixture
def mock_get_identity_groups(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider.get_identity_groups')
group = MagicMock()
group.name = 'cern users'
get_identity_groups.return_value = {group}
def mock_fetch_identity_group_names(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider._fetch_identity_group_names')
get_identity_groups.return_value = {'cern users'}
return get_identity_groups


@pytest.fixture
def mock_get_identity_groups_fail(mocker):
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider.get_identity_groups')
get_identity_groups = mocker.patch('flask_multipass_cern.CERNIdentityProvider._fetch_identity_group_names')
get_identity_groups.side_effect = RequestException()
return get_identity_groups

Expand All @@ -37,7 +34,7 @@ def spy_cache_set(mocker):
return mocker.spy(MemoryCache, 'set')


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_cache(provider):
test_group = CERNGroup(provider, 'cern users')
test_group.has_member('12345')
Expand All @@ -46,24 +43,24 @@ def test_has_member_cache(provider):
assert test_group.provider.cache.get('flask-multipass-cern:cip:groups:12345:timestamp')


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_cache_miss(provider, spy_cache_set):
test_group = CERNGroup(provider, 'cern users')
test_group.has_member('12345')

assert spy_cache_set.call_count == 2


def test_has_member_cache_hit(provider, mock_get_identity_groups):
def test_has_member_cache_hit(provider, mock_fetch_identity_group_names):
test_group = CERNGroup(provider, 'cern users')
test_group.provider.cache.set('flask-multipass-cern:cip:groups:12345', 'cern users')
test_group.provider.cache.set('flask-multipass-cern:cip:groups:12345:timestamp', datetime.now())
test_group.has_member('12345')

assert not mock_get_identity_groups.called
assert not mock_fetch_identity_group_names.called


@pytest.mark.usefixtures('mock_get_identity_groups')
@pytest.mark.usefixtures('mock_fetch_identity_group_names')
def test_has_member_request_fails(provider, mock_get_identity_groups_fail):
test_group = CERNGroup(provider, 'cern users')
res = test_group.has_member('12345')
Expand Down
19 changes: 19 additions & 0 deletions tests/test_search_identities_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,22 @@ def test_search_identities_cache_hit_stale(provider, mock_data, freeze_time):
assert mock_data[data_key] == identities[0][0].data.get(identities_key)
assert isinstance(identities[0][0], IdentityInfo)
assert identities[1] == 1


@pytest.mark.usefixtures('httpretty_enabled')
def test_search_identities_cache_hit_broken_sso(mocker, provider, mock_data, freeze_time):
get_api_session = mocker.patch('flask_multipass_cern.CERNIdentityProvider._get_api_session')
get_api_session.side_effect = RequestException()

test_uri = f'{provider.settings.get("authz_api")}/api/v1.0/Identity'
httpretty.register_uri(httpretty.GET, test_uri, status=401)
cache_key = 'flask-multipass-cern:cip:email-identities:test@cern.ch'
provider.cache.set(cache_key, ([mock_data], 1), 2000, 10)
freeze_time(datetime.now() + timedelta(seconds=100))

identities = provider.search_identities_ex({'primaryAccountEmail': {'test@cern.ch'}}, True)

for identities_key, data_key in provider.settings.get('mapping').items():
assert mock_data[data_key] == identities[0][0].data.get(identities_key)
assert isinstance(identities[0][0], IdentityInfo)
assert identities[1] == 1

0 comments on commit e6e734c

Please sign in to comment.