Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Ratelimit 3PID /requestToken API (#9238)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored Jan 28, 2021
1 parent 54a6afe commit 4b73488
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 14 deletions.
1 change: 1 addition & 0 deletions changelog.d/9238.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ratelimited to 3PID `/requestToken` API.
6 changes: 5 additions & 1 deletion docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
#
# The defaults are as shown below.
#
Expand Down Expand Up @@ -857,7 +858,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# remote:
# per_second: 0.01
# burst_count: 3

#
#rc_3pid_validation:
# per_second: 0.003
# burst_count: 5

# Ratelimiting settings for incoming federation
#
Expand Down
2 changes: 1 addition & 1 deletion synapse/config/_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class RootConfig:
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
ratelimit: ratelimiting.RatelimitConfig
ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
Expand Down
13 changes: 11 additions & 2 deletions synapse/config/ratelimiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
defaults={"per_second": 0.17, "burst_count": 3.0},
):
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = config.get("burst_count", defaults["burst_count"])
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))


class FederationRateLimitConfig:
Expand Down Expand Up @@ -102,6 +102,11 @@ def read_config(self, config, **kwargs):
defaults={"per_second": 0.01, "burst_count": 3},
)

self.rc_3pid_validation = RateLimitConfig(
config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5},
)

def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
Expand Down Expand Up @@ -131,6 +136,7 @@ def generate_config_section(self, **kwargs):
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
#
# The defaults are as shown below.
#
Expand Down Expand Up @@ -164,7 +170,10 @@ def generate_config_section(self, **kwargs):
# remote:
# per_second: 0.01
# burst_count: 3
#
#rc_3pid_validation:
# per_second: 0.003
# burst_count: 5
# Ratelimiting settings for incoming federation
#
Expand Down
28 changes: 28 additions & 0 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
HttpResponseException,
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
Expand Down Expand Up @@ -57,6 +59,32 @@ def __init__(self, hs):

self._web_client_location = hs.config.invite_client_location

# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)

def ratelimit_request_token_requests(
self, request: SynapseRequest, medium: str, address: str,
):
"""Used to ratelimit requests to `/requestToken` by IP and address.
Args:
request: The associated request
medium: The type of threepid, e.g. "msisdn" or "email"
address: The actual threepid ID, e.g. the phone number or email address
"""

self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))

async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
Expand Down
12 changes: 10 additions & 2 deletions synapse/rest/client/v2_alpha/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
Expand Down Expand Up @@ -103,6 +103,8 @@ async def on_POST(self, request):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)

self.identity_handler.ratelimit_request_token_requests(request, "email", email)

# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
Expand Down Expand Up @@ -379,6 +381,8 @@ async def on_POST(self, request):
Codes.THREEPID_DENIED,
)

self.identity_handler.ratelimit_request_token_requests(request, "email", email)

if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
Expand Down Expand Up @@ -430,7 +434,7 @@ async def on_POST(self, request):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
Expand Down Expand Up @@ -458,6 +462,10 @@ async def on_POST(self, request):
Codes.THREEPID_DENIED,
)

self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)

if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
Expand Down
6 changes: 6 additions & 0 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ async def on_POST(self, request):
Codes.THREEPID_DENIED,
)

self.identity_handler.ratelimit_request_token_requests(request, "email", email)

existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
Expand Down Expand Up @@ -205,6 +207,10 @@ async def on_POST(self, request):
Codes.THREEPID_DENIED,
)

self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)

existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
Expand Down
90 changes: 84 additions & 6 deletions tests/rest/client/v2_alpha/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes
from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
Expand Down Expand Up @@ -112,6 +112,56 @@ def test_basic_password_reset(self):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)

@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self):
"""Test that we ratelimit /requestToken for the same email.
"""
old_password = "monkey"
new_password = "kangeroo"

user_id = self.register_user("kermit", old_password)
self.login("kermit", old_password)

email = "test1@example.com"

# Add a threepid
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
medium="email",
address=email,
validated_at=0,
added_at=0,
)
)

def reset(ip):
client_secret = "foobar"
session_id = self._request_token(email, client_secret, ip)

self.assertEquals(len(self.email_attempts), 1)
link = self._get_link_from_email()

self._validate_token(link)

self._reset_password(new_password, session_id, client_secret)

self.email_attempts.clear()

# We expect to be able to make three requests before getting rate
# limited.
#
# We change IPs to ensure that we're not being ratelimited due to the
# same IP
reset("127.0.0.1")
reset("127.0.0.2")
reset("127.0.0.3")

with self.assertRaises(HttpResponseException) as cm:
reset("127.0.0.4")

self.assertEqual(cm.exception.code, 429)

def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow
Request password reset with different spelling
Expand Down Expand Up @@ -239,13 +289,18 @@ def test_password_reset_bad_email_inhibit_error(self):

self.assertIsNotNone(session_id)

def _request_token(self, email, client_secret):
def _request_token(self, email, client_secret, ip="127.0.0.1"):
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
client_ip=ip,
)
self.assertEquals(200, channel.code, channel.result)

if channel.code != 200:
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"],
)

return channel.json_body["sid"]

Expand Down Expand Up @@ -509,6 +564,21 @@ def test_add_email_address_casefold(self):
def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))

@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_ip(self):
"""Tests that adding emails is ratelimited by IP
"""

# We expect to be able to set three emails before getting ratelimited.
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))

with self.assertRaises(HttpResponseException) as cm:
self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))

self.assertEqual(cm.exception.code, 429)

def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
"""
Expand Down Expand Up @@ -777,7 +847,11 @@ def _request_token(
body["next_link"] = next_link

channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
self.assertEquals(expect_code, channel.code, channel.result)

if channel.code != expect_code:
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"],
)

return channel.json_body.get("sid")

Expand Down Expand Up @@ -823,10 +897,12 @@ def _get_link_from_email(self):
def _add_email(self, request_email, expected_email):
"""Test adding an email to profile
"""
previous_email_attempts = len(self.email_attempts)

client_secret = "foobar"
session_id = self._request_token(request_email, client_secret)

self.assertEquals(len(self.email_attempts), 1)
self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email()

self._validate_token(link)
Expand Down Expand Up @@ -855,4 +931,6 @@ def _add_email(self, request_email, expected_email):

self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])

threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
self.assertIn(expected_email, threepids)
9 changes: 7 additions & 2 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1")
_producer = None

@property
Expand Down Expand Up @@ -120,7 +121,7 @@ def requestDone(self, _self):
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address("TCP", "127.0.0.1", 3423)
return address.IPv4Address("TCP", self._ip, 3423)

def getHost(self):
return None
Expand Down Expand Up @@ -196,6 +197,7 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Make a web request using the given method, path and content, and render it
Expand Down Expand Up @@ -223,6 +225,9 @@ def make_request(
will pump the reactor until the the renderer tells the channel the request
is finished.
client_ip: The IP to use as the requesting IP. Useful for testing
ratelimiting.
Returns:
channel
"""
Expand Down Expand Up @@ -250,7 +255,7 @@ def make_request(
if isinstance(content, str):
content = content.encode("utf8")

channel = FakeChannel(site, reactor)
channel = FakeChannel(site, reactor, ip=client_ip)

req = request(channel)
req.content = BytesIO(content)
Expand Down
5 changes: 5 additions & 0 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
Expand All @@ -410,6 +411,9 @@ def make_request(
custom_headers: (name, value) pairs to add as request headers
client_ip: The IP to use as the requesting IP. Useful for testing
ratelimiting.
Returns:
The FakeChannel object which stores the result of the request.
"""
Expand All @@ -426,6 +430,7 @@ def make_request(
content_is_form,
await_result,
custom_headers,
client_ip,
)

def setup_test_homeserver(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit 4b73488

Please sign in to comment.