Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

A third batch of Pydantic validation for rest/client/account.py #13736

Merged
merged 11 commits into from
Sep 15, 2022
1 change: 1 addition & 0 deletions changelog.d/13736.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind).
65 changes: 36 additions & 29 deletions synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from urllib.parse import urlparse

from pydantic import StrictBool, StrictStr, constr
from typing_extensions import Literal

from twisted.web.server import Request

Expand All @@ -43,6 +44,7 @@
from synapse.push.mailer import Mailer
from synapse.rest.client.models import (
AuthenticationData,
ClientSecretType,
EmailRequestTokenBody,
MsisdnRequestTokenBody,
)
Expand Down Expand Up @@ -627,6 +629,11 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData] = None
client_secret: ClientSecretType
sid: StrictStr

@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
Expand All @@ -636,22 +643,17 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)

assert_params_in_dict(body, ["client_secret", "sid"])
sid = body["sid"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
body = parse_and_validate_json_object_from_request(request, self.PostBody)

await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
body.dict(exclude_unset=True),
"add a third-party identifier to your account",
)

validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid
body.client_secret, body.sid
)
if validation_session:
await self.auth_handler.add_threepid(
Expand All @@ -676,23 +678,20 @@ def __init__(self, hs: "HomeServer"):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
class PostBody(RequestBodyModel):
client_secret: ClientSecretType
id_access_token: StrictStr
id_server: StrictStr
sid: StrictStr

assert_params_in_dict(
body, ["id_server", "sid", "id_access_token", "client_secret"]
)
id_server = body["id_server"]
sid = body["sid"]
id_access_token = body["id_access_token"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_and_validate_json_object_from_request(request, self.PostBody)

requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

await self.identity_handler.bind_threepid(
client_secret, sid, user_id, id_server, id_access_token
body.client_secret, body.sid, user_id, body.id_server, body.id_access_token
)

return 200, {}
Expand All @@ -708,23 +707,27 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastores().main

class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: Literal["email", "msisdn"]

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])

medium = body.get("medium")
address = body.get("address")
id_server = body.get("id_server")
body = parse_and_validate_json_object_from_request(request, self.PostBody)

# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(),
{"address": address, "medium": medium, "id_server": id_server},
{
"address": body.address,
"medium": body.medium,
"id_server": body.id_server,
},
)
return 200, {"id_server_unbind_result": "success" if result else "no-support"}

Expand All @@ -738,21 +741,25 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: Literal["email", "msisdn"]

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)

body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
body = parse_and_validate_json_object_from_request(request, self.PostBody)

requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

try:
ret = await self.auth_handler.delete_threepid(
user_id, body["medium"], body["address"], body.get("id_server")
user_id, body.medium, body.address, body.id_server
)
except Exception:
# NB. This endpoint should succeed if there is nothing to
Expand Down
28 changes: 15 additions & 13 deletions synapse/rest/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,20 @@ class Config:
type: Optional[StrictStr] = None


class ThreePidRequestTokenBody(RequestBodyModel):
if TYPE_CHECKING:
client_secret: StrictStr
else:
# See also assert_valid_client_secret()
client_secret: constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=0,
max_length=255,
strict=True,
)
if TYPE_CHECKING:
ClientSecretType = StrictStr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny tiny question: would calling this ClientSecretStr make more sense so that the class is more readable without looking up the definition of this? I don't think it matters either way but it may be a good convention to have, in which case we would better consider that now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I like it. So what's the rule of thumb: if you need an alias for a constr, conint or confloat, call it SomethingSomethingStr, FooBarInt or PotatoFloat rather than BlahBlahType?

Perhaps you'd also like to see this for conlist (JSON arrays) and certain mapping types too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, sounds good to me; I wouldn't do it for pydantic models though, so you wouldn't call e.g. PostBody PostBodyMap.

else:
# See also assert_valid_client_secret()
ClientSecretType = constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=1,
max_length=255,
strict=True,
)


class ThreepidRequestTokenBody(RequestBodyModel):
client_secret: ClientSecretType
id_server: Optional[StrictStr]
id_access_token: Optional[StrictStr]
next_link: Optional[StrictStr]
Expand All @@ -62,7 +64,7 @@ def token_required_for_identity_server(
return token


class EmailRequestTokenBody(ThreePidRequestTokenBody):
class EmailRequestTokenBody(ThreepidRequestTokenBody):
email: StrictStr

# Canonicalise the email address. The addresses are all stored canonicalised
Expand All @@ -80,6 +82,6 @@ class EmailRequestTokenBody(ThreePidRequestTokenBody):
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)


class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
country: ISO3116_1_Alpha_2
phone_number: StrictStr
29 changes: 26 additions & 3 deletions tests/rest/client/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import unittest as stdlib_unittest

from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from typing_extensions import Literal

from synapse.rest.client.models import EmailRequestTokenBody


class EmailRequestTokenBodyTestCase(unittest.TestCase):
class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
class Model(BaseModel):
medium: Literal["email", "msisdn"]

def test_accepts_valid_medium_string(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we should roll our own StrEnum that also does __repr__ and __str__ in a less surprising way (i.e. by calling it on the string value)?
I suppose it ultimately doesn't matter for the sake of Pydantic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that 3.11 will have a StrEnum in the stdlib: https://docs.python.org/3.11/library/enum.html?highlight=strenum#enum.StrEnum

There's a backport here. If we do this I'd prefer to import from stdlib and fallback to the backport.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absolutely fine with that idea, as long as we have confidence that the usual pitfall around __str__ is addressed

"""Sanity check that Pydantic behaves sensibly with an enum-of-str

This is arguably more of a test of a class that inherits from str and Enum
simultaneously.
"""
model = self.Model.parse_obj({"medium": "email"})
self.assertEqual(model.medium, "email")

def test_rejects_invalid_medium_value(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": "interpretive_dance"})

def test_rejects_invalid_medium_type(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": 123})


class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
base_request = {
"client_secret": "hunter2",
"email": "alice@wonderland.com",
Expand Down