Skip to content

Commit

Permalink
providers/oauth2: remember session_id from initial token (#7976)
Browse files Browse the repository at this point in the history
* providers/oauth2: remember session_id original token was created with for future access/refresh tokens

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* providers/proxy: use hashed session as `sid`

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
  • Loading branch information
BeryJu authored Dec 22, 2023
1 parent 4776d2b commit 21888f5
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Generated by Django 5.0 on 2023-12-22 23:20

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("authentik_providers_oauth2", "0016_alter_refreshtoken_token"),
]

operations = [
migrations.AddField(
model_name="accesstoken",
name="session_id",
field=models.CharField(blank=True, default=""),
),
migrations.AddField(
model_name="authorizationcode",
name="session_id",
field=models.CharField(blank=True, default=""),
),
migrations.AddField(
model_name="refreshtoken",
name="session_id",
field=models.CharField(blank=True, default=""),
),
]
1 change: 1 addition & 0 deletions authentik/providers/oauth2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ class BaseGrantModel(models.Model):
revoked = models.BooleanField(default=False)
_scope = models.TextField(default="", verbose_name=_("Scopes"))
auth_time = models.DateTimeField(verbose_name="Authentication time")
session_id = models.CharField(default="", blank=True)

@property
def scope(self) -> list[str]:
Expand Down
3 changes: 3 additions & 0 deletions authentik/providers/oauth2/views/authorize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""authentik OAuth2 Authorization views"""
from dataclasses import dataclass, field
from datetime import timedelta
from hashlib import sha256
from json import dumps
from re import error as RegexError
from re import fullmatch
Expand Down Expand Up @@ -282,6 +283,7 @@ def create_code(self, request: HttpRequest) -> AuthorizationCode:
expires=now + timedelta_from_string(self.provider.access_code_validity),
scope=self.scope,
nonce=self.nonce,
session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
)

if self.code_challenge and self.code_challenge_method:
Expand Down Expand Up @@ -569,6 +571,7 @@ def create_implicit_response(self, code: Optional[AuthorizationCode]) -> dict:
expires=access_token_expiry,
provider=self.provider,
auth_time=auth_event.created if auth_event else now,
session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
)

id_token = IDToken.new(self.provider, token, self.request)
Expand Down
4 changes: 4 additions & 0 deletions authentik/providers/oauth2/views/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def create_code_response(self) -> dict[str, Any]:
# Keep same scopes as previous token
scope=self.params.authorization_code.scope,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
)
access_token.id_token = IDToken.new(
self.provider,
Expand All @@ -502,6 +503,7 @@ def create_code_response(self) -> dict[str, Any]:
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
)
id_token = IDToken.new(
self.provider,
Expand Down Expand Up @@ -539,6 +541,7 @@ def create_refresh_response(self) -> dict[str, Any]:
# Keep same scopes as previous token
scope=self.params.refresh_token.scope,
auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
)
access_token.id_token = IDToken.new(
self.provider,
Expand All @@ -554,6 +557,7 @@ def create_refresh_response(self) -> dict[str, Any]:
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
)
id_token = IDToken.new(
self.provider,
Expand Down
5 changes: 4 additions & 1 deletion authentik/providers/proxy/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""proxy provider tasks"""
from hashlib import sha256

from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError
Expand All @@ -23,13 +25,14 @@ def proxy_set_defaults():
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()

Check warning on line 28 in authentik/providers/proxy/tasks.py

View check run for this annotation

Codecov / codecov/patch

authentik/providers/proxy/tasks.py#L28

Added line #L28 was not covered by tests
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{
"type": "event.provider.specific",
"sub_type": "logout",
"session_id": session_id,
"session_id": hashed_session_id,
},
)
2 changes: 1 addition & 1 deletion blueprints/system/providers-proxy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ entries:
# This mapping is used by the authentik proxy. It passes extra user attributes,
# which are used for example for the HTTP-Basic Authentication mapping.
return {
"sid": request.http_request.session.session_key,
"sid": token.session_id,
"ak_proxy": {
"user_attributes": request.user.group_attributes(request),
"is_superuser": request.user.is_superuser,
Expand Down
1 change: 1 addition & 0 deletions internal/outpost/proxyv2/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]inte
switch msg.SubType {
case WSProviderSubTypeLogout:
for _, p := range ps.apps {
ps.log.WithField("provider", p.Host).Debug("Logging out")
err := p.Logout(ctx, func(c application.Claims) bool {
return c.Sid == msg.SessionID
})
Expand Down

0 comments on commit 21888f5

Please sign in to comment.