diff --git a/dojo_plugin/__init__.py b/dojo_plugin/__init__.py index ae1186a78..19e6713ce 100644 --- a/dojo_plugin/__init__.py +++ b/dojo_plugin/__init__.py @@ -24,6 +24,7 @@ from .pages.workspace import workspace from .pages.desktop import desktop from .pages.users import users +from .pages.sso_login import sso from .pages.settings import settings_override from .pages.course import course from .pages.writeups import writeups @@ -126,6 +127,7 @@ def load(app): app.register_blueprint(dojo) app.register_blueprint(workspace) app.register_blueprint(desktop) + app.register_blueprint(sso) app.register_blueprint(users) app.register_blueprint(course) app.register_blueprint(writeups) diff --git a/dojo_plugin/api/__init__.py b/dojo_plugin/api/__init__.py index 0b7a98f41..e12a4a1d9 100644 --- a/dojo_plugin/api/__init__.py +++ b/dojo_plugin/api/__init__.py @@ -20,3 +20,4 @@ api_v1.add_namespace(dojo_namespace, "/dojo") api_v1.add_namespace(belts_namespace, "/belts") api_v1.add_namespace(score_namespace, "/score") +api_v1.add_namespace(score_namespace, "/sso_login") diff --git a/dojo_plugin/api/v1/sso_login.py b/dojo_plugin/api/v1/sso_login.py new file mode 100644 index 000000000..703eb4689 --- /dev/null +++ b/dojo_plugin/api/v1/sso_login.py @@ -0,0 +1,121 @@ +from urllib.parse import urlencode, urljoin +from urllib.request import urlopen, Request +from urllib.error import URLError + +from CTFd.models import Users, db +from CTFd.utils import user as current_user +from CTFd.utils.helpers import error_for, get_errors, markup +from CTFd.utils.logging import log + + +from xml.etree import ElementTree + +class Settings: + def __init__(self): + self.DEBUG = True + self.CAS_SERVER_URL = 'https://pass.hust.edu.cn/cas/login' + self.CAS_ADMIN_PREFIX = None + self.CAS_EXTRA_LOGIN_PARAMS = None + self.CAS_IGNORE_REFERER = False + self.CAS_LOGOUT_COMPLETELY = True + self.CAS_REDIRECT_URL = 'http://pwn.cse.hust.edu.cn/cas-login/' + self.CAS_RETRY_LOGIN = False + self.CAS_VERSION = '2' + + def __getattr__(self, item): + raise AttributeError(f'Setting {item} not found') + +settings = Settings() + + + +def _verify_cas2(ticket, service): + """Verifies CAS 2.0+ XML-based authentication ticket.""" + headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'} + params = {'ticket': ticket, 'service': service} + url = urljoin(settings.CAS_SERVER_URL, 'proxyValidate') + '?' + urlencode(params) + request = Request(url, headers=headers) + + try: + with urlopen(request,timeout=2) as page: + response = page.read() + tree = ElementTree.fromstring(response) + if tree[0].tag.endswith('authenticationSuccess'): + return tree[0][0].text + else: + return None + except URLError as e: + print(f"URL Error: {e.reason}") + return None + + +def register_sso(studentID): + errors = get_errors() + name = studentID.strip() + email_address = studentID.strip().lower()+"@hust.edu.cn" + password = "" + bracket_id = None + oauth_id = int(studentID[1:]) + # website = request.form.get("website") + # affiliation = request.form.get("affiliation") + # country = request.form.get("country") + # registration_code = str(request.form.get("registration_code", "")) + # bracket_id = request.form.get("bracket_id", None) + + names = ( + Users.query.add_columns(Users.name, Users.id).filter_by(name=name).first() + ) + emails = ( + Users.query.add_columns(Users.email, Users.id) + .filter_by(email=email_address) + .first() + ) + if names: + errors.append("That user name is already taken") + if emails: + errors.append("That email has already been used") + + user = Users( + name=name, + email=email_address, + password=password, + oauth_id=oauth_id, + verified=True, + ) + db.session.add(user) + db.session.commit() + + log( + "registrations", + format="[{date}] {ip} - {name} registered with {email}", + name=user.name, + email=user.email, + ) + db.session.close() + return user + + + +class CASBackend(object): + """CAS authentication backend""" + + def authenticate(self,ticket): + service = settings.CAS_REDIRECT_URL + username = _verify_cas2(ticket, service) + user = Users.query.filter_by(oauth_id=username[1:]).first() + + if user : + return user + else: + # user will have an "unusable" password + user = register_sso(username) + return user + + def get_user(self, user_id): + """Retrieve the user's entry in the User model if it exists""" + + def get_login_url(): + """Generates CAS login URL""" + params = {'service': settings.CAS_REDIRECT_URL} + #return urlopen(settings.CAS_SERVER_URL) + return urljoin(settings.CAS_SERVER_URL, 'login') + '?' + urlencode(params) diff --git a/dojo_plugin/pages/sso_login.py b/dojo_plugin/pages/sso_login.py new file mode 100644 index 000000000..287cd3691 --- /dev/null +++ b/dojo_plugin/pages/sso_login.py @@ -0,0 +1,35 @@ +import datetime +import hashlib +import itertools +import re + +from flask import Blueprint, Response, render_template, abort, url_for,request,redirect +from sqlalchemy.sql import and_, or_ +from CTFd.utils.user import get_current_user +from CTFd.utils.decorators import authed_only +from CTFd.utils.security.auth import login_user, logout_user + +from ..api.v1.sso_login import CASBackend +from ..models import Dojos, DojoModules, DojoChallenges +from ..config import DATA_DIR +from ..utils.scores import dojo_scores, module_scores +from ..utils.awards import get_belts, get_viewable_emojis + +sso = Blueprint("pwncollege_sso", __name__) + + +@sso.route('/cas-login/') +def cas_login(): + ticket = request.args.get('ticket') + casbackend = CASBackend() + if ticket: + user = casbackend.authenticate(ticket) + if user: + # 用户认证成功,创建本地会话等 + login_user(user) + return redirect(url_for('pwncollege_dojos.listing')) + else: + return redirect(url_for("auth.login")) + else: + return redirect(CASBackend.get_login_url()) + diff --git a/dojo_theme/templates/components/navbar.html b/dojo_theme/templates/components/navbar.html index 6ad1ed025..839426983 100644 --- a/dojo_theme/templates/components/navbar.html +++ b/dojo_theme/templates/components/navbar.html @@ -38,10 +38,9 @@ {% else %} {% if registration_visible() %} - {{ navitem("注册", url_for("auth.register"), "fa-user-plus") }} {% endif %} - {{ navitem("登录", url_for("auth.login", next=request.full_path), "fa-sign-in-alt") }} + {{ navitem("统一身份认证", url_for("pwncollege_sso.cas_login", next=request.full_path), "fa-sign-in-alt") }} {% endif %}