Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OIDC: Plugin-customizable OpenIDProvider class #982

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added mwdb/core/oauth/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions mwdb/core/oauth.py → mwdb/core/oauth/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@


class OpenIDClient:
"""
Stateful client representing OpenID Connect session using
specified client and provider data
"""

supported_algorithms = ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"]

def __init__(
Expand Down
122 changes: 122 additions & 0 deletions mwdb/core/oauth/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import hashlib
from typing import TYPE_CHECKING, Iterator

from authlib.oidc.core import UserInfo
from marshmallow import ValidationError
from sqlalchemy import exists

from mwdb.schema.user import UserLoginSchemaBase

from .client import OpenIDClient

if TYPE_CHECKING:
from mwdb.model import Group, User


class OpenIDProvider:
"""
OpenID Connect Identity Provider representation with generic handlers.

You can override these methods with your own implementation
that is specific for provider.
"""

scope = "openid profile email"

def __init__(
self,
name,
client_id,
client_secret,
authorization_endpoint,
token_endpoint,
userinfo_endpoint,
jwks_uri,
):
self.name = name
self.client = OpenIDClient(
client_id=client_id,
client_secret=client_secret,
grant_type="authorization_code",
response_type="code",
scope=self.scope,
authorization_endpoint=authorization_endpoint,
token_endpoint=token_endpoint,
userinfo_endpoint=userinfo_endpoint,
jwks_uri=jwks_uri,
state=None,
)

def get_group_name(self) -> str:
"""
Group name that is used for registering a new OpenID provider
"""
return ("OpenID_" + self.name)[:32]

def create_provider_group(self) -> "Group":
"""
Creates a Group model object for a new OpenID provider
"""
from mwdb.model import Group

group_name = self.get_group_name()
return Group(name=group_name, immutable=True, workspace=False)

def iter_user_name_variants(self, sub: bytes, userinfo: UserInfo) -> Iterator[str]:
"""
Yield username variants that are used when user registers using OpenID identity

Usernames are yielded starting from most-preferred
"""
login_claims = ["preferred_username", "nickname", "name"]

for claim in login_claims:
username = userinfo.get(claim)
if not username:
continue
yield username
# If no candidates in claims: try fallback login
sub_md5 = hashlib.md5(sub.encode("utf-8")).hexdigest()[:8]
yield f"{self.name}-{sub_md5}"

def get_user_email(self, sub: bytes, userinfo: UserInfo) -> str:
"""
User e-mail that is used when user registers using OpenID identity
"""
if "email" in userinfo.keys():
return userinfo["email"]
else:
return f"{sub}@mwdb.local"

def get_user_description(self, sub: bytes, userinfo: UserInfo) -> str:
"""
User description that is used when user registers using OpenID identity
"""
return "Registered via OpenID Connect protocol"

def create_user(self, sub: bytes, userinfo: UserInfo) -> "User":
"""
Creates a User model object for a new OpenID identity user
"""
from mwdb.model import Group, User, db

for username in self.iter_user_name_variants(sub, userinfo):
try:
UserLoginSchemaBase().load({"login": username})
except ValidationError:
continue
already_exists = db.session.query(
exists().where(Group.name == username)
).scalar()
if not already_exists:
break
else:
raise RuntimeError("Can't find any good username candidate for user")

user_email = self.get_user_email(sub, userinfo)
user_description = self.get_user_description(sub, userinfo)
return User.create(
username,
user_email,
user_description,
)
5 changes: 5 additions & 0 deletions mwdb/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

_plugin_handlers = []
loaded_plugins = {}
openid_provider_classes = {}


class PluginAppContext(object):
Expand All @@ -33,6 +34,10 @@ def register_converter(self, converter_name, converter):
def register_schema_spec(self, schema_name, schema):
api.spec.components.schema(schema_name, schema=schema)

def register_openid_provider_class(self, provider_name, provider_class):
global openid_provider_classes
openid_provider_classes[provider_name] = provider_class


def hook_handler_method(meth):
@functools.wraps(meth)
Expand Down
4 changes: 2 additions & 2 deletions mwdb/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema
from .file import File # noqa: E402
from .group import Group, Member # noqa: E402
from .karton import KartonAnalysis, karton_object # noqa: E402
from .oauth import OpenIDProvider, OpenIDUserIdentity # noqa: E402
from .oauth import OpenIDProviderSettings, OpenIDUserIdentity # noqa: E402
from .object import Object, relation # noqa: E402
from .object_permission import ObjectPermission # noqa: E402
from .quick_query import QuickQuery # noqa: E402
Expand All @@ -74,7 +74,7 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema
"AttributePermission",
"Object",
"ObjectPermission",
"OpenIDProvider",
"OpenIDProviderSettings",
"OpenIDUserIdentity",
"relation",
"QuickQuery",
Expand Down
28 changes: 15 additions & 13 deletions mwdb/model/oauth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from mwdb.core.oauth import OpenIDClient
from typing import Type

from mwdb.core.oauth.provider import OpenIDProvider

from . import db


class OpenIDProvider(db.Model):
def get_oidc_provider_class(provider_name: str) -> Type[OpenIDProvider]:
from mwdb.core.plugins import openid_provider_classes

return openid_provider_classes.get(provider_name, OpenIDProvider)


class OpenIDProviderSettings(db.Model):
__tablename__ = "openid_provider"

id = db.Column(db.Integer, primary_key=True, autoincrement=True)
Expand All @@ -28,24 +36,18 @@ class OpenIDProvider(db.Model):
cascade="all, delete",
)

def get_oidc_client(self):
return OpenIDClient(
def get_oidc_provider(self):
openid_provider_class = get_oidc_provider_class(self.name)
return openid_provider_class(
name=self.name,
client_id=self.client_id,
client_secret=self.client_secret,
scope="openid profile email",
grant_type="authorization_code",
response_type="code",
authorization_endpoint=self.authorization_endpoint,
token_endpoint=self.token_endpoint,
userinfo_endpoint=self.userinfo_endpoint,
jwks_uri=self.jwks_endpoint,
state=None,
)

@property
def group_name(self):
return ("OpenID_" + self.name)[:32]


class OpenIDUserIdentity(db.Model):
__tablename__ = "openid_identity"
Expand All @@ -63,5 +65,5 @@ class OpenIDUserIdentity(db.Model):

user = db.relationship("User", back_populates="openid_identities")
provider = db.relationship(
OpenIDProvider, back_populates="identities", lazy="selectin"
OpenIDProviderSettings, back_populates="identities", lazy="selectin"
)
Loading
Loading