Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

A third batch of Pydantic validation for rest/client/account.py #13736

Merged
merged 11 commits into from
Sep 15, 2022
1 change: 1 addition & 0 deletions changelog.d/13736.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind).
65 changes: 36 additions & 29 deletions synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
from synapse.push.mailer import Mailer
from synapse.rest.client.models import (
AuthenticationData,
ClientSecretType,
EmailRequestTokenBody,
MsisdnRequestTokenBody,
ThreepidMedium,
)
from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict
Expand Down Expand Up @@ -627,6 +629,11 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData] = None
client_secret: ClientSecretType
sid: StrictStr

@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
Expand All @@ -636,22 +643,17 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)

assert_params_in_dict(body, ["client_secret", "sid"])
sid = body["sid"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
body = parse_and_validate_json_object_from_request(request, self.PostBody)

await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
body.dict(exclude_unset=True),
"add a third-party identifier to your account",
)

validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid
body.client_secret, body.sid
)
if validation_session:
await self.auth_handler.add_threepid(
Expand All @@ -676,23 +678,20 @@ def __init__(self, hs: "HomeServer"):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
class PostBody(RequestBodyModel):
client_secret: ClientSecretType
id_access_token: StrictStr
id_server: StrictStr
sid: StrictStr

assert_params_in_dict(
body, ["id_server", "sid", "id_access_token", "client_secret"]
)
id_server = body["id_server"]
sid = body["sid"]
id_access_token = body["id_access_token"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_and_validate_json_object_from_request(request, self.PostBody)

requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

await self.identity_handler.bind_threepid(
client_secret, sid, user_id, id_server, id_access_token
body.client_secret, body.sid, user_id, body.id_server, body.id_access_token
)

return 200, {}
Expand All @@ -708,23 +707,27 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastores().main

class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: ThreepidMedium

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])

medium = body.get("medium")
address = body.get("address")
id_server = body.get("id_server")
body = parse_and_validate_json_object_from_request(request, self.PostBody)

# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(),
{"address": address, "medium": medium, "id_server": id_server},
{
"address": body.address,
"medium": body.medium,
"id_server": body.id_server,
},
)
return 200, {"id_server_unbind_result": "success" if result else "no-support"}

Expand All @@ -738,21 +741,25 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()

class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: ThreepidMedium

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)

body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
body = parse_and_validate_json_object_from_request(request, self.PostBody)

requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

try:
ret = await self.auth_handler.delete_threepid(
user_id, body["medium"], body["address"], body.get("id_server")
user_id, body.medium, body.address, body.id_server
)
except Exception:
# NB. This endpoint should succeed if there is nothing to
Expand Down
35 changes: 22 additions & 13 deletions synapse/rest/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import TYPE_CHECKING, Dict, Optional

from pydantic import Extra, StrictInt, StrictStr, constr, validator
Expand All @@ -19,6 +20,12 @@
from synapse.util.threepids import validate_email


class ThreepidMedium(str, Enum):
# Per advice at https://pydantic-docs.helpmanual.io/usage/types/#enums-and-choices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you be specific about what advice you're applying / what this is commenting on?
The constants being lowercased perhaps (main thing that strikes me as weird)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. My first thought was to say medium: StrictStr but I wanted to see if there was a way to say that "medium should be one of these specific strings". The link seems to be the blessed way to do this.

What I was suspicious of is the multiple inheritance from str and Enum. These ThreepidMedium types will get passed through to the rest of the application, where I want them to be treated like any other str. I'm not fully convinced that the rest of the application will be happy with that tbh... but I don't have anything more concrete than a bad feeling about this.

Having said that, looking here it might be possible to write medium: Literal["email", "msisdn"]. If it works, that seems a bit less magic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that, looking here it might be possible to write medium: Literal["email", "msisdn"]. If it works, that seems a bit less magic.

This seems to work pretty well:

class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
    class Model(BaseModel):
        medium: Literal["email", "msisdn"]

    def test_accepts_valid_medium_string(self) -> None:
        """Sanity check that Pydantic behaves sensibly with an enum-of-str

        This is arguably more of a test of a class that inherits from str and Enum
        simultaneously.
        """
        model = self.Model.parse_obj({"medium": "email"})
        self.assertIsInstance(model.medium, str)
        self.assertEqual(model.medium, "email")

    def test_rejects_invalid_medium_value(self) -> None:
        with self.assertRaises(ValidationError) as ctx:
            self.Model.parse_obj({"medium": "interpretive_dance"})
        print(ctx.exception)

    def test_rejects_invalid_medium_type(self) -> None:
        with self.assertRaises(ValidationError) as ctx:
            self.Model.parse_obj({"medium": 123})
        print(ctx.exception)

The error messages generated are

Ran 3 tests in 0.002s

PASSED (successes=3)

Process finished with exit code 0
1 validation error for Model
medium
  unexpected value; permitted: 'email', 'msisdn' (type=value_error.const; given=123; permitted=('email', 'msisdn'))
1 validation error for Model
medium
  unexpected value; permitted: 'email', 'msisdn' (type=value_error.const; given=interpretive_dance; permitted=('email', 'msisdn'))

(the errors are a little technical, but see #13337)

Copy link
Contributor

@reivilibre reivilibre Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually totally fine with the enum, but enum constants usually have uppercase names — that's what I was mostly wondering about (alongside a slightly underspecified comment).

Inheriting from both str and Enum and supplying a proper __str__ is, as far as I remember, the sensible way to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(In passing, I noticed pydantic/pydantic#4505)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. That's fair.

I really want to like stdlib enums, but I've been bitten by their sharp corners in the past. We've even seen it ourselves with IntEnum a few times I think? In that light, I think I prefer the Literal approach. But I can also see that a proper Enum can be passed throughout the rest of the application.

Maybe we should dwell on this for a bit. (I'd be interested in other opinions too!)

email = "email"
msisdn = "msisdn"


class AuthenticationData(RequestBodyModel):
"""
Data used during user-interactive authentication.
Expand All @@ -36,18 +43,20 @@ class Config:
type: Optional[StrictStr] = None


class ThreePidRequestTokenBody(RequestBodyModel):
if TYPE_CHECKING:
client_secret: StrictStr
else:
# See also assert_valid_client_secret()
client_secret: constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=0,
max_length=255,
strict=True,
)
if TYPE_CHECKING:
ClientSecretType = StrictStr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny tiny question: would calling this ClientSecretStr make more sense so that the class is more readable without looking up the definition of this? I don't think it matters either way but it may be a good convention to have, in which case we would better consider that now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I like it. So what's the rule of thumb: if you need an alias for a constr, conint or confloat, call it SomethingSomethingStr, FooBarInt or PotatoFloat rather than BlahBlahType?

Perhaps you'd also like to see this for conlist (JSON arrays) and certain mapping types too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, sounds good to me; I wouldn't do it for pydantic models though, so you wouldn't call e.g. PostBody PostBodyMap.

else:
# See also assert_valid_client_secret()
ClientSecretType = constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=0,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
max_length=255,
strict=True,
)


class ThreepidRequestTokenBody(RequestBodyModel):
client_secret: ClientSecretType
id_server: Optional[StrictStr]
id_access_token: Optional[StrictStr]
next_link: Optional[StrictStr]
Expand All @@ -62,7 +71,7 @@ def token_required_for_identity_server(
return token


class EmailRequestTokenBody(ThreePidRequestTokenBody):
class EmailRequestTokenBody(ThreepidRequestTokenBody):
email: StrictStr

# Canonicalise the email address. The addresses are all stored canonicalised
Expand All @@ -80,6 +89,6 @@ class EmailRequestTokenBody(ThreePidRequestTokenBody):
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)


class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
country: ISO3116_1_Alpha_2
phone_number: StrictStr
41 changes: 37 additions & 4 deletions tests/rest/client/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import unittest as stdlib_unittest

from pydantic import ValidationError
from pydantic import BaseModel, ValidationError

from synapse.rest.client.models import EmailRequestTokenBody
from synapse.rest.client.models import EmailRequestTokenBody, ThreepidMedium


class EmailRequestTokenBodyTestCase(unittest.TestCase):
class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
class Model(BaseModel):
medium: ThreepidMedium

def test_accepts_valid_medium_string(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we should roll our own StrEnum that also does __repr__ and __str__ in a less surprising way (i.e. by calling it on the string value)?
I suppose it ultimately doesn't matter for the sake of Pydantic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that 3.11 will have a StrEnum in the stdlib: https://docs.python.org/3.11/library/enum.html?highlight=strenum#enum.StrEnum

There's a backport here. If we do this I'd prefer to import from stdlib and fallback to the backport.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absolutely fine with that idea, as long as we have confidence that the usual pitfall around __str__ is addressed

"""Sanity check that Pydantic behaves sensibly with an enum-of-str

This is arguably more of a test of a class that inherits from str and Enum
simultaneously.
"""
model = self.Model.parse_obj({"medium": "email"})
self.assertIsInstance(model.medium, str)
self.assertEqual(model.medium, "email")
self.assertEqual(model.medium, ThreepidMedium.email)
self.assertIs(model.medium, ThreepidMedium.email)

self.assertNotEqual(model.medium, "msisdn")
self.assertNotEqual(model.medium, ThreepidMedium.msisdn)
self.assertIsNot(model.medium, ThreepidMedium.msisdn)

def test_accepts_valid_medium_enum(self) -> None:
model = self.Model.parse_obj({"medium": ThreepidMedium.email})
self.assertIs(model.medium, ThreepidMedium.email)

def test_rejects_invalid_medium_value(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": "interpretive_dance"})

def test_rejects_invalid_medium_type(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": 123})


class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
base_request = {
"client_secret": "hunter2",
"email": "alice@wonderland.com",
Expand Down