Skip to content

Commit

Permalink
Mise à jour majeure de django-oauth-toolkit (#6537)
Browse files Browse the repository at this point in the history
  • Loading branch information
Situphen authored Feb 10, 2024
1 parent 91fe319 commit cc1e21e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 107 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requests==2.31.0
# Api dependencies
django-cors-headers==4.3.1
django-filter==23.5
django-oauth-toolkit==1.7.0
django-oauth-toolkit==2.3.0
djangorestframework==3.14.0
drf-extensions==0.7.1
dry-rest-permissions==0.1.10
Expand Down
32 changes: 32 additions & 0 deletions zds/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from django.urls import reverse
from oauth2_provider.models import Application, AccessToken


# As of django-oauth-toolkit (oauth2_provider) 2.0.0, `Application.client_secret` values are hashed
# before being saved in the database. For the tests, we use the same method as django-oauth-toolkit's tests
# which is to store the client_secret cleartext value in CLEARTEXT_SECRET.
# (See https://github.com/jazzband/django-oauth-toolkit/blob/fda64f97974aac78d4ac9c9f0f36e137dbe4fb8c/tests/test_client_credential.py#L26C58-L26C58)
CLEARTEXT_SECRET = "abcdefghijklmnopqrstuvwxyz1234567890"


def authenticate_oauth2_client(client, user, password):
oauth2_client = Application.objects.create(
user=user,
client_type=Application.CLIENT_CONFIDENTIAL,
authorization_grant_type=Application.GRANT_PASSWORD,
client_secret=CLEARTEXT_SECRET,
)
oauth2_client.save()

client.post(
reverse("oauth2_provider:token"),
{
"client_id": oauth2_client.client_id,
"client_secret": CLEARTEXT_SECRET,
"username": user.username,
"password": password,
"grant_type": "password",
},
)
access_token = AccessToken.objects.get(user=user)
client.credentials(HTTP_AUTHORIZATION=f"Bearer {access_token}")
20 changes: 7 additions & 13 deletions zds/gallery/api/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from rest_framework.test import APITestCase, APIClient
from rest_framework_extensions.settings import extensions_api_settings

from zds.api.utils import authenticate_oauth2_client
from zds.gallery.tests.factories import UserGalleryFactory, GalleryFactory, ImageFactory
from zds.gallery.models import Gallery, UserGallery, GALLERY_WRITE, Image, GALLERY_READ
from zds.member.tests.factories import ProfileFactory
from zds.member.api.tests import create_oauth2_client, authenticate_client
from zds.tutorialv2.tests.factories import PublishableContentFactory
from zds.tutorialv2.tests import TutorialTestMixin, override_for_contents

Expand All @@ -21,8 +21,7 @@ class GalleryListAPITest(APITestCase):
def setUp(self):
self.profile = ProfileFactory()
self.client = APIClient()
client_oauth2 = create_oauth2_client(self.profile.user)
authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

caches[extensions_api_settings.DEFAULT_USE_CACHE].clear()

Expand Down Expand Up @@ -92,8 +91,7 @@ def setUp(self):
self.profile = ProfileFactory()
self.other = ProfileFactory()
self.client = APIClient()
client_oauth2 = create_oauth2_client(self.profile.user)
authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

self.gallery = GalleryFactory()

Expand Down Expand Up @@ -222,8 +220,7 @@ def setUp(self):
self.profile = ProfileFactory()
self.other = ProfileFactory()
self.client = APIClient()
client_oauth2 = create_oauth2_client(self.profile.user)
authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

self.gallery = GalleryFactory()
UserGalleryFactory(user=self.profile.user, gallery=self.gallery)
Expand Down Expand Up @@ -358,8 +355,7 @@ def setUp(self):
self.profile = ProfileFactory()
self.other = ProfileFactory()
self.client = APIClient()
client_oauth2 = create_oauth2_client(self.profile.user)
authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

self.gallery = GalleryFactory()
UserGalleryFactory(user=self.profile.user, gallery=self.gallery)
Expand Down Expand Up @@ -506,8 +502,7 @@ def setUp(self):
self.other = ProfileFactory()
self.client = APIClient()
self.new_participant = ProfileFactory()
client_oauth2 = create_oauth2_client(self.profile.user)
authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

self.gallery = GalleryFactory()
UserGalleryFactory(user=self.profile.user, gallery=self.gallery)
Expand Down Expand Up @@ -620,8 +615,7 @@ def setUp(self):
self.other = ProfileFactory()
self.new_participant = ProfileFactory()
self.client = APIClient()
client_oauth2 = create_oauth2_client(self.profile.user)
authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

self.gallery = GalleryFactory()
UserGalleryFactory(user=self.profile.user, gallery=self.gallery)
Expand Down
82 changes: 20 additions & 62 deletions zds/member/api/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from django.contrib.auth.models import User, Group
from django.core import mail
from django.urls import reverse
from oauth2_provider.models import Application, AccessToken
from rest_framework import status
from rest_framework.test import APITestCase
from rest_framework.test import APIClient

from zds.api.pagination import REST_PAGE_SIZE, REST_MAX_PAGE_SIZE, REST_PAGE_SIZE_QUERY_PARAM
from zds.api.utils import authenticate_oauth2_client
from zds.member.tests.factories import ProfileFactory, StaffProfileFactory, ProfileNotSyncFactory
from zds.member.models import TokenRegister, BannedEmailProvider
from rest_framework_extensions.settings import extensions_api_settings
Expand Down Expand Up @@ -366,9 +366,8 @@ def test_detail_of_the_member(self):
Gets all information about the user.
"""
profile = ProfileFactory()
client_oauth2 = create_oauth2_client(profile.user)
client_authenticated = APIClient()
authenticate_client(client_authenticated, client_oauth2, profile.user.username, "hostel77")
authenticate_oauth2_client(client_authenticated, profile.user, "hostel77")

response = client_authenticated.get(reverse("api:member:profile"))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand Down Expand Up @@ -403,9 +402,8 @@ def setUp(self):
self.client = APIClient()
self.profile = ProfileFactory()

client_oauth2 = create_oauth2_client(self.profile.user)
self.client_authenticated = APIClient()
authenticate_client(self.client_authenticated, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(self.client_authenticated, self.profile.user, "hostel77")

caches[extensions_api_settings.DEFAULT_USE_CACHE].clear()

Expand Down Expand Up @@ -497,9 +495,8 @@ def test_update_member_details_with_user_not_synchronized(self):
"""
decal = ProfileNotSyncFactory()

client_oauth2 = create_oauth2_client(decal.user)
client_authenticated = APIClient()
authenticate_client(client_authenticated, client_oauth2, decal.user.username, "hostel77")
authenticate_oauth2_client(client_authenticated, decal.user, "hostel77")

response = client_authenticated.put(reverse("api:member:detail", args=[decal.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand Down Expand Up @@ -681,9 +678,8 @@ def setUp(self):

self.profile = ProfileFactory()
self.staff = StaffProfileFactory()
client_oauth2 = create_oauth2_client(self.staff.user)
self.client_authenticated = APIClient()
authenticate_client(self.client_authenticated, client_oauth2, self.staff.user.username, "hostel77")
authenticate_oauth2_client(self.client_authenticated, self.staff.user, "hostel77")

caches[extensions_api_settings.DEFAULT_USE_CACHE].clear()

Expand Down Expand Up @@ -749,9 +745,8 @@ def test_apply_read_only_at_a_member_without_permissions(self):
"""
Tries to apply a read only sanction at a member with a user isn't authenticated.
"""
client_oauth2 = create_oauth2_client(self.profile.user)
client_authenticated = APIClient()
authenticate_client(client_authenticated, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(client_authenticated, self.profile.user, "hostel77")

response = client_authenticated.post(reverse("api:member:read-only", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
Expand Down Expand Up @@ -817,9 +812,8 @@ def test_remove_read_only_at_a_member_without_permissions(self):
"""
Tries to remove a read only sanction at a member with a user isn't authenticated.
"""
client_oauth2 = create_oauth2_client(self.profile.user)
client_authenticated = APIClient()
authenticate_client(client_authenticated, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(client_authenticated, self.profile.user, "hostel77")

response = client_authenticated.delete(reverse("api:member:read-only", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
Expand All @@ -840,9 +834,8 @@ def setUp(self):

self.profile = ProfileFactory()
self.staff = StaffProfileFactory()
client_oauth2 = create_oauth2_client(self.staff.user)
self.client_authenticated = APIClient()
authenticate_client(self.client_authenticated, client_oauth2, self.staff.user.username, "hostel77")
authenticate_oauth2_client(self.client_authenticated, self.staff.user, "hostel77")

caches[extensions_api_settings.DEFAULT_USE_CACHE].clear()

Expand Down Expand Up @@ -908,9 +901,8 @@ def test_apply_ban_at_a_member_without_permissions(self):
"""
Tries to apply a ban sanction at a member with a user isn't authenticated.
"""
client_oauth2 = create_oauth2_client(self.profile.user)
client_authenticated = APIClient()
authenticate_client(client_authenticated, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(client_authenticated, self.profile.user, "hostel77")

response = client_authenticated.post(reverse("api:member:ban", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
Expand Down Expand Up @@ -976,9 +968,8 @@ def test_remove_ban_at_a_member_without_permissions(self):
"""
Tries to remove a ban sanction at a member with a user isn't authenticated.
"""
client_oauth2 = create_oauth2_client(self.profile.user)
client_authenticated = APIClient()
authenticate_client(client_authenticated, client_oauth2, self.profile.user.username, "hostel77")
authenticate_oauth2_client(client_authenticated, self.profile.user, "hostel77")

response = client_authenticated.delete(reverse("api:member:ban", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
Expand Down Expand Up @@ -1011,9 +1002,7 @@ def test_has_read_permission_for_authenticated_users(self):
"""
Authenticated users have the permission to read any member.
"""
authenticate_client(
self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77"
)
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -1023,7 +1012,7 @@ def test_has_read_permission_for_staff_users(self):
"""
Staff users have the permission to read any member.
"""
authenticate_client(self.client, create_oauth2_client(self.staff.user), self.staff.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.staff.user, "hostel77")

response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -1041,9 +1030,7 @@ def test_has_write_permission_for_authenticated_user(self):
"""
A user authenticated have write permissions.
"""
authenticate_client(
self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77"
)
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -1053,7 +1040,7 @@ def test_has_write_permission_for_staff(self):
"""
A staff user have write permissions.
"""
authenticate_client(self.client, create_oauth2_client(self.staff.user), self.staff.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.staff.user, "hostel77")

response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -1071,9 +1058,7 @@ def test_has_update_permission_for_authenticated_user(self):
"""
Only the user authenticated have update permissions.
"""
authenticate_client(
self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77"
)
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -1083,7 +1068,7 @@ def test_has_not_update_permission_for_staff(self):
"""
Only the user authenticated have update permissions.
"""
authenticate_client(self.client, create_oauth2_client(self.staff.user), self.staff.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.staff.user, "hostel77")

response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -1101,9 +1086,7 @@ def test_has_not_ban_permission_for_authenticated_user(self):
"""
Only staff have ban permission.
"""
authenticate_client(
self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77"
)
authenticate_oauth2_client(self.client, self.profile.user, "hostel77")

response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -1113,7 +1096,7 @@ def test_has_ban_permission_for_staff(self):
"""
Only staff have ban permission.
"""
authenticate_client(self.client, create_oauth2_client(self.staff.user), self.staff.user.username, "hostel77")
authenticate_oauth2_client(self.client, self.staff.user, "hostel77")

response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id]))
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -1132,14 +1115,12 @@ def test_cache_of_user_authenticated_for_member_profile(self):
profile = ProfileFactory()
another_profile = ProfileFactory()

authenticate_client(self.client, create_oauth2_client(profile.user), profile.user.username, "hostel77")
authenticate_oauth2_client(self.client, profile.user, "hostel77")
response = self.client.get(reverse("api:member:profile"))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(profile.user.username, response.data.get("username"))

authenticate_client(
self.client, create_oauth2_client(another_profile.user), another_profile.user.username, "hostel77"
)
authenticate_oauth2_client(self.client, another_profile.user, "hostel77")
response = self.client.get(reverse("api:member:profile"))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(another_profile.user.username, response.data.get("username"))
Expand Down Expand Up @@ -1167,26 +1148,3 @@ def test_cache_invalidated_when_new_member(self):
response = self.client.get(reverse("api:member:list"))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data.get("count"), count + 1)


def create_oauth2_client(user):
client = Application.objects.create(
user=user, client_type=Application.CLIENT_CONFIDENTIAL, authorization_grant_type=Application.GRANT_PASSWORD
)
client.save()
return client


def authenticate_client(client, client_auth, username, password):
client.post(
"/oauth2/token/",
{
"client_id": client_auth.client_id,
"client_secret": client_auth.client_secret,
"username": username,
"password": password,
"grant_type": "password",
},
)
access_token = AccessToken.objects.get(user__username=username)
client.credentials(HTTP_AUTHORIZATION=f"Bearer {access_token}")
Loading

0 comments on commit cc1e21e

Please sign in to comment.