From f2226ce58e1e7fd188dc822ef8ed21046b50de2c Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Sun, 8 Oct 2023 17:26:58 +0200 Subject: [PATCH] providers/proxy: add custom signal on logout to logout in provider Signed-off-by: Jens Langhammer --- authentik/providers/proxy/apps.py | 4 ++ authentik/providers/proxy/signals.py | 20 ++++++++ authentik/providers/proxy/tasks.py | 20 ++++++++ blueprints/system/providers-proxy.yaml | 1 + .../proxyv2/application/application.go | 4 +- .../outpost/proxyv2/application/claims.go | 3 +- .../outpost/proxyv2/application/session.go | 6 +-- internal/outpost/proxyv2/proxyv2.go | 1 + internal/outpost/proxyv2/ws.go | 49 +++++++++++++++++++ 9 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 authentik/providers/proxy/signals.py create mode 100644 internal/outpost/proxyv2/ws.go diff --git a/authentik/providers/proxy/apps.py b/authentik/providers/proxy/apps.py index 5e49fe1810f9..4e1a9a88344d 100644 --- a/authentik/providers/proxy/apps.py +++ b/authentik/providers/proxy/apps.py @@ -9,3 +9,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): label = "authentik_providers_proxy" verbose_name = "authentik Providers.Proxy" default = True + + def reconcile_load_providers_proxy_signals(self): + """Load proxy signals""" + self.import_module("authentik.providers.proxy.signals") diff --git a/authentik/providers/proxy/signals.py b/authentik/providers/proxy/signals.py new file mode 100644 index 000000000000..3e199d3c38a2 --- /dev/null +++ b/authentik/providers/proxy/signals.py @@ -0,0 +1,20 @@ +"""Proxy provider signals""" +from django.contrib.auth.signals import user_logged_out +from django.db.models.signals import pre_delete +from django.dispatch import receiver +from django.http import HttpRequest + +from authentik.core.models import AuthenticatedSession, User +from authentik.providers.proxy.tasks import proxy_on_logout + + +@receiver(user_logged_out) +def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_): + """Catch logout by direct logout and forward to proxy providers""" + proxy_on_logout.delay(request.session.session_key) + + +@receiver(pre_delete, sender=AuthenticatedSession) +def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): + """Catch logout by expiring sessions being deleted""" + proxy_on_logout.delay(instance.session_key) diff --git a/authentik/providers/proxy/tasks.py b/authentik/providers/proxy/tasks.py index a5a4dc45f42d..630b0d186a56 100644 --- a/authentik/providers/proxy/tasks.py +++ b/authentik/providers/proxy/tasks.py @@ -1,6 +1,9 @@ """proxy provider tasks""" +from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer from django.db import DatabaseError, InternalError, ProgrammingError +from authentik.outposts.models import Outpost, OutpostState, OutpostType from authentik.providers.proxy.models import ProxyProvider from authentik.root.celery import CELERY_APP @@ -13,3 +16,20 @@ def proxy_set_defaults(): for provider in ProxyProvider.objects.all(): provider.set_oauth_defaults() provider.save() + + +@CELERY_APP.task() +def proxy_on_logout(session_id: str): + """Update outpost instances connected to a single outpost""" + layer = get_channel_layer() + for outpost in Outpost.objects.filter(type=OutpostType.PROXY): + for state in OutpostState.for_outpost(outpost): + for channel in state.channel_ids: + async_to_sync(layer.send)( + channel, + { + "type": "event.provider.specific", + "sub_type": "logout", + "session_id": session_id, + }, + ) diff --git a/blueprints/system/providers-proxy.yaml b/blueprints/system/providers-proxy.yaml index 1214d157d1ca..0086645a8edd 100644 --- a/blueprints/system/providers-proxy.yaml +++ b/blueprints/system/providers-proxy.yaml @@ -15,6 +15,7 @@ entries: # This mapping is used by the authentik proxy. It passes extra user attributes, # which are used for example for the HTTP-Basic Authentication mapping. return { + "sid": request.http_request.session.session_key, "ak_proxy": { "user_attributes": request.user.group_attributes(request), "is_superuser": request.user.is_superuser, diff --git a/internal/outpost/proxyv2/application/application.go b/internal/outpost/proxyv2/application/application.go index 657bcbec706a..eae4c6774242 100644 --- a/internal/outpost/proxyv2/application/application.go +++ b/internal/outpost/proxyv2/application/application.go @@ -280,7 +280,9 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) { "id_token_hint": []string{cc.RawToken}, } redirect += "?" + uv.Encode() - err = a.Logout(r.Context(), cc.Sub) + err = a.Logout(r.Context(), func(c Claims) bool { + return c.Sub == cc.Sub + }) if err != nil { a.log.WithError(err).Warning("failed to logout of other sessions") } diff --git a/internal/outpost/proxyv2/application/claims.go b/internal/outpost/proxyv2/application/claims.go index bd34e1309142..32f4d26ebe88 100644 --- a/internal/outpost/proxyv2/application/claims.go +++ b/internal/outpost/proxyv2/application/claims.go @@ -11,10 +11,11 @@ type Claims struct { Exp int `json:"exp"` Email string `json:"email"` Verified bool `json:"email_verified"` - Proxy *ProxyClaims `json:"ak_proxy"` Name string `json:"name"` PreferredUsername string `json:"preferred_username"` Groups []string `json:"groups"` + Sid string `json:"sid"` + Proxy *ProxyClaims `json:"ak_proxy"` RawToken string } diff --git a/internal/outpost/proxyv2/application/session.go b/internal/outpost/proxyv2/application/session.go index 739b23e844ee..55d2bbb468d9 100644 --- a/internal/outpost/proxyv2/application/session.go +++ b/internal/outpost/proxyv2/application/session.go @@ -88,7 +88,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec { return cs } -func (a *Application) Logout(ctx context.Context, sub string) error { +func (a *Application) Logout(ctx context.Context, filter func(c Claims) bool) error { if _, ok := a.sessions.(*sessions.FilesystemStore); ok { files, err := os.ReadDir(os.TempDir()) if err != nil { @@ -118,7 +118,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error { continue } claims := s.Values[constants.SessionClaims].(Claims) - if claims.Sub == sub { + if filter(claims) { a.log.WithField("path", fullPath).Trace("deleting session") err := os.Remove(fullPath) if err != nil { @@ -153,7 +153,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error { continue } claims := c.(Claims) - if claims.Sub == sub { + if filter(claims) { a.log.WithField("key", key).Trace("deleting session") _, err := client.Del(ctx, key).Result() if err != nil { diff --git a/internal/outpost/proxyv2/proxyv2.go b/internal/outpost/proxyv2/proxyv2.go index 154f79e348af..70364957fcf0 100644 --- a/internal/outpost/proxyv2/proxyv2.go +++ b/internal/outpost/proxyv2/proxyv2.go @@ -65,6 +65,7 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer { globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) rootMux.PathPrefix("/").HandlerFunc(s.Handle) + ac.AddWSHandler(s.handleWSMessage) return s } diff --git a/internal/outpost/proxyv2/ws.go b/internal/outpost/proxyv2/ws.go new file mode 100644 index 000000000000..b75ba50fd5b2 --- /dev/null +++ b/internal/outpost/proxyv2/ws.go @@ -0,0 +1,49 @@ +package proxyv2 + +import ( + "context" + + "github.com/mitchellh/mapstructure" + "goauthentik.io/internal/outpost/proxyv2/application" +) + +type WSProviderSubType string + +const ( + WSProviderSubTypeLogout WSProviderSubType = "logout" +) + +type WSProviderMsg struct { + SubType WSProviderSubType `mapstructure:"sub_type"` + SessionID string `mapstructure:"session_id"` +} + +func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) { + msg := &WSProviderMsg{} + err := mapstructure.Decode(args, &msg) + if err != nil { + return nil, err + } + return msg, nil +} + +func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) { + msg, err := ParseWSProvider(args) + if err != nil { + ps.log.WithError(err).Warning("invalid provider-specific ws message") + return + } + switch msg.SubType { + case WSProviderSubTypeLogout: + for _, p := range ps.apps { + err := p.Logout(ctx, func(c application.Claims) bool { + return c.Sid == msg.SessionID + }) + if err != nil { + ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout") + } + } + default: + ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type") + } +}