Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

SSO: redirect to public URL before setting cookies #9436

Merged
merged 5 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion synapse/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Union

from twisted.internet import task
from twisted.internet import address, task
from twisted.web.client import FileBodyProducer
from twisted.web.iweb import IRequest

Expand Down Expand Up @@ -53,6 +54,40 @@ def stopProducing(self):
pass


def get_request_uri(request: IRequest) -> bytes:
"""Return the full URI that was requested by the client"""
return b"%s://%s%s" % (
b"https" if request.isSecure() else b"http",
_get_requested_host(request),
# despite its name, "request.uri" is only the path and query-string.
request.uri,
)


def _get_requested_host(request: IRequest) -> bytes:
hostname = request.getHeader(b"host")
if hostname:
return hostname

# no Host header, use the address/port that the request arrived on
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]

hostname = host.host.encode("ascii")

if request.isSecure() and host.port == 443:
# default port for https
return hostname

if not request.isSecure() and host.port == 80:
# default port for http
return hostname

return b"%s:%i" % (
hostname,
host.port,
)


def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default."""
# There could be raw utf-8 bytes in the User-Agent header.
Expand Down
28 changes: 28 additions & 0 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
Expand Down Expand Up @@ -354,6 +355,7 @@ def __init__(self, hs: "HomeServer"):
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl

def register(self, http_server: HttpServer) -> None:
super().register(http_server)
Expand All @@ -373,6 +375,32 @@ def register(self, http_server: HttpServer) -> None:
async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None:
if not self._public_baseurl:
raise SynapseError(400, "SSO requires a valid public_baseurl")

# if this isn't the expected hostname, redirect to the right one, so that we
# get our cookies back.
requested_uri = get_request_uri(request)
baseurl_bytes = self._public_baseurl.encode("utf-8")
if not requested_uri.startswith(baseurl_bytes):
# swap out the incorrect base URL for the right one.
#
# The idea here is to redirect from
# https://foo.bar/whatever/_matrix/...
# to
# https://public.baseurl/_matrix/...
#
i = requested_uri.index(b"/_matrix")
new_uri = baseurl_bytes[:-1] + requested_uri[i:]
richvdh marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
"Requested URI %s is not canonical: redirecting to %s",
requested_uri.decode("utf-8", errors="replace"),
new_uri.decode("utf-8", errors="replace"),
)
request.redirect(new_uri)
finish_request(request)
return

client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
Expand Down
56 changes: 32 additions & 24 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import time
import urllib.parse
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlencode

from mock import Mock
Expand Down Expand Up @@ -47,8 +47,9 @@
HAS_JWT = False


# public_base_url used in some tests
BASE_URL = "https://synapse/"
# synapse server name: used to populate public_base_url in some tests
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swapping to http here, because FakeChannel.isSecure() returns False, so synapse will see the requested uri as http://synapse/.... Configuring that as the public_baseurl avoids the redirect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to include this comment in the code. Just looking at the file, I'd be a bit confused why other URLs here were https, but not the public base url.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough. I've updated some comments.


# CAS server used in some tests
CAS_SERVER = "https://fake.test"
Expand Down Expand Up @@ -480,11 +481,7 @@ def test_get_msc2858_login_flows(self):
def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
channel = self.make_request(
"GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
channel = self._make_sso_redirect_request(False, None)
self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0]

Expand Down Expand Up @@ -628,41 +625,52 @@ def test_multi_sso_redirect_to_unknown(self):

def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")

@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
channel = self._make_sso_redirect_request(True, "xxx")
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")

@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)

channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)

# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)

def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
):
"""Send a request to /_matrix/client/r0/login/sso/redirect

... or the unstable equivalent

... possibly specifying an IDP provider
"""
endpoint = (
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect"
if unstable_endpoint
else "/_matrix/client/r0/login/sso/redirect"
)
if idp_prov is not None:
endpoint += "/" + idp_prov
endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)

return self.make_request(
"GET",
endpoint,
custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
)

@staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = "
Expand Down
14 changes: 14 additions & 0 deletions tests/rest/client/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,20 @@ def initiate_sso_login(
"/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
)

# that should 302 us to the public base url
assert channel.code == 302
location = channel.headers.getRawHeaders("Location")[0]
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
custom_headers=[
("Host", parts[1]),
],
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was the easiest way of figuring out what the Host header should be set to.

assert channel.code == 302
channel.extract_cookies(cookies)
return channel.headers.getRawHeaders("Location")[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/rest/client/v2_alpha/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class UIAuthTests(unittest.HomeserverTestCase):

def default_config(self):
config = super().default_config()
config["public_baseurl"] = "https://synapse.test"
config["public_baseurl"] = "http://synapse.test"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above


if HAS_OIDC:
# we enable OIDC as a way of testing SSO flows
Expand Down
6 changes: 5 additions & 1 deletion tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def getPeer(self):
return address.IPv4Address("TCP", self._ip, 3423)

def getHost(self):
return None
# this is called by Request.__init__ to configure Request.host.
return address.IPv4Address("TCP", "127.0.0.1", 8888)

def isSecure(self):
return False

@property
def transport(self):
Expand Down