diff --git a/carsharing.db b/carsharing.db index 1f7d1c7..458caaa 100644 Binary files a/carsharing.db and b/carsharing.db differ diff --git a/carsharing.py b/carsharing.py index e221350..2f10152 100644 --- a/carsharing.py +++ b/carsharing.py @@ -1,18 +1,22 @@ import uvicorn -from fastapi import FastAPI, Request +from fastapi import FastAPI +from fastapi import Request from fastapi.middleware.cors import CORSMiddleware -from sqlmodel import SQLModel +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from sqlmodel import SQLModel, Session, select +from starlette import status from starlette.responses import JSONResponse from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY -from db import engine -from routers import cars, web +from db import engine, get_session +from routers import cars, web, auth from routers.cars import BadTripException +from schemas import UserOutput, User app = FastAPI(title="Car Sharing") app.include_router(web.router) app.include_router(cars.router) - +app.include_router(auth.router) origins = [ "http://localhost:8000", @@ -50,4 +54,3 @@ async def add_cars_cookie(request: Request, call_next): if __name__ == "__main__": uvicorn.run("carsharing:app", reload=True) - diff --git a/create_user.py b/create_user.py new file mode 100644 index 0000000..6935118 --- /dev/null +++ b/create_user.py @@ -0,0 +1,36 @@ +""" +create_user.py +------------- +A convenience script to create a user. +""" + +from getpass import getpass + +from sqlmodel import SQLModel, Session, create_engine + +from schemas import User + + +engine = create_engine( + "sqlite:///carsharing.db", + connect_args={"check_same_thread": False}, # Needed for SQLite + echo=True # Log generated SQL +) + + +if __name__ == "__main__": + print("Creating tables (if necessary)") + SQLModel.metadata.create_all(engine) + + print("--------") + + print("This script will create a user and save it in the database.") + + username = input("Please enter username\n") + pwd = getpass("Please enter password\n") + + with Session(engine) as session: + user = User(username=username) + user.set_password(pwd) + session.add(user) + session.commit() diff --git a/routers/auth.py b/routers/auth.py new file mode 100644 index 0000000..beea008 --- /dev/null +++ b/routers/auth.py @@ -0,0 +1,36 @@ +from fastapi import Depends, HTTPException, APIRouter +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from sqlmodel import Session, select +from starlette import status + +from db import get_session +from schemas import UserOutput, User + +URL_PREFIX="/auth" +router = APIRouter(prefix=URL_PREFIX) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{URL_PREFIX}/token") + + +def get_current_user(token: str = Depends(oauth2_scheme), + session: Session = Depends(get_session)) -> UserOutput: + query = select(User).where(User.username == token) + user = session.exec(query).first() + if user: + return UserOutput.from_orm(user) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Username or password incorrect", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +@router.post("/token") +async def login(form_data: OAuth2PasswordRequestForm = Depends(), + session: Session = Depends(get_session)): + query = select(User).where(User.username == form_data.username) + user = session.exec(query).first() + if user and user.verify_password(form_data.password): + return {"access_token": user.username, "token_type": "bearer"} + else: + raise HTTPException(status_code=400, detail="Incorrect username or password") \ No newline at end of file diff --git a/routers/cars.py b/routers/cars.py index 226e407..ebc5d45 100644 --- a/routers/cars.py +++ b/routers/cars.py @@ -1,9 +1,9 @@ from fastapi import Depends, HTTPException, APIRouter from sqlmodel import Session, select +from routers.auth import get_current_user from db import get_session -from schemas import Car, CarOutput, CarInput, Trip, TripInput - +from schemas import Car, CarOutput, CarInput, Trip, TripInput, User router = APIRouter(prefix="/api/cars") @@ -29,7 +29,9 @@ def car_by_id(id: int, session: Session = Depends(get_session)) -> Car: @router.post("/", response_model=Car) -def add_car(car_input: CarInput, session: Session = Depends(get_session)) -> Car: +def add_car(car_input: CarInput, + session: Session = Depends(get_session), + user: User = Depends(get_current_user)) -> Car: new_car = Car.from_orm(car_input) session.add(new_car) session.commit() diff --git a/schemas.py b/schemas.py index 45770b8..2f3dc06 100644 --- a/schemas.py +++ b/schemas.py @@ -1,4 +1,25 @@ -from sqlmodel import SQLModel, Field, Relationship +from sqlmodel import SQLModel, Field, Relationship, Column, VARCHAR +from passlib.context import CryptContext + +pwd_context = CryptContext(schemes=["bcrypt"]) + +class UserOutput(SQLModel): + id: int + username: str + + +class User(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + username: str = Field(sa_column=Column("username", VARCHAR, unique=True, index=True)) + password_hash: str = "" + + def set_password(self, password): + """Setting the passwords actually sets password_hash.""" + self.password_hash = pwd_context.hash(password) + + def verify_password(self, password): + """Verify given password by hashing and comparing to password_hash.""" + return pwd_context.verify(password, self.password_hash) class TripInput(SQLModel):