Skip to content

Commit

Permalink
root: Restructure broker / cache / channel / result configuration (#7097
Browse files Browse the repository at this point in the history
)

* Initial commit

* Remove any remaining mentions of Redis URL

This is handled in #5395

* Allow setting broker transport options

This enables usage of other brokers that require additional settings

* Remove remaining reference to Redis URL

This functionality is not part of this PR

* Reset default TLS requirements to none

* Fix linter errors

* Move dict from base64 encoded json to config.py

Additionally add tests

* Replace ast.literal_eval with json.loads

* Use default channel and cache backend configuration

If more customization is desired users shall look at goauthentik.io/docs/installation/configuration#custom-python-settings

* Send config deprecation notification to all superusers

* Remove duplicate method

* Add configuration explanation

For channel layer settings

* Use Event for deprecation warning

* Fix remove duplicated method

* Add missing comma

* Update authentik/lib/config.py

Signed-off-by: Jens L. <jens@beryju.org>

* Fix Event deprecation handling

---------

Signed-off-by: Jens L. <jens@beryju.org>
Co-authored-by: Jens L <jens@beryju.org>
  • Loading branch information
PKizzle and BeryJu authored Nov 10, 2023
1 parent 11dcda7 commit 9db9ad3
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 47 deletions.
8 changes: 4 additions & 4 deletions authentik/api/v3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def get_config(self) -> ConfigSerializer:
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
},
"capabilities": self.get_capabilities(),
"cache_timeout": CONFIG.get_int("redis.cache_timeout"),
"cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"),
"cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"),
"cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"),
"cache_timeout": CONFIG.get_int("cache.timeout"),
"cache_timeout_flows": CONFIG.get_int("cache.timeout_flows"),
"cache_timeout_policies": CONFIG.get_int("cache.timeout_policies"),
"cache_timeout_reputation": CONFIG.get_int("cache.timeout_reputation"),
}
)

Expand Down
2 changes: 1 addition & 1 deletion authentik/flows/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
# was restored.
PLAN_CONTEXT_IS_RESTORED = "is_restored"
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows")
CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_flows")
CACHE_PREFIX = "goauthentik.io/flows/planner/"


Expand Down
90 changes: 82 additions & 8 deletions authentik/lib/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""authentik core config loader"""
import base64
import json
import os
from collections.abc import Mapping
from contextlib import contextmanager
Expand All @@ -22,6 +24,25 @@
ENV_PREFIX = "AUTHENTIK"
ENVIRONMENT = os.getenv(f"{ENV_PREFIX}_ENV", "local")

REDIS_ENV_KEYS = [
f"{ENV_PREFIX}_REDIS__HOST",
f"{ENV_PREFIX}_REDIS__PORT",
f"{ENV_PREFIX}_REDIS__DB",
f"{ENV_PREFIX}_REDIS__USERNAME",
f"{ENV_PREFIX}_REDIS__PASSWORD",
f"{ENV_PREFIX}_REDIS__TLS",
f"{ENV_PREFIX}_REDIS__TLS_REQS",
]

DEPRECATIONS = {
"redis.broker_url": "broker.url",
"redis.broker_transport_options": "broker.transport_options",
"redis.cache_timeout": "cache.timeout",
"redis.cache_timeout_flows": "cache.timeout_flows",
"redis.cache_timeout_policies": "cache.timeout_policies",
"redis.cache_timeout_reputation": "cache.timeout_reputation",
}


def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
"""Recursively walk through `root`, checking each part of `path` separated by `sep`.
Expand Down Expand Up @@ -81,6 +102,10 @@ def default(self, o: Any) -> Any:
return super().default(o)


class UNSET:
"""Used to test whether configuration key has not been set."""


class ConfigLoader:
"""Search through SEARCH_PATHS and load configuration. Environment variables starting with
`ENV_PREFIX` are also applied.
Expand Down Expand Up @@ -113,6 +138,40 @@ def __init__(self, **kwargs):
self.update_from_file(env_file)
self.update_from_env()
self.update(self.__config, kwargs)
self.check_deprecations()

def check_deprecations(self):
"""Warn if any deprecated configuration options are used"""

def _pop_deprecated_key(current_obj, dot_parts, index):
"""Recursive function to remove deprecated keys in configuration"""
dot_part = dot_parts[index]
if index == len(dot_parts) - 1:
return current_obj.pop(dot_part)
value = _pop_deprecated_key(current_obj[dot_part], dot_parts, index + 1)
if not current_obj[dot_part]:
current_obj.pop(dot_part)
return value

for deprecation, replacement in DEPRECATIONS.items():
if self.get(deprecation, default=UNSET) is not UNSET:
message = (
f"'{deprecation}' has been deprecated in favor of '{replacement}'! "
+ "Please update your configuration."
)
self.log(
"warning",
message,
)
try:
from authentik.events.models import Event, EventAction

Event.new(EventAction.CONFIGURATION_ERROR, message=message).save()
except ImportError:
continue

deprecated_attr = _pop_deprecated_key(self.__config, deprecation.split("."), 0)
self.set(replacement, deprecated_attr.value)

def log(self, level: str, message: str, **kwargs):
"""Custom Log method, we want to ensure ConfigLoader always logs JSON even when
Expand Down Expand Up @@ -180,6 +239,10 @@ def update_from_file(self, path: Path):
error=str(exc),
)

def update_from_dict(self, update: dict):
"""Update config from dict"""
self.__config.update(update)

def update_from_env(self):
"""Check environment variables"""
outer = {}
Expand All @@ -188,19 +251,13 @@ def update_from_env(self):
if not key.startswith(ENV_PREFIX):
continue
relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower()
# Recursively convert path from a.b.c into outer[a][b][c]
current_obj = outer
dot_parts = relative_key.split(".")
for dot_part in dot_parts[:-1]:
if dot_part not in current_obj:
current_obj[dot_part] = {}
current_obj = current_obj[dot_part]
# Check if the value is json, and try to load it
try:
value = loads(value)
except JSONDecodeError:
pass
current_obj[dot_parts[-1]] = Attr(value, Attr.Source.ENV, key)
attr_value = Attr(value, Attr.Source.ENV, relative_key)
set_path_in_dict(outer, relative_key, attr_value)
idx += 1
if idx > 0:
self.log("debug", "Loaded environment variables", count=idx)
Expand Down Expand Up @@ -241,6 +298,23 @@ def get_bool(self, path: str, default=False) -> bool:
"""Wrapper for get that converts value into boolean"""
return str(self.get(path, default)).lower() == "true"

def get_dict_from_b64_json(self, path: str, default=None) -> dict:
"""Wrapper for get that converts value from Base64 encoded string into dictionary"""
config_value = self.get(path)
if config_value is None:
return {}
try:
b64decoded_str = base64.b64decode(config_value).decode("utf-8")
b64decoded_str = b64decoded_str.strip().lstrip("{").rstrip("}")
b64decoded_str = "{" + b64decoded_str + "}"
return json.loads(b64decoded_str)
except (JSONDecodeError, TypeError, ValueError) as exc:
self.log(
"warning",
f"Ignored invalid configuration for '{path}' due to exception: {str(exc)}",
)
return default if isinstance(default, dict) else {}

def set(self, path: str, value: Any, sep="."):
"""Set value using same syntax as get()"""
set_path_in_dict(self.raw, path, Attr(value), sep=sep)
Expand Down
24 changes: 19 additions & 5 deletions authentik/lib/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,28 @@ listen:
redis:
host: localhost
port: 6379
db: 0
username: ""
password: ""
tls: false
tls_reqs: "none"
db: 0
cache_timeout: 300
cache_timeout_flows: 300
cache_timeout_policies: 300
cache_timeout_reputation: 300

# broker:
# url: ""
# transport_options: ""

cache:
# url: ""
timeout: 300
timeout_flows: 300
timeout_policies: 300
timeout_reputation: 300

# channel:
# url: ""

# result_backend:
# url: ""

paths:
media: ./media
Expand Down
87 changes: 82 additions & 5 deletions authentik/lib/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
"""Test config loader"""
import base64
from json import dumps
from os import chmod, environ, unlink, write
from tempfile import mkstemp
from unittest import mock

from django.conf import ImproperlyConfigured
from django.test import TestCase

from authentik.lib.config import ENV_PREFIX, ConfigLoader
from authentik.lib.config import ENV_PREFIX, UNSET, Attr, AttrEncoder, ConfigLoader


class TestConfig(TestCase):
"""Test config loader"""

check_deprecations_env_vars = {
ENV_PREFIX + "_REDIS__BROKER_URL": "redis://myredis:8327/43",
ENV_PREFIX + "_REDIS__BROKER_TRANSPORT_OPTIONS": "bWFzdGVybmFtZT1teW1hc3Rlcg==",
ENV_PREFIX + "_REDIS__CACHE_TIMEOUT": "124s",
ENV_PREFIX + "_REDIS__CACHE_TIMEOUT_FLOWS": "32m",
ENV_PREFIX + "_REDIS__CACHE_TIMEOUT_POLICIES": "3920ns",
ENV_PREFIX + "_REDIS__CACHE_TIMEOUT_REPUTATION": "298382us",
}

@mock.patch.dict(environ, {ENV_PREFIX + "_test__test": "bar"})
def test_env(self):
"""Test simple instance"""
config = ConfigLoader()
environ[ENV_PREFIX + "_test__test"] = "bar"
config.update_from_env()
self.assertEqual(config.get("test.test"), "bar")

Expand All @@ -27,12 +39,20 @@ def test_patch(self):
self.assertEqual(config.get("foo.bar"), "baz")
self.assertEqual(config.get("foo.bar"), "bar")

@mock.patch.dict(environ, {"foo": "bar"})
def test_uri_env(self):
"""Test URI parsing (environment)"""
config = ConfigLoader()
environ["foo"] = "bar"
self.assertEqual(config.parse_uri("env://foo").value, "bar")
self.assertEqual(config.parse_uri("env://foo?bar").value, "bar")
foo_uri = "env://foo"
foo_parsed = config.parse_uri(foo_uri)
self.assertEqual(foo_parsed.value, "bar")
self.assertEqual(foo_parsed.source_type, Attr.Source.URI)
self.assertEqual(foo_parsed.source, foo_uri)
foo_bar_uri = "env://foo?bar"
foo_bar_parsed = config.parse_uri(foo_bar_uri)
self.assertEqual(foo_bar_parsed.value, "bar")
self.assertEqual(foo_bar_parsed.source_type, Attr.Source.URI)
self.assertEqual(foo_bar_parsed.source, foo_bar_uri)

def test_uri_file(self):
"""Test URI parsing (file load)"""
Expand Down Expand Up @@ -91,3 +111,60 @@ def test_get_int_invalid(self):
config = ConfigLoader()
config.set("foo", "bar")
self.assertEqual(config.get_int("foo", 1234), 1234)

def test_get_dict_from_b64_json(self):
"""Test get_dict_from_b64_json"""
config = ConfigLoader()
test_value = ' { "foo": "bar" } '.encode("utf-8")
b64_value = base64.b64encode(test_value)
config.set("foo", b64_value)
self.assertEqual(config.get_dict_from_b64_json("foo"), {"foo": "bar"})

def test_get_dict_from_b64_json_missing_brackets(self):
"""Test get_dict_from_b64_json with missing brackets"""
config = ConfigLoader()
test_value = ' "foo": "bar" '.encode("utf-8")
b64_value = base64.b64encode(test_value)
config.set("foo", b64_value)
self.assertEqual(config.get_dict_from_b64_json("foo"), {"foo": "bar"})

def test_get_dict_from_b64_json_invalid(self):
"""Test get_dict_from_b64_json with invalid value"""
config = ConfigLoader()
config.set("foo", "bar")
self.assertEqual(config.get_dict_from_b64_json("foo"), {})

def test_attr_json_encoder(self):
"""Test AttrEncoder"""
test_attr = Attr("foo", Attr.Source.ENV, "AUTHENTIK_REDIS__USERNAME")
json_attr = dumps(test_attr, indent=4, cls=AttrEncoder)
self.assertEqual(json_attr, '"foo"')

def test_attr_json_encoder_no_attr(self):
"""Test AttrEncoder if no Attr is passed"""

class Test:
"""Non Attr class"""

with self.assertRaises(TypeError):
test_obj = Test()
dumps(test_obj, indent=4, cls=AttrEncoder)

@mock.patch.dict(environ, check_deprecations_env_vars)
def test_check_deprecations(self):
"""Test config key re-write for deprecated env vars"""
config = ConfigLoader()
config.update_from_env()
config.check_deprecations()
self.assertEqual(config.get("redis.broker_url", UNSET), UNSET)
self.assertEqual(config.get("redis.broker_transport_options", UNSET), UNSET)
self.assertEqual(config.get("redis.cache_timeout", UNSET), UNSET)
self.assertEqual(config.get("redis.cache_timeout_flows", UNSET), UNSET)
self.assertEqual(config.get("redis.cache_timeout_policies", UNSET), UNSET)
self.assertEqual(config.get("redis.cache_timeout_reputation", UNSET), UNSET)
self.assertEqual(config.get("broker.url"), "redis://myredis:8327/43")
self.assertEqual(config.get("broker.transport_options"), "bWFzdGVybmFtZT1teW1hc3Rlcg==")
self.assertEqual(config.get("cache.timeout"), "124s")
self.assertEqual(config.get("cache.timeout_flows"), "32m")
self.assertEqual(config.get("cache.timeout_policies"), "3920ns")
self.assertEqual(config.get("cache.timeout_reputation"), "298382us")
2 changes: 1 addition & 1 deletion authentik/outposts/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def disconnect(self, code):
expected=self.outpost.config.kubernetes_replicas,
).dec()

def receive_json(self, content: Data):
def receive_json(self, content: Data, **kwargs):
msg = from_dict(WebsocketMessage, content)
uid = msg.args.get("uuid", self.channel_name)
self.last_uid = uid
Expand Down
2 changes: 1 addition & 1 deletion authentik/policies/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LOGGER = get_logger()

FORK_CTX = get_context("fork")
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies")
CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_policies")
PROCESS_CLASS = FORK_CTX.Process


Expand Down
2 changes: 1 addition & 1 deletion authentik/policies/reputation/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from authentik.stages.identification.signals import identification_failed

LOGGER = get_logger()
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation")
CACHE_TIMEOUT = CONFIG.get_int("cache.timeout_reputation")


def update_score(request: HttpRequest, identifier: str, amount: int):
Expand Down
14 changes: 8 additions & 6 deletions authentik/root/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""root settings for authentik"""

import importlib
import os
from hashlib import sha512
Expand Down Expand Up @@ -195,8 +194,8 @@
CACHES = {
"default": {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
"TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300),
"LOCATION": CONFIG.get("cache.url") or f"{_redis_url}/{CONFIG.get('redis.db')}",
"TIMEOUT": CONFIG.get_int("cache.timeout", 300),
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
"KEY_PREFIX": "authentik_cache",
}
Expand Down Expand Up @@ -256,7 +255,7 @@
"default": {
"BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
"CONFIG": {
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
"hosts": [CONFIG.get("channel.url", f"{_redis_url}/{CONFIG.get('redis.db')}")],
"prefix": "authentik_channels_",
},
},
Expand Down Expand Up @@ -349,8 +348,11 @@
},
"task_create_missing_queues": True,
"task_default_queue": "authentik",
"broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
"result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
"broker_url": CONFIG.get("broker.url")
or f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
"broker_transport_options": CONFIG.get_dict_from_b64_json("broker.transport_options"),
"result_backend": CONFIG.get("result_backend.url")
or f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
}

# Sentry integration
Expand Down
Loading

0 comments on commit 9db9ad3

Please sign in to comment.