diff options
author | Botond Hende <nettingman@gmail.com> | 2024-08-25 19:59:39 +0200 |
---|---|---|
committer | Botond Hende <nettingman@gmail.com> | 2024-08-25 19:59:39 +0200 |
commit | dbb410dafec593ff8f082c07074fbf8148613d9b (patch) | |
tree | 24c79e705296f5146199d29ff486305d330284a5 /__main__.py | |
parent | b3c35c2ca00106fe475d82baa580ee26ee64e1ba (diff) |
basic authentication for clariossAI
Diffstat (limited to '__main__.py')
-rw-r--r-- | __main__.py | 47 |
1 files changed, 44 insertions, 3 deletions
diff --git a/__main__.py b/__main__.py index 7ae3047..2002425 100644 --- a/__main__.py +++ b/__main__.py @@ -1,29 +1,36 @@ import os.path import sqlite3 from pathlib import Path +from enum import Enum from contextlib import asynccontextmanager from typing import Annotated -from fastapi import FastAPI, Depends, Request +from fastapi import FastAPI, Depends, Response from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from passlib.context import CryptContext from .config import Config +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +cursor = None -sq_con = None @asynccontextmanager async def lifespan(app: FastAPI): db_dir = os.path.dirname(Config.DATABASE_PATH) Path(db_dir).mkdir(parents=True, exist_ok=True) + assert sqlite3.threadsafety == 3, "QLite thread safety is not set to 'Serialized'." sq_con = sqlite3.connect(Config.DATABASE_PATH) + global cursor + cursor = sq_con.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS agents(primary_key INTEGER PRIMARY KEY, callsign TEXT NOT NULL UNIQUE, token_hash TEXT NOT NULL)") yield app = FastAPI(lifespan=lifespan) - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -33,6 +40,40 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()): return {"access_token": user, "token_type": "bearer"} +class AuthResult(Enum): + SUCCESS = 0 + NOT_FOUND = 1 + TOKEN_MISMATCH = 2 + + +async def auth_agent(callsign: str, token: str) -> AuthResult: + cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,)) + row = cursor.fetchone() + if row == None: + return AuthResult.NOT_FOUND + + return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH + + +@app.post("/{callsign}/init", status_code=201) +async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): + result = await auth_agent(callsign, token) + + if result == AuthResult.SUCCESS: + response.status_code = 200 + return '{"result": "Agent already registered."}' + + if result == AuthResult.TOKEN_MISMATCH: + response.status_code = 400 + return '{"error": "Agent already registered but with a different token."}' + + # TODO: test token on spacetraders api + + hash = pwd_context.hash(token) + cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, hash)) + return '{"result": "Agent successfully registered."}' + + @app.get("/{callsign}/tasks") async def get_tasks(callsign: str, token: Annotated[str, Depends(oauth2_scheme)]): return f'{{"callsign": "{callsign}", "token": "{token}"}}' |