diff --git a/client/autotest_client/__init__.py b/client/autotest_client/__init__.py index 3b996728..5ac845fe 100644 --- a/client/autotest_client/__init__.py +++ b/client/autotest_client/__init__.py @@ -23,6 +23,8 @@ SETTINGS_JOB_TIMEOUT = os.environ.get("SETTINGS_JOB_TIMEOUT", 600) REDIS_URL = os.environ["REDIS_URL"] +REDIS_CONNECTION = redis.Redis.from_url(REDIS_URL) + app = Flask(__name__) @@ -35,23 +37,6 @@ def _open_log(log, mode="a", fallback=sys.stdout): yield fallback -def _redis_connection(decode_responses=True) -> redis.Redis: - return redis.Redis.from_url(REDIS_URL, decode_responses=decode_responses) - - -def _rq_connection() -> redis.Redis: - """ - Return the currently open redis connection object. If there is no - connection currently open, one is created using the url specified in - REDIS_URL. - """ - conn = rq.get_current_connection() - if conn: - return conn - rq.use_connection(redis=redis.Redis.from_url(REDIS_URL)) - return rq.get_current_connection() - - @app.errorhandler(Exception) def _handle_error(e): code = 500 @@ -71,14 +56,13 @@ def _handle_error(e): def _check_rate_limit(api_key): - conn = _redis_connection() key = f"autotest:ratelimit:{api_key}:{datetime.now().minute}" - n_requests = conn.get(key) or 0 - user_limit = conn.get(f"autotest:ratelimit:{api_key}:limit") or 20 # TODO: make default limit configurable + n_requests = REDIS_CONNECTION.get(key) or 0 + user_limit = REDIS_CONNECTION.get(f"autotest:ratelimit:{api_key}:limit") or 20 # TODO: make default configurable if int(n_requests) > int(user_limit): abort(make_response(jsonify(message="Too many requests"), 429)) else: - with conn.pipeline() as pipe: + with REDIS_CONNECTION.pipeline() as pipe: pipe.incr(key) pipe.expire(key, 59) pipe.execute() @@ -86,8 +70,7 @@ def _check_rate_limit(api_key): def _authorize_user(): api_key = request.headers.get("Api-Key") - user_name = (_redis_connection().hgetall("autotest:user_credentials") or {}).get(api_key) - if user_name is None: + if api_key is None or (REDIS_CONNECTION.hgetall("autotest:user_credentials") or {}).get(api_key.encode()) is None: abort(make_response(jsonify(message="Unauthorized"), 401)) _check_rate_limit(api_key) return api_key @@ -95,7 +78,7 @@ def _authorize_user(): def _authorize_settings(user, settings_id=None, **_kw): if settings_id: - settings_ = _redis_connection().hget("autotest:settings", settings_id) + settings_ = REDIS_CONNECTION.hget("autotest:settings", settings_id) if settings_ is None: abort(make_response(jsonify(message="Settings not found"), 404)) if json.loads(settings_).get("_user") != user: @@ -104,10 +87,10 @@ def _authorize_settings(user, settings_id=None, **_kw): def _authorize_tests(tests_id=None, settings_id=None, **_kw): if settings_id and tests_id: - test_setting = _redis_connection().hget("autotest:tests", tests_id) + test_setting = REDIS_CONNECTION.hget("autotest:tests", tests_id) if test_setting is None: abort(make_response(jsonify(message="Test not found"), 404)) - if test_setting != settings_id: + if int(test_setting) != int(settings_id): abort(make_response(jsonify(message="Unauthorized"), 401)) @@ -125,7 +108,7 @@ def _update_settings(settings_id, user): if error: abort(make_response(jsonify(message=error), 422)) - queue = rq.Queue("settings", connection=_rq_connection()) + queue = rq.Queue("settings", connection=REDIS_CONNECTION) data = {"user": user, "settings_id": settings_id, "test_settings": test_settings, "file_url": file_url} queue.enqueue_call( "autotest_server.update_test_settings", @@ -137,12 +120,12 @@ def _update_settings(settings_id, user): def _get_jobs(test_ids, settings_id): for id_ in test_ids: - test_setting = _redis_connection().hget("autotest:tests", id_) - if test_setting is None or test_setting != settings_id: + test_setting = REDIS_CONNECTION.hget("autotest:tests", id_) + if test_setting is None or int(test_setting) != int(settings_id): yield None else: try: - yield rq.job.Job.fetch(str(id_), connection=_rq_connection()) + yield rq.job.Job.fetch(str(id_), connection=REDIS_CONNECTION) except rq.exceptions.NoSuchJobError: yield None @@ -181,7 +164,7 @@ def register(): credentials = request.json.get("credentials") key = base64.b64encode(os.urandom(24)).decode("utf-8") data = {"auth_type": auth_type, "credentials": credentials} - while not _redis_connection().hsetnx("autotest:user_credentials", key=key, value=json.dumps(data)): + while not REDIS_CONNECTION.hsetnx("autotest:user_credentials", key=key, value=json.dumps(data)): key = base64.b64encode(os.urandom(24)).decode("utf-8") return {"api_key": key} @@ -192,20 +175,20 @@ def reset_credentials(user): auth_type = request.json.get("auth_type") credentials = request.json.get("credentials") data = {"auth_type": auth_type, "credentials": credentials} - _redis_connection().hset("autotest:user_credentials", key=user, value=json.dumps(data)) + REDIS_CONNECTION.hset("autotest:user_credentials", key=user, value=json.dumps(data)) return jsonify(success=True) @app.route("/schema", methods=["GET"]) @authorize def schema(**_kwargs): - return json.loads(_redis_connection().get("autotest:schema") or "{}") + return json.loads(REDIS_CONNECTION.get("autotest:schema") or "{}") @app.route("/settings/", methods=["GET"]) @authorize def settings(settings_id, **_kw): - settings_ = json.loads(_redis_connection().hget("autotest:settings", key=settings_id) or "{}") + settings_ = json.loads(REDIS_CONNECTION.hget("autotest:settings", key=settings_id) or "{}") if settings_.get("_error"): raise Exception(f"Settings Error: {settings_['_error']}") return {k: v for k, v in settings_.items() if not k.startswith("_")} @@ -214,8 +197,8 @@ def settings(settings_id, **_kw): @app.route("/settings", methods=["POST"]) @authorize def create_settings(user): - settings_id = _redis_connection().incr("autotest:settings_id") - _redis_connection().hset("autotest:settings", key=settings_id, value=json.dumps({"_user": user})) + settings_id = REDIS_CONNECTION.incr("autotest:settings_id") + REDIS_CONNECTION.hset("autotest:settings", key=settings_id, value=json.dumps({"_user": user})) _update_settings(settings_id, user) return {"settings_id": settings_id} @@ -234,7 +217,7 @@ def run_tests(settings_id, user): categories = request.json["categories"] high_priority = request.json.get("request_high_priority") queue_name = "batch" if len(test_data) > 1 else ("high" if high_priority else "low") - queue = rq.Queue(queue_name, connection=_rq_connection()) + queue = rq.Queue(queue_name, connection=REDIS_CONNECTION) timeout = 0 @@ -246,8 +229,8 @@ def run_tests(settings_id, user): for data in test_data: url = data["file_url"] test_env_vars = data.get("env_vars", {}) - id_ = _redis_connection().incr("autotest:tests_id") - _redis_connection().hset("autotest:tests", key=id_, value=settings_id) + id_ = REDIS_CONNECTION.incr("autotest:tests_id") + REDIS_CONNECTION.hset("autotest:tests", key=id_, value=settings_id) ids.append(id_) data = { "settings_id": settings_id, @@ -272,11 +255,11 @@ def run_tests(settings_id, user): @app.route("/settings//test/", methods=["GET"]) @authorize def get_result(settings_id, tests_id, **_kw): - job = rq.job.Job.fetch(tests_id, connection=_rq_connection()) + job = rq.job.Job.fetch(tests_id, connection=REDIS_CONNECTION) job_status = job.get_status() result = {"status": job_status} if job_status == "finished": - test_result = _redis_connection().get(f"autotest:test_result:{tests_id}") + test_result = REDIS_CONNECTION.get(f"autotest:test_result:{tests_id}") try: result.update(json.loads(test_result)) except json.JSONDecodeError: @@ -284,7 +267,7 @@ def get_result(settings_id, tests_id, **_kw): elif job_status == "failed": result.update({"error": str(job.exc_info)}) job.delete() - _redis_connection().delete(f"autotest:test_result:{tests_id}") + REDIS_CONNECTION.delete(f"autotest:test_result:{tests_id}") return result @@ -292,10 +275,10 @@ def get_result(settings_id, tests_id, **_kw): @authorize def get_feedback_file(settings_id, tests_id, feedback_id, **_kw): key = f"autotest:feedback_file:{tests_id}:{feedback_id}" - data = _redis_connection(decode_responses=False).get(key) + data = REDIS_CONNECTION.get(key) if data is None: abort(make_response(jsonify(message="File doesn't exist"), 404)) - _redis_connection().delete(key) + REDIS_CONNECTION.delete(key) return send_file(io.BytesIO(data), mimetype="application/gzip", as_attachment=True, download_name=str(feedback_id)) diff --git a/client/autotest_client/tests/test_flask_app.py b/client/autotest_client/tests/test_flask_app.py index e0dd5430..87a6c081 100644 --- a/client/autotest_client/tests/test_flask_app.py +++ b/client/autotest_client/tests/test_flask_app.py @@ -13,18 +13,12 @@ def client(): @pytest.fixture def fake_redis_conn(): - yield fakeredis.FakeStrictRedis(decode_responses=True) - - -@pytest.fixture -def fake_rq_conn(): - conn = fakeredis.FakeStrictRedis(decode_responses=False) - autotest_client.rq.use_connection(conn) + yield fakeredis.FakeStrictRedis() @pytest.fixture(autouse=True) def fake_redis_db(monkeypatch, fake_redis_conn): - monkeypatch.setattr(autotest_client.redis.Redis, "from_url", lambda *a, **kw: fake_redis_conn) + monkeypatch.setattr(autotest_client, "REDIS_CONNECTION", fake_redis_conn) class TestRegister: diff --git a/server/autotest_server/__init__.py b/server/autotest_server/__init__.py index 9206c3e8..f65b2d9e 100644 --- a/server/autotest_server/__init__.py +++ b/server/autotest_server/__init__.py @@ -13,6 +13,7 @@ import importlib import psycopg2 import mimetypes +import rq from typing import Optional, Dict, Union, List, Tuple, Callable, Type from types import TracebackType @@ -20,14 +21,13 @@ from .utils import loads_partial_json, set_rlimits_before_test, extract_zip_stream, recursive_iglob, copy_tree DEFAULT_ENV_DIR = "defaultvenv" -REDIS_URL = config["redis_url"] TEST_SCRIPT_DIR = os.path.join(config["workspace"], "scripts") ResultData = Dict[str, Union[str, int, type(None), Dict]] def redis_connection() -> redis.Redis: - return redis.Redis.from_url(REDIS_URL, decode_responses=True) + return rq.get_current_job().connection def run_test_command(test_username: Optional[str] = None) -> str: diff --git a/server/autotest_server/tests/test_autotest_server.py b/server/autotest_server/tests/test_autotest_server.py index d596fb34..bc3be969 100644 --- a/server/autotest_server/tests/test_autotest_server.py +++ b/server/autotest_server/tests/test_autotest_server.py @@ -1,16 +1,27 @@ import pytest import fakeredis +import rq import autotest_server @pytest.fixture def fake_redis_conn(): - yield fakeredis.FakeStrictRedis(decode_responses=True) + yield fakeredis.FakeStrictRedis() + + +@pytest.fixture +def fake_queue(fake_redis_conn): + yield rq.Queue(is_async=False, connection=fake_redis_conn) + + +@pytest.fixture +def fake_job(fake_queue): + yield fake_queue.enqueue(lambda: None) @pytest.fixture(autouse=True) -def fake_redis_db(monkeypatch, fake_redis_conn): - monkeypatch.setattr(autotest_server.redis.Redis, "from_url", lambda *a, **kw: fake_redis_conn) +def fake_redis_db(monkeypatch, fake_job): + monkeypatch.setattr(autotest_server.rq, "get_current_job", lambda *a, **kw: fake_job) def test_redis_connection(fake_redis_conn): diff --git a/server/install.py b/server/install.py index a8288d93..878763fc 100644 --- a/server/install.py +++ b/server/install.py @@ -7,10 +7,13 @@ import json import subprocess import getpass +import redis from autotest_server.config import config -from autotest_server import redis_connection, run_test_command +from autotest_server import run_test_command from autotest_server.testers import install as install_testers +REDIS_CONNECTION = redis.Redis.from_url(config["redis_url"]) + def _print(*args, **kwargs): print("[AUTOTESTER]", *args, **kwargs) @@ -19,7 +22,7 @@ def _print(*args, **kwargs): def check_dependencies(): _print("checking if redis url is valid:") try: - redis_connection().keys() + REDIS_CONNECTION.ping() except Exception as e: raise Exception(f'Cannot connect to redis database with url: {config["redis_url"]}') from e for w in config["workers"]: @@ -66,7 +69,7 @@ def install_all_testers(): skeleton = json.load(f) skeleton["definitions"]["installed_testers"]["enum"] = list(settings.keys()) skeleton["definitions"]["tester_schemas"]["oneOf"] = list(settings.values()) - redis_connection().set("autotest:schema", json.dumps(skeleton)) + REDIS_CONNECTION.set("autotest:schema", json.dumps(skeleton)) def install(): diff --git a/server/start_stop.py b/server/start_stop.py index 55d926cc..218dabb2 100644 --- a/server/start_stop.py +++ b/server/start_stop.py @@ -43,9 +43,7 @@ """ - -def redis_connection() -> redis.Redis: - return redis.Redis.from_url(config["redis_url"], decode_responses=True) +REDIS_CONNECTION = redis.Redis.from_url(config["redis_url"], decode_responses=True) def create_enqueuer_wrapper(rq): @@ -82,7 +80,7 @@ def stat(rq, extra_args): def clean(age, dry_run): - for settings_id, settings in dict(redis_connection().hgetall("autotest:settings") or {}).items(): + for settings_id, settings in dict(REDIS_CONNECTION.hgetall("autotest:settings") or {}).items(): settings = json.loads(settings) last_access_timestamp = settings.get("_last_access") access = int(time.time() - (last_access_timestamp or 0)) @@ -93,7 +91,7 @@ def clean(age, dry_run): print(f"{dir_path} -> last accessed {last_access or '< 1'} days ago") else: settings["_error"] = "the settings for this test have expired, please re-upload the settings." - redis_connection().hset("autotest:settings", key=settings_id, value=json.dumps(settings)) + REDIS_CONNECTION.hset("autotest:settings", key=settings_id, value=json.dumps(settings)) if os.path.isdir(dir_path): shutil.rmtree(dir_path)