Skip to content

Commit

Permalink
Merge branch 'auth-backend' into cloud-auth
Browse files Browse the repository at this point in the history
  • Loading branch information
LatentDream committed Apr 30, 2024
2 parents e3c57c8 + a887e6e commit 48ca647
Show file tree
Hide file tree
Showing 51 changed files with 1,781 additions and 701 deletions.
85 changes: 64 additions & 21 deletions captain/middleware/auth_middleware.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,82 @@
from typing import Callable, Optional
from fastapi import Request, HTTPException, status
from captain.services.auth.auth_service import validate_credentials
import base64
from captain.services.auth.auth_service import get_user, has_cloud_access, has_write_access
from captain.types.auth import Auth


async def is_admin(req: Request):
def _with_verify_access(func: Callable[[str, str]]):
async def wrapper(req: Request):
exception_txt = "You are not authorized to perform this action"
studio_cookie = req.cookies.get("studio-auth")

if not studio_cookie:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=exception_txt,
)

try:
credentials = base64.b64decode(studio_cookie).decode("utf-8")
username, token = credentials.split(":", 1)
authorized = has_cloud_access(username, token)
func(username, token)

if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=exception_txt,
)
except Exception:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=exception_txt,
)
return wrapper


@_with_verify_access
async def can_write(username, token):
"""
Middleware to check if the user can modify protected resources
Example of use
@router.get("/write", dependencies=[Depends(can_write)])
async def update():
return "resource updated"
"""
Middleware to check if the user is an admin
return has_write_access(username, token)


@_with_verify_access
async def is_connected(username, token):
"""
Middleware to check if the user has access to the cloud
Example of use
@router.get("/write", dependencies=[Depends(is_admin)])
@router.get("/write", dependencies=[Depends(is_connected)])
async def update():
return "resource updated"
"""
return has_cloud_access(username, token)


def retreive_user(req: Request) -> Auth:
"""
Access the information store in the current user
Should be use in tendem with the `dependencies=[Depends(is_connected)]` middleware
- Raise an HTTPException if the user is not connected
"""
exception_txt = "You are not authorized to perform this action"
studio_cookie = req.cookies.get("studio-auth")

if not studio_cookie:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=exception_txt,
detail="User is not connected"
)

try:
credentials = base64.b64decode(studio_cookie).decode("utf-8")
username, password = credentials.split(":", 1)
authorized = validate_credentials(username, password)

if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=exception_txt,
)
except Exception:
credentials = base64.b64decode(studio_cookie).decode("utf-8")
username, token = credentials.split(":", 1)
user = get_user(username, token)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=exception_txt,
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
3 changes: 3 additions & 0 deletions captain/models/test_sequencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class TestTypes(StrEnum):
python = "python"
flojoy = "flojoy"
matlab = "matlab"
placeholder = "placeholder"
robotframework = "robotframework"


class StatusTypes(StrEnum):
Expand Down Expand Up @@ -65,6 +67,7 @@ class Test(BaseModel):
max_value: Optional[float] = Field(None, alias="maxValue")
measured_value: Optional[float] = Field(None, alias="measuredValue")
unit: Optional[str] = Field(None, alias="unit")
args: Optional[List[str]] = Field(None, alias="args")


class Role(StrEnum):
Expand Down
20 changes: 7 additions & 13 deletions captain/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter, Response
from captain.services.auth.auth_service import (
validate_credentials,
save_user,
get_base64_credentials,
)
from captain.types.auth import Auth
Expand All @@ -10,18 +10,12 @@

@router.post("/auth/login/")
async def login(response: Response, auth: Auth):
if not validate_credentials(auth.username, auth.password):
response.set_cookie(
key="studio-auth",
value="",
path="/",
samesite="none",
secure=True,
)
return "Login failed"

encoded_credentials = get_base64_credentials(auth.username, auth.password)

""" Login to the backend of the app
- Actual auth with password and username is done in the frontend with cloud
- Backend auth serves as a middleware to store Cloud credentials
"""
save_user(auth)
encoded_credentials = get_base64_credentials(auth.username, auth.token)
response.set_cookie(
key="studio-auth",
value=encoded_credentials,
Expand Down
85 changes: 47 additions & 38 deletions captain/routes/cloud.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
import logging
import requests
from fastapi import APIRouter, Header, Response
from flojoy.env_var import get_env_var, get_flojoy_cloud_url
from fastapi import APIRouter, Depends, HTTPException, Header, Request, Response
from flojoy_cloud import test_sequencer
from pydantic import BaseModel, Field
from typing import Annotated, Optional
Expand All @@ -11,6 +10,9 @@
import pandas as pd
from functools import wraps

from captain.middleware.auth_middleware import is_connected, retreive_user
from captain.types.auth import Auth


# Utils ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -49,10 +51,10 @@ def inner(*args, **kwargs):
return decorator


async def get_cloud_part_variation(part_variation_id: str):
async def get_cloud_part_variation(part_variation_id: str, user: Auth):
logging.info("Querying part variation")
url = get_flojoy_cloud_url() + "partVariation/" + part_variation_id
response = requests.get(url, headers=headers_builder())
url = user.url + "partVariation/" + part_variation_id
response = requests.get(url, headers=headers_builder(user))
res = response.json()
res["partVariationId"] = part_variation_id
logging.info("Part variation retrieved: %s", res)
Expand All @@ -65,24 +67,23 @@ class SecretNotFound(Exception):

def error_response_builder(e: Exception) -> Response:
logging.error(f"Error from Flojoy Cloud: {e}")
if isinstance(e, HTTPException):
return Response(status_code=e.status_code, content=json.dumps(e.__str__))
if isinstance(e, SecretNotFound):
return Response(status_code=401, content=json.dumps([]))
else:
return Response(status_code=500, content=json.dumps([]))


@temporary_cache
def headers_builder(with_workspace_id=True) -> dict:
workspace_secret = get_env_var("FLOJOY_CLOUD_WORKSPACE_SECRET")
def headers_builder(user: Auth, with_workspace_id=True) -> dict:
logging.info("Querying workspace current")
if workspace_secret is None:
raise SecretNotFound
headers = {
"Content-Type": "application/json",
"flojoy-workspace-personal-secret": workspace_secret,
"flojoy-workspace-personal-secret": user.token,
}
if with_workspace_id:
url = get_flojoy_cloud_url() + "workspace/"
url = user.url + "workspace/"
response = requests.get(url, headers=headers)
if response.status_code != 200:
logging.error(f"Failed to get workspace id {url}: {response.text}")
Expand Down Expand Up @@ -187,10 +188,10 @@ def get_measurement(m: Measurement) -> MeasurementData:
return data


async def get_part(part_id: str) -> Part:
async def get_part(part_id: str, user: Auth) -> Part:
logging.info("Querying part")
url = get_flojoy_cloud_url() + "part/" + part_id
response = requests.get(url, headers=headers_builder())
url = user.url + "part/" + part_id
response = requests.get(url, headers=headers_builder(user))
return Part(**response.json())


Expand All @@ -200,22 +201,23 @@ async def get_part(part_id: str) -> Part:
router = APIRouter(tags=["cloud"])


@router.get("/cloud/projects/")
async def get_cloud_projects():
@router.get("/cloud/projects/", dependencies=[Depends(is_connected)])
async def get_cloud_projects(req: Request):
"""
Get all projects from the Flojoy Cloud.
"""
try:
logging.info("Querying projects")
url = get_flojoy_cloud_url() + "project/"
response = requests.get(url, headers=headers_builder())
user = retreive_user(req)
url = user.url + "project/"
response = requests.get(url, headers=headers_builder(user))
if response.status_code != 200:
return Response(status_code=response.status_code, content=json.dumps([]))
projects = [Project(**project_data) for project_data in response.json()]
projects_res = []
for p in projects:
part_var = await get_cloud_part_variation(p.part_variation_id)
part = await get_part(part_var.part_id)
part_var = await get_cloud_part_variation(p.part_variation_id, user)
part = await get_part(part_var.part_id, user)
projects_res.append(
{
"label": p.name,
Expand All @@ -234,16 +236,17 @@ async def get_cloud_projects():
return error_response_builder(e)


@router.get("/cloud/stations/{project_id}")
async def get_cloud_stations(project_id: str):
@router.get("/cloud/stations/{project_id}", dependencies=[Depends(is_connected)])
async def get_cloud_stations(project_id: str, req: Request):
"""
Get all station of a project from the Flojoy Cloud.
"""
try:
logging.info("Querying stations")
url = get_flojoy_cloud_url() + "station/"
user = retreive_user(req)
url = user.url + "station/"
querystring = {"projectId": project_id}
response = requests.get(url, headers=headers_builder(), params=querystring)
response = requests.get(url, headers=headers_builder(user), params=querystring)
if response.status_code != 200:
logging.error(f"Error getting stations from Flojoy Cloud: {response.text}")
return Response(status_code=response.status_code, content=json.dumps([]))
Expand All @@ -256,12 +259,13 @@ async def get_cloud_stations(project_id: str):
return error_response_builder(e)


@router.get("/cloud/partVariation/{part_var_id}/unit")
async def get_cloud_variant_unit(part_var_id: str):
@router.get("/cloud/partVariation/{part_var_id}/unit", dependencies=[Depends(is_connected)])
async def get_cloud_variant_unit(part_var_id: str, req: Request):
try:
logging.info(f"Querying unit for part {part_var_id}")
url = f"{get_flojoy_cloud_url()}partVariation/{part_var_id}/unit"
response = requests.get(url, headers=headers_builder())
user = retreive_user(req)
url = f"{user.url}partVariation/{part_var_id}/unit"
response = requests.get(url, headers=headers_builder(user))
if response.status_code != 200:
logging.error(f"Error getting stations from Flojoy Cloud: {response.text}")
return Response(status_code=response.status_code, content=json.dumps([]))
Expand All @@ -272,19 +276,20 @@ async def get_cloud_variant_unit(part_var_id: str):
return error_response_builder(e)


@router.post("/cloud/session/")
async def post_cloud_session(_: Response, body: Session):
@router.post("/cloud/session/", dependencies=[Depends(is_connected)])
async def post_cloud_session(_: Response, body: Session, req: Request):
try:
logging.info("Posting session")
url = get_flojoy_cloud_url() + "session/"
user = retreive_user(req)
url = user.url + "session/"
payload = body.model_dump(by_alias=True)
payload["createdAt"] = utcnow_str()
for i, m in enumerate(payload["measurements"]):
m["data"] = make_payload(get_measurement(body.measurements[i]), m["unit"])
m["pass"] = m.pop("pass_")
m["durationMs"] = int(m.pop("completionTime") * 1000)
del m["unit"]
response = requests.post(url, json=payload, headers=headers_builder())
response = requests.post(url, json=payload, headers=headers_builder(user))
if response.status_code == 200:
return Response(status_code=200, content=json.dumps(response.json()))
else:
Expand All @@ -296,15 +301,19 @@ async def post_cloud_session(_: Response, body: Session):
return error_response_builder(e)


@router.get("/cloud/user/")
async def get_user_info(secret: Annotated[str | None, Header()]):
# Verify cloud connection ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


@router.get("/cloud/user/", dependencies=[Depends(is_connected)])
async def get_user_info(req: Request, secret: Annotated[str | None, Header()]):
try:
logging.info("Querying user info")
url = get_flojoy_cloud_url() + "user/"
user = retreive_user(req)
url = user.url + "user/"
headers = (
{"flojoy-workspace-personal-secret": secret}
if secret
else headers_builder(with_workspace_id=False)
else headers_builder(user, with_workspace_id=False)
)
response = requests.get(url, headers=headers)
if response.status_code == 200:
Expand All @@ -319,11 +328,11 @@ async def get_user_info(secret: Annotated[str | None, Header()]):


@router.get("/cloud/health/")
async def get_cloud_health(url: Annotated[str | None, Header()]):
async def get_cloud_health(url: Annotated[str | None, Header()], req: Request):
try:
logging.info("Querying health")
if url is None:
url = get_flojoy_cloud_url()
url = retreive_user(req).url
url = url + "health/"
response = requests.get(url)
if response.status_code == 200:
Expand Down
8 changes: 4 additions & 4 deletions captain/routes/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from fastapi import APIRouter, Response, status, Depends
from flojoy import delete_env_var, get_credentials, get_env_var, set_env_var
from captain.middleware.auth_middleware import is_admin
from captain.middleware.auth_middleware import can_write, is_connected
from captain.types.key import EnvVar
from captain.utils.logger import logger

router = APIRouter(tags=["env"])


@router.post("/env/", dependencies=[Depends(is_admin)])
@router.post("/env/", dependencies=[Depends(can_write)])
async def set_env_var_route(env_var: EnvVar):
try:
set_env_var(env_var.key, env_var.value)
Expand All @@ -25,7 +25,7 @@ async def set_env_var_route(env_var: EnvVar):
return Response(status_code=200)


@router.delete("/env/{key_name}", dependencies=[Depends(is_admin)])
@router.delete("/env/{key_name}", dependencies=[Depends(can_write)])
async def delete_env_var_route(key_name: str):
try:
delete_env_var(key_name)
Expand All @@ -36,7 +36,7 @@ async def delete_env_var_route(key_name: str):
return Response(status_code=200)


@router.get("/env/{key_name}", response_model=EnvVar, dependencies=[Depends(is_admin)])
@router.get("/env/{key_name}", response_model=EnvVar, dependencies=[Depends(can_write)])
async def get_env_var_by_name_route(key_name: str):
value: Optional[str] = get_env_var(key_name)
if value is None:
Expand Down
Loading

0 comments on commit 48ca647

Please sign in to comment.