Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

providers/oauth2: remember session_id from initial token #7976

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Generated by Django 5.0 on 2023-12-22 23:20

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("authentik_providers_oauth2", "0016_alter_refreshtoken_token"),
]

operations = [
migrations.AddField(
model_name="accesstoken",
name="session_id",
field=models.CharField(blank=True, default=""),
),
migrations.AddField(
model_name="authorizationcode",
name="session_id",
field=models.CharField(blank=True, default=""),
),
migrations.AddField(
model_name="refreshtoken",
name="session_id",
field=models.CharField(blank=True, default=""),
),
]
1 change: 1 addition & 0 deletions authentik/providers/oauth2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ class BaseGrantModel(models.Model):
revoked = models.BooleanField(default=False)
_scope = models.TextField(default="", verbose_name=_("Scopes"))
auth_time = models.DateTimeField(verbose_name="Authentication time")
session_id = models.CharField(default="", blank=True)

@property
def scope(self) -> list[str]:
Expand Down
3 changes: 3 additions & 0 deletions authentik/providers/oauth2/views/authorize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""authentik OAuth2 Authorization views"""
from dataclasses import dataclass, field
from datetime import timedelta
from hashlib import sha256
from json import dumps
from re import error as RegexError
from re import fullmatch
Expand Down Expand Up @@ -282,6 +283,7 @@ def create_code(self, request: HttpRequest) -> AuthorizationCode:
expires=now + timedelta_from_string(self.provider.access_code_validity),
scope=self.scope,
nonce=self.nonce,
session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
)

if self.code_challenge and self.code_challenge_method:
Expand Down Expand Up @@ -569,6 +571,7 @@ def create_implicit_response(self, code: Optional[AuthorizationCode]) -> dict:
expires=access_token_expiry,
provider=self.provider,
auth_time=auth_event.created if auth_event else now,
session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
)

id_token = IDToken.new(self.provider, token, self.request)
Expand Down
4 changes: 4 additions & 0 deletions authentik/providers/oauth2/views/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def create_code_response(self) -> dict[str, Any]:
# Keep same scopes as previous token
scope=self.params.authorization_code.scope,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
)
access_token.id_token = IDToken.new(
self.provider,
Expand All @@ -502,6 +503,7 @@ def create_code_response(self) -> dict[str, Any]:
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
)
id_token = IDToken.new(
self.provider,
Expand Down Expand Up @@ -539,6 +541,7 @@ def create_refresh_response(self) -> dict[str, Any]:
# Keep same scopes as previous token
scope=self.params.refresh_token.scope,
auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
)
access_token.id_token = IDToken.new(
self.provider,
Expand All @@ -554,6 +557,7 @@ def create_refresh_response(self) -> dict[str, Any]:
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
)
id_token = IDToken.new(
self.provider,
Expand Down
5 changes: 4 additions & 1 deletion authentik/providers/proxy/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""proxy provider tasks"""
from hashlib import sha256

from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError
Expand All @@ -23,13 +25,14 @@
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()

Check warning on line 28 in authentik/providers/proxy/tasks.py

View check run for this annotation

Codecov / codecov/patch

authentik/providers/proxy/tasks.py#L28

Added line #L28 was not covered by tests
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{
"type": "event.provider.specific",
"sub_type": "logout",
"session_id": session_id,
"session_id": hashed_session_id,
},
)
2 changes: 1 addition & 1 deletion blueprints/system/providers-proxy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ entries:
# This mapping is used by the authentik proxy. It passes extra user attributes,
# which are used for example for the HTTP-Basic Authentication mapping.
return {
"sid": request.http_request.session.session_key,
"sid": token.session_id,
"ak_proxy": {
"user_attributes": request.user.group_attributes(request),
"is_superuser": request.user.is_superuser,
Expand Down
1 change: 1 addition & 0 deletions internal/outpost/proxyv2/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]inte
switch msg.SubType {
case WSProviderSubTypeLogout:
for _, p := range ps.apps {
ps.log.WithField("provider", p.Host).Debug("Logging out")
err := p.Logout(ctx, func(c application.Claims) bool {
return c.Sid == msg.SessionID
})
Expand Down
Loading