Skip to content

Commit

Permalink
fixes review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhpoddar committed Aug 11, 2023
1 parent 84acec2 commit 5e08c6d
Show file tree
Hide file tree
Showing 22 changed files with 62 additions and 40 deletions.
2 changes: 1 addition & 1 deletion supertokens_python/recipe/session/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@


async def create_new_session(
tenant_id: str,
request: Any,
tenant_id: str,
user_id: str,
access_token_payload: Union[Dict[str, Any], None] = None,
session_data_in_database: Union[Dict[str, Any], None] = None,
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/session/syncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@


def create_new_session(
tenant_id: str,
request: Any,
tenant_id: str,
user_id: str,
access_token_payload: Union[Dict[str, Any], None] = None,
session_data_in_database: Union[Dict[str, Any], None] = None,
Expand Down
1 change: 0 additions & 1 deletion supertokens_python/recipe/thirdparty/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from supertokens_python.utils import utf_base64decode
from base64 import b64decode
import json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CreateResetPasswordLinkUknownUserIdError,
CreateResetPasswordLinkOkResult,
SendResetPasswordEmailUnknownUserIdError,
SendResetPasswordEmailEmailOkResult
SendResetPasswordEmailEmailOkResult,
)
from supertokens_python.recipe.emailpassword.utils import get_password_reset_link

Expand Down Expand Up @@ -61,6 +61,7 @@ async def get_user_by_third_party_info(
user_context,
)


async def thirdparty_manually_create_or_update_user(
tenant_id: str,
third_party_id: str,
Expand Down
15 changes: 11 additions & 4 deletions supertokens_python/syncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,40 @@


def get_users_oldest_first(
tenant_id: str,
limit: Union[int, None] = None,
pagination_token: Union[str, None] = None,
include_recipe_ids: Union[None, List[str]] = None,
query: Union[None, Dict[str, str]] = None,
) -> UsersResponse:
return sync(
Supertokens.get_instance().get_users(
"public", "ASC", limit, pagination_token, include_recipe_ids, query
tenant_id, "ASC", limit, pagination_token, include_recipe_ids, query
)
)


def get_users_newest_first(
tenant_id: str,
limit: Union[int, None] = None,
pagination_token: Union[str, None] = None,
include_recipe_ids: Union[None, List[str]] = None,
query: Union[None, Dict[str, str]] = None,
) -> UsersResponse:
return sync(
Supertokens.get_instance().get_users(
"public", "DESC", limit, pagination_token, include_recipe_ids, query
tenant_id, "DESC", limit, pagination_token, include_recipe_ids, query
)
)


def get_user_count(include_recipe_ids: Union[None, List[str]] = None) -> int:
return sync(Supertokens.get_instance().get_user_count(include_recipe_ids))
def get_user_count(
include_recipe_ids: Union[None, List[str]] = None,
tenant_id: Optional[str] = None,
) -> int:
return sync(
Supertokens.get_instance().get_user_count(include_recipe_ids, tenant_id)
)


def delete_user(user_id: str) -> None:
Expand Down
15 changes: 10 additions & 5 deletions tests/Django/test_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_cookies(response: HttpResponse) -> Dict[str, Any]:


async def create_new_session_view(request: HttpRequest):
await create_new_session("public", request, "user_id")
await create_new_session(request, "public", "user_id")
return JsonResponse({"foo": "bar"})


Expand Down Expand Up @@ -456,10 +456,12 @@ async def test_thirdparty_parsing_works(self):

start_st()

state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode()
state = b64encode(
json.dumps({"redirectURI": "http://localhost:3000/redirect"}).encode()
).decode()
code = "testing"

data = { "state": state, "code": code}
data = {"state": state, "code": code}

request = self.factory.post(
"/auth/callback/apple",
Expand All @@ -472,8 +474,11 @@ async def test_thirdparty_parsing_works(self):
response = await temp

self.assertEqual(response.status_code, 303)
self.assertEqual(response.content, b'')
self.assertEqual(response.headers['location'], f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}")
self.assertEqual(response.content, b"")
self.assertEqual(
response.headers["location"],
f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}",
)

@pytest.mark.asyncio
async def test_search_with_multiple_emails(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/Fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def driver_config_client():
@app.get("/login")
async def login(request: Request): # type: ignore
user_id = "userId"
await create_new_session("public", request, user_id, {}, {})
await create_new_session(request, "public", user_id, {}, {})
return {"userId": user_id}

@app.post("/refresh")
Expand Down Expand Up @@ -135,12 +135,12 @@ async def custom_logout(request: Request): # type: ignore

@app.post("/create")
async def _create(request: Request): # type: ignore
await create_new_session("public", request, "userId", {}, {})
await create_new_session(request, "public", "userId", {}, {})
return ""

@app.post("/create-throw")
async def _create_throw(request: Request): # type: ignore
await create_new_session("public", request, "userId", {}, {})
await create_new_session(request, "public", "userId", {}, {})
raise UnauthorisedError("unauthorised")

return TestClient(app)
Expand Down
17 changes: 11 additions & 6 deletions tests/Flask/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def t(): # type: ignore
@app.route("/login") # type: ignore
def login(): # type: ignore
user_id = "userId"
create_new_session("public", request, user_id, {}, {})
create_new_session(request, "public", user_id, {}, {})

return jsonify({"userId": user_id, "session": "ssss"})

Expand Down Expand Up @@ -478,15 +478,20 @@ def test_thirdparty_parsing_works(driver_config_app: Any):
start_st()

test_client = driver_config_app.test_client()
state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode()
state = b64encode(
json.dumps({"redirectURI": "http://localhost:3000/redirect"}).encode()
).decode()
code = "testing"

data = { "state": state, "code": code}
data = {"state": state, "code": code}
res = test_client.post("/auth/callback/apple", data=data)

assert res.status_code == 303
assert res.data == b''
assert res.headers["location"] == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"
assert res.data == b""
assert (
res.headers["location"]
== f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"
)


from flask.wrappers import Response
Expand Down Expand Up @@ -747,7 +752,7 @@ def test_api(): # type: ignore
@app.route("/login") # type: ignore
def login(): # type: ignore
user_id = "userId"
s = create_new_session("public", request, user_id, {}, {})
s = create_new_session(request, "public", user_id, {}, {})
return jsonify({"user": s.get_user_id()})

@app.route("/ping") # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion tests/emailpassword/test_emailexists.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def driver_config_client():
@app.get("/login")
async def login(request: Request): # type: ignore
user_id = "userId"
await create_new_session("public", request, user_id, {}, {})
await create_new_session(request, "public", user_id, {}, {})
return {"userId": user_id}

@app.post("/refresh")
Expand Down
2 changes: 1 addition & 1 deletion tests/emailpassword/test_emailverify.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def driver_config_client():
@app.get("/login")
async def login(request: Request): # type: ignore
user_id = "userId"
await create_new_session("public", request, user_id, {}, {})
await create_new_session(request, "public", user_id, {}, {})
return {"userId": user_id}

@app.post("/refresh")
Expand Down
2 changes: 1 addition & 1 deletion tests/emailpassword/test_passwordreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def driver_config_client():
@app.get("/login")
async def login(request: Request): # type: ignore
user_id = "userId"
await create_new_session("public", request, user_id, {}, {})
await create_new_session(request, "public", user_id, {}, {})
return {"userId": user_id}

@app.post("/refresh")
Expand Down
2 changes: 1 addition & 1 deletion tests/emailpassword/test_signin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def driver_config_client():
@app.get("/login")
async def login(request: Request): # type: ignore
user_id = "userId"
await create_new_session("public", request, user_id, {}, {})
await create_new_session(request, "public", user_id, {}, {})
return {"userId": user_id}

@app.post("/refresh")
Expand Down
2 changes: 1 addition & 1 deletion tests/frontendIntegration/django2x/polls/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def login(request: HttpRequest):
if request.method == "POST":
user_id = json.loads(request.body)["userId"]

session_ = create_new_session("public", request, user_id)
session_ = create_new_session(request, "public", user_id)
return HttpResponse(session_.get_user_id())
else:
return send_options_api_response()
Expand Down
2 changes: 1 addition & 1 deletion tests/frontendIntegration/django3x/polls/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ async def login(request: HttpRequest):
if request.method == "POST":
user_id = json.loads(request.body)["userId"]

session_ = await create_new_session("public", request, user_id)
session_ = await create_new_session(request, "public", user_id)
return HttpResponse(session_.get_user_id())
else:
return send_options_api_response()
Expand Down
2 changes: 1 addition & 1 deletion tests/frontendIntegration/fastapi-server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def login_options():
@app.post("/login")
async def login(request: Request):
user_id = (await request.json())["userId"]
_session = await create_new_session("public", request, user_id)
_session = await create_new_session(request, "public", user_id)
return PlainTextResponse(content=_session.get_user_id())


Expand Down
2 changes: 1 addition & 1 deletion tests/frontendIntegration/flask-server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def login_options():
@app.route("/login", methods=["POST"]) # type: ignore
def login():
user_id: str = request.get_json()["userId"] # type: ignore
_session = create_new_session("public", request, user_id)
_session = create_new_session(request, "public", user_id)
return _session.get_user_id()


Expand Down
2 changes: 1 addition & 1 deletion tests/jwt/test_get_JWKS.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def driver_config_client():
@app.get("/login")
async def login(request: Request): # type: ignore
user_id = "userId"
await create_new_session("public", request, user_id, {}, {})
await create_new_session(request, "public", user_id, {}, {})
return {"userId": user_id}

return TestClient(app)
Expand Down
4 changes: 2 additions & 2 deletions tests/sessions/claims/test_verify_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ async def fastapi_client():
@app.post("/login")
async def _login(request: Request): # type: ignore
user_id = "userId"
await create_new_session("public", request, user_id, {}, {})
await create_new_session(request, "public", user_id, {}, {})
return {"userId": user_id}

@app.post("/create-with-claim")
async def _create_with_claim(request: Request): # type: ignore
user_id = "userId"
_ = await create_new_session("public", request, user_id, {}, {})
_ = await create_new_session(request, "public", user_id, {}, {})
key: str = (await request.json())["key"]
# PrimitiveClaim(key, fetch_value="Value").add_to_session(session, "value")
return {"userId": key}
Expand Down
2 changes: 1 addition & 1 deletion tests/sessions/test_access_token_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def _create(request: Request): # type: ignore
except Exception:
pass

session = await create_new_session("public", request, "userId", body, {})
session = await create_new_session(request, "public", "userId", body, {})
return {"message": True, "sessionHandle": session.get_handle()}

@fast.get("/merge-into-payload")
Expand Down
2 changes: 1 addition & 1 deletion tests/sessions/test_auth_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def app():
@fast.post("/create")
async def _create(request: Request): # type: ignore
body = await request.json()
session = await create_new_session("public", request, "userId", body, {})
session = await create_new_session(request, "public", "userId", body, {})
return {"message": True, "sessionHandle": session.get_handle()}

@fast.get("/update-payload")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ async def home(_request: Request): # type: ignore

@app.post("/create")
async def create_api(request: Request): # type: ignore
await async_create_new_session("public", request, "test-user", {}, {})
await async_create_new_session(request, "public", "test-user", {}, {})
return ""

return TestClient(app)
Expand Down
13 changes: 9 additions & 4 deletions tests/thirdparty/test_thirdparty.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,17 @@ async def test_thirdpary_parsing_works(fastapi_client: TestClient):
init(**st_init_args) # type: ignore
start_st()

state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode()
state = b64encode(
json.dumps({"redirectURI": "http://localhost:3000/redirect"}).encode()
).decode()
code = "testing"

data = { "state": state, "code": code}
data = {"state": state, "code": code}
res = fastapi_client.post("/auth/callback/apple", data=data)

assert res.status_code == 303
assert res.content == b''
assert res.headers["location"] == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"
assert res.content == b""
assert (
res.headers["location"]
== f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"
)

0 comments on commit 5e08c6d

Please sign in to comment.