Skip to content

Commit

Permalink
Refactoring of gitlab auth class
Browse files Browse the repository at this point in the history
  • Loading branch information
MiksIr committed Jun 10, 2020
1 parent 0ca6270 commit 45a8987
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 112 deletions.
2 changes: 1 addition & 1 deletion src/auth/auth_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_client_visible_config(self):
def get_groups(self, user, known_groups=None):
return []

def is_active(self, user, request_handler):
def validate_user(self, user, request_handler):
return True

def logout(self, user, request_handler):
Expand Down
98 changes: 45 additions & 53 deletions src/auth/auth_gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from typing import List, Any, Dict, cast, Iterable, Union, Optional

from utils import file_utils

LOGGER = logging.getLogger('script_server.GitlabAuthorizer')


Expand Down Expand Up @@ -58,6 +60,7 @@ async def get_authenticated_user(
body=body,
raise_error=False)

default_response_values = {"state": "unknown"}
response_values = {}
if response.body:
response_values = escape.json_decode(response.body)
Expand Down Expand Up @@ -88,7 +91,7 @@ async def get_authenticated_user(
LOGGER.error(error_message)
raise AuthFailureError(error_message)

return {**response_values, **user}
return {**default_response_values, **response_values, **user}

async def fetch_user(self, access_token):
user = await self.oauth2_request(
Expand All @@ -115,9 +118,9 @@ def __init__(self, params_dict):
secret_value = model_helper.read_obligatory(params_dict, 'secret', ' for Gitlab OAuth')
self.secret = model_helper.resolve_env_vars(secret_value, full_match=True)

gitlabPrefix = params_dict.get('url')
if not model_helper.is_empty(gitlabPrefix):
self._GITLAB_PREFIX = gitlabPrefix
gitlab_prefix = params_dict.get('url')
if not model_helper.is_empty(gitlab_prefix):
self._GITLAB_PREFIX = gitlab_prefix

self.states = {}
self.user_states = {}
Expand All @@ -128,15 +131,13 @@ def __init__(self, params_dict):
now = time.time()

if self.gitlab_dump and os.path.exists(self.gitlab_dump):
dumpFile = open(self.gitlab_dump, "r")
stateStr = dumpFile.read()
self.user_states = escape.json_decode(stateStr)
dumpFile.close()
for userData in list(self.user_states.keys()):
state_str = file_utils.read_file(self.gitlab_dump)
self.user_states = escape.json_decode(state_str)
for user_data in list(self.user_states.keys()):
# force to update user from gitlab
self.user_states[userData]['updating'] = False
self.user_states[user_data]['updating'] = False
if self.gitlab_update:
self.user_states[userData]['updated'] = now - self.gitlab_update - 1
self.user_states[user_data]['updated'] = now - self.gitlab_update - 1
LOGGER.info("Readed state from file %s: " % self.gitlab_dump + str(self.user_states))

self.gitlab_group_search = params_dict.get('group_search')
Expand All @@ -152,40 +153,29 @@ def authenticate(self, request_handler):
LOGGER.error('Code is not specified')
raise AuthBadRequestException('Missing authorization information. Please contact your administrator')

return self.validate_user(code, request_handler)
return self.read_user(code, request_handler)

def is_active(self, user, request_handler):
def validate_user(self, user, request_handler):
access_token = request_handler.get_secure_cookie('token')
if access_token is None:
return False
access_token = access_token.decode("utf-8")

self.clean_and_persist_sessions()

if self.user_states.get(user) is None:
LOGGER.debug("User %s not found in state" % user)
return False

if self.user_states[user]['state'] is None or self.user_states[user]['state'] != "active":
LOGGER.info("User %s state inactive: " % user + str(self.user_states[user]))
del self.user_states[user]
self.dump_sessions_to_file()
return False

now = time.time()
# check session ttl
if self.session_expire and (self.user_states[user]['visit'] + self.session_expire) < now:
del self.user_states[user]
LOGGER.info("User %s session expired, logged out" % user)
self.dump_sessions_to_file()
return False

self.user_states[user]['visit'] = now

# check gitlab response ttl, also check for stale updating (ttl*2)
if self.gitlab_update is not None:
stale = (self.user_states[user]['updated'] + max(self.gitlab_update*2, 60)) < now
stale_update = (self.user_states[user]['updated'] + max(self.gitlab_update*2, 60)) < now
ttl_expired = (self.user_states[user]['updated'] + self.gitlab_update) < now
updating_now = self.user_states[user]['updating'] is True
if ttl_expired and (not updating_now or stale):
if ttl_expired and (not updating_now or stale_update):
if self.gitlab_group_support:
self.do_update_groups(user, access_token)
else:
Expand All @@ -203,31 +193,38 @@ def get_groups(self, user, known_groups=None):
def logout(self, user, request_handler):
request_handler.clear_cookie('token')

def clean_expired_sessions(self):
def clean_sessions(self):
now = time.time()
if self.session_expire:
for userData in list(self.user_states.keys()):
if (self.user_states[userData]['visit'] + self.session_expire) < now:
LOGGER.debug("User %s session expired and removed" % userData)
del self.user_states[userData]

def dump_sessions_to_file(self):
for user_data in list(self.user_states.keys()):
if self.session_expire and (self.user_states[user_data]['visit'] + self.session_expire) < now:
LOGGER.info("User %s removed because session expired" % user_data)
del self.user_states[user_data]
continue
if self.user_states[user_data]['state'] is None or self.user_states[user_data]['state'] != "active":
LOGGER.info("User %s removed because state '%s' != 'active'" %
(user_data, self.user_states[user_data]['state']))
del self.user_states[user_data]
continue

def clean_and_persist_sessions(self):
self.clean_sessions()
if self.gitlab_dump:
dumpFile = open(self.gitlab_dump, "w")
dumpFile.write(escape.json_encode(self.user_states))
dumpFile.close()
LOGGER.debug("Dumped state to file %s" % self.gitlab_dump)
self.persist_session()

def persist_session(self):
file_utils.write_file(self.gitlab_dump, escape.json_encode(self.user_states))
LOGGER.debug("Dumped state to file %s" % self.gitlab_dump)

def do_update_user(self, user, access_token):
self.user_states[user]['updating'] = True
tornado.ioloop.IOLoop.current().spawn_callback(self.update_user_state, user, access_token)
tornado.ioloop.IOLoop.current().spawn_callback(self.update_user, user, access_token)

def do_update_groups(self, user, access_token):
self.user_states[user]['updating'] = True
tornado.ioloop.IOLoop.current().spawn_callback(self.update_group_list, user, access_token)
tornado.ioloop.IOLoop.current().spawn_callback(self.update_groups, user, access_token)

@gen.coroutine
def update_group_list(self, user, access_token):
def update_groups(self, user, access_token):
group_list = yield self.read_groups(access_token)
if group_list is None:
LOGGER.error("Failed to refresh groups for %s" % user)
Expand All @@ -239,12 +236,10 @@ def update_group_list(self, user, access_token):
self.user_states[user]['updating'] = False
self.user_states[user]['updated'] = now
self.user_states[user]['visit'] = now
self.clean_expired_sessions()
self.dump_sessions_to_file()
return
self.clean_and_persist_sessions()

@gen.coroutine
def update_user_state(self, user, access_token):
def update_user(self, user, access_token):
user_state = yield self.fetch_user(access_token)
if user_state is None:
LOGGER.error("Failed to fetch user %s" % user)
Expand All @@ -256,8 +251,7 @@ def update_user_state(self, user, access_token):
self.user_states[user]['updating'] = False
self.user_states[user]['updated'] = now
self.user_states[user]['visit'] = now
self.clean_expired_sessions()
self.dump_sessions_to_file()
self.clean_and_persist_sessions()
return

@gen.coroutine
Expand Down Expand Up @@ -288,7 +282,7 @@ def read_groups(self, access_token):
return groups

@gen.coroutine
def validate_user(self, code, request_handler):
def read_user(self, code, request_handler):
user_response_future = self.get_authenticated_user(
get_path_for_redirect(request_handler),
self.client_id,
Expand Down Expand Up @@ -316,10 +310,8 @@ def validate_user(self, code, request_handler):
user_response['visit'] = time.time()
user_response['updating'] = False
oauth_access_token = user_response.pop('access_token')
oauth_refresh_token = user_response.pop('refresh_token') # not used atm
self.user_states[user_response['email']] = user_response
self.clean_expired_sessions()
self.dump_sessions_to_file()
self.clean_and_persist_sessions()
request_handler.set_secure_cookie('token', oauth_access_token)

return user_response['email']
Expand Down
2 changes: 1 addition & 1 deletion src/auth/tornado_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def is_authenticated(self, request_handler):
if not username:
return False

active = self.authenticator.is_active(username, request_handler)
active = self.authenticator.validate_user(username, request_handler)
if not active:
self.logout(request_handler)

Expand Down
44 changes: 21 additions & 23 deletions src/tests/auth/test_auth_gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@

mock_time = Mock()
mock_time.return_value = 10000.01
mock_dump_sessions_to_file = Mock()
mock_persist_session = Mock()
mock_do_update_groups = Mock()
mock_do_update_user = Mock()
mock_request_handler = Mock(**{'get_secure_cookie.return_value': "12345".encode()})


class TestAuthConfig(TestCase):
@patch('time.time', mock_time)
@patch('auth.auth_gitlab.GitlabOAuthAuthenticator.dump_sessions_to_file', mock_dump_sessions_to_file)
@patch('auth.auth_gitlab.GitlabOAuthAuthenticator.persist_session', mock_persist_session)
@patch('auth.auth_gitlab.GitlabOAuthAuthenticator.do_update_groups', mock_do_update_groups)
def test_gitlab_oauth(self):
tmp = tempfile.mkstemp('.json', 'test_auth_gitlab-')
now = time.time()
state = {
"user@test.com": {
Expand All @@ -57,8 +56,7 @@ def test_gitlab_oauth(self):
}
}

os.write(tmp[0], str.encode(escape.json_encode(state)))
os.fsync(tmp[0])
state_file = test_utils.create_file("gitlab_state.json", text=escape.json_encode(state))

config = _from_json({
'auth': {
Expand All @@ -68,15 +66,15 @@ def test_gitlab_oauth(self):
"secret": "abcd",
"group_search": "script-server",
"auth_info_ttl": 80,
"state_dump_file": tmp[1],
"state_dump_file": state_file,
"session_expire_minutes": 10
},
'access': {
'allowed_users': []
}})

self.assertIsInstance(config.authenticator, GitlabOAuthAuthenticator)
self.assertEqual(tmp[1], config.authenticator.gitlab_dump)
self.assertEqual(state_file, config.authenticator.gitlab_dump)
self.assertEqual("1234", config.authenticator._client_visible_config['client_id'])
self.assertEqual("https://gitlab/oauth/authorize", config.authenticator._client_visible_config['oauth_url'])
self.assertEqual("api", config.authenticator._client_visible_config['oauth_scope'])
Expand All @@ -88,61 +86,58 @@ def test_gitlab_oauth(self):
self.assertDictEqual(assert_state, config.authenticator.user_states)
saved_state = copy.deepcopy(config.authenticator.user_states)

self.assertEqual(False, config.authenticator.is_active("unknown@test.com", mock_request_handler))
self.assertEqual(False, config.authenticator.is_active("nogroups@test.com", mock_request_handler))
self.assertEqual(False, config.authenticator.validate_user("unknown@test.com", mock_request_handler))
self.assertEqual(False, config.authenticator.validate_user("nogroups@test.com", mock_request_handler))
self.assertListEqual([], config.authenticator.get_groups("unknown@test.com"))
self.assertListEqual([], config.authenticator.get_groups("nogroups@test.com"))

self.assertEqual(True, config.authenticator.is_active("user@test.com", mock_request_handler))
self.assertEqual(True, config.authenticator.validate_user("user@test.com", mock_request_handler))
self.assertEqual(time.time(), config.authenticator.user_states["user@test.com"]["visit"], "visit updated")
self.assertEqual(True, mock_do_update_groups.called, "state just loaded, gitlab updating")
mock_do_update_groups.reset_mock()

config.authenticator.user_states["user@test.com"]["updating"] = True
self.assertEqual(True, config.authenticator.is_active("user@test.com", mock_request_handler))
self.assertEqual(True, config.authenticator.validate_user("user@test.com", mock_request_handler))
self.assertEqual(False, mock_do_update_groups.called, "do not call parallel updated")
mock_do_update_groups.reset_mock()

mock_time.return_value = 10000.01 + 80*2 + 1 # stale request
self.assertEqual(True, config.authenticator.is_active("user@test.com", mock_request_handler))
self.assertEqual(True, config.authenticator.validate_user("user@test.com", mock_request_handler))
self.assertEqual(True, mock_do_update_groups.called, "parallel but stale")
mock_do_update_groups.reset_mock()
config.authenticator.user_states = copy.deepcopy(saved_state)
mock_time.return_value = 10000.01

config.authenticator.user_states["user@test.com"]['updated'] = now # gitlab info updated
config.authenticator.user_states["user@test.com"]['updating'] = False
self.assertEqual(True, config.authenticator.is_active("user@test.com", mock_request_handler))
self.assertEqual(True, config.authenticator.validate_user("user@test.com", mock_request_handler))
self.assertEqual(False, mock_do_update_groups.called, "do not update gitlab because ttl not expired")
mock_do_update_groups.reset_mock()

mock_time.return_value = 10000.01 + 81
self.assertEqual(True, config.authenticator.is_active("user@test.com", mock_request_handler))
self.assertEqual(True, config.authenticator.validate_user("user@test.com", mock_request_handler))
self.assertEqual(True, mock_do_update_groups.called, "ttl expired")
mock_do_update_groups.reset_mock()
config.authenticator.user_states = copy.deepcopy(saved_state)
mock_time.return_value = 10000.01

# session expire test
mock_time.return_value = 10000.01 + 601
self.assertEqual(False, config.authenticator.is_active("user@test.com", mock_request_handler), "shoud be expired")
self.assertEqual(True, mock_dump_sessions_to_file.called, "dump state to file")
mock_dump_sessions_to_file.reset_mock()
self.assertEqual(False, config.authenticator.validate_user("user@test.com", mock_request_handler), "shoud be expired")
self.assertEqual(True, mock_persist_session.called, "dump state to file")
mock_persist_session.reset_mock()
self.assertIsNone(config.authenticator.user_states.get("user@test.com"), "removed from state")
self.assertListEqual([], config.authenticator.get_groups("user@test.com"))
config.authenticator.user_states = copy.deepcopy(saved_state)
mock_time.return_value = 10000.01

# test clean expire
mock_time.return_value = 10000.01 + 601
config.authenticator.clean_expired_sessions()
config.authenticator.clean_sessions()
self.assertIsNone(config.authenticator.user_states.get("user@test.com"))
config.authenticator.user_states = copy.deepcopy(saved_state)
mock_time.return_value = 10000.01

os.close(tmp[0])
os.unlink(tmp[1])

@patch('time.time', mock_time)
@patch('auth.auth_gitlab.GitlabOAuthAuthenticator.do_update_user', mock_do_update_user)
@patch('auth.auth_gitlab.GitlabOAuthAuthenticator.do_update_groups', mock_do_update_groups)
Expand Down Expand Up @@ -181,19 +176,22 @@ def test_gitlab_oauth_user_read_scope(self):
self.assertIsInstance(config.authenticator, GitlabOAuthAuthenticator)
self.assertEqual("read_user", config.authenticator._client_visible_config['oauth_scope'])
config.authenticator.user_states = state
self.assertEqual(True, config.authenticator.is_active("user@test.com", mock_request_handler))
self.assertEqual(True, config.authenticator.validate_user("user@test.com", mock_request_handler))
self.assertEqual(False, mock_do_update_groups.called, "update==0, gitlab updating but not groups")
self.assertEqual(True, mock_do_update_user.called, "update==0, gitlab updating only user")
mock_do_update_groups.reset_mock()
mock_do_update_user.reset_mock()

config.authenticator.gitlab_update = None
self.assertEqual(True, config.authenticator.is_active("user@test.com", mock_request_handler))
self.assertEqual(True, config.authenticator.validate_user("user@test.com", mock_request_handler))
self.assertEqual(False, mock_do_update_groups.called, "gitab update disabled")
self.assertEqual(False, mock_do_update_user.called, "gitab update disabled")
mock_do_update_groups.reset_mock()
mock_do_update_user.reset_mock()

def tearDown(self):
test_utils.cleanup()


def _from_json(content):
json_obj = json.dumps(content)
Expand Down
Loading

0 comments on commit 45a8987

Please sign in to comment.