Skip to content

Commit

Permalink
providers/proxy: add custom signal on logout to logout in provider
Browse files Browse the repository at this point in the history
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
  • Loading branch information
BeryJu committed Oct 8, 2023
1 parent be86903 commit f2226ce
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 5 deletions.
4 changes: 4 additions & 0 deletions authentik/providers/proxy/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig):
label = "authentik_providers_proxy"
verbose_name = "authentik Providers.Proxy"
default = True

def reconcile_load_providers_proxy_signals(self):
"""Load proxy signals"""
self.import_module("authentik.providers.proxy.signals")
20 changes: 20 additions & 0 deletions authentik/providers/proxy/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Proxy provider signals"""
from django.contrib.auth.signals import user_logged_out
from django.db.models.signals import pre_delete
from django.dispatch import receiver
from django.http import HttpRequest

from authentik.core.models import AuthenticatedSession, User
from authentik.providers.proxy.tasks import proxy_on_logout


@receiver(user_logged_out)
def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to proxy providers"""
proxy_on_logout.delay(request.session.session_key)


@receiver(pre_delete, sender=AuthenticatedSession)
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""
proxy_on_logout.delay(instance.session_key)
20 changes: 20 additions & 0 deletions authentik/providers/proxy/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""proxy provider tasks"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError

from authentik.outposts.models import Outpost, OutpostState, OutpostType
from authentik.providers.proxy.models import ProxyProvider
from authentik.root.celery import CELERY_APP

Expand All @@ -13,3 +16,20 @@ def proxy_set_defaults():
for provider in ProxyProvider.objects.all():
provider.set_oauth_defaults()
provider.save()


@CELERY_APP.task()
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
for state in OutpostState.for_outpost(outpost):
for channel in state.channel_ids:
async_to_sync(layer.send)(
channel,
{
"type": "event.provider.specific",
"sub_type": "logout",
"session_id": session_id,
},
)
1 change: 1 addition & 0 deletions blueprints/system/providers-proxy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ entries:
# This mapping is used by the authentik proxy. It passes extra user attributes,
# which are used for example for the HTTP-Basic Authentication mapping.
return {
"sid": request.http_request.session.session_key,
"ak_proxy": {
"user_attributes": request.user.group_attributes(request),
"is_superuser": request.user.is_superuser,
Expand Down
4 changes: 3 additions & 1 deletion internal/outpost/proxyv2/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
"id_token_hint": []string{cc.RawToken},
}
redirect += "?" + uv.Encode()
err = a.Logout(r.Context(), cc.Sub)
err = a.Logout(r.Context(), func(c Claims) bool {
return c.Sub == cc.Sub
})
if err != nil {
a.log.WithError(err).Warning("failed to logout of other sessions")
}
Expand Down
3 changes: 2 additions & 1 deletion internal/outpost/proxyv2/application/claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ type Claims struct {
Exp int `json:"exp"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
Proxy *ProxyClaims `json:"ak_proxy"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Groups []string `json:"groups"`
Sid string `json:"sid"`
Proxy *ProxyClaims `json:"ak_proxy"`

RawToken string
}
6 changes: 3 additions & 3 deletions internal/outpost/proxyv2/application/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec {
return cs
}

func (a *Application) Logout(ctx context.Context, sub string) error {
func (a *Application) Logout(ctx context.Context, filter func(c Claims) bool) error {
if _, ok := a.sessions.(*sessions.FilesystemStore); ok {
files, err := os.ReadDir(os.TempDir())
if err != nil {
Expand Down Expand Up @@ -118,7 +118,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error {
continue
}
claims := s.Values[constants.SessionClaims].(Claims)
if claims.Sub == sub {
if filter(claims) {
a.log.WithField("path", fullPath).Trace("deleting session")
err := os.Remove(fullPath)
if err != nil {
Expand Down Expand Up @@ -153,7 +153,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error {
continue
}
claims := c.(Claims)
if claims.Sub == sub {
if filter(claims) {
a.log.WithField("key", key).Trace("deleting session")
_, err := client.Del(ctx, key).Result()
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions internal/outpost/proxyv2/proxyv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer {
globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic)
globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing))
rootMux.PathPrefix("/").HandlerFunc(s.Handle)
ac.AddWSHandler(s.handleWSMessage)
return s
}

Expand Down
49 changes: 49 additions & 0 deletions internal/outpost/proxyv2/ws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package proxyv2

import (
"context"

"github.com/mitchellh/mapstructure"
"goauthentik.io/internal/outpost/proxyv2/application"
)

type WSProviderSubType string

const (
WSProviderSubTypeLogout WSProviderSubType = "logout"
)

type WSProviderMsg struct {
SubType WSProviderSubType `mapstructure:"sub_type"`
SessionID string `mapstructure:"session_id"`
}

func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) {
msg := &WSProviderMsg{}
err := mapstructure.Decode(args, &msg)
if err != nil {
return nil, err
}
return msg, nil
}

func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) {
msg, err := ParseWSProvider(args)
if err != nil {
ps.log.WithError(err).Warning("invalid provider-specific ws message")
return
}
switch msg.SubType {
case WSProviderSubTypeLogout:
for _, p := range ps.apps {
err := p.Logout(ctx, func(c application.Claims) bool {
return c.Sid == msg.SessionID
})
if err != nil {
ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout")
}
}
default:
ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type")
}
}

0 comments on commit f2226ce

Please sign in to comment.