From dbb410dafec593ff8f082c07074fbf8148613d9b Mon Sep 17 00:00:00 2001 From: Botond Hende Date: Sun, 25 Aug 2024 19:59:39 +0200 Subject: basic authentication for clariossAI --- __main__.py | 47 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) (limited to '__main__.py') diff --git a/__main__.py b/__main__.py index 7ae3047..2002425 100644 --- a/__main__.py +++ b/__main__.py @@ -1,29 +1,36 @@ import os.path import sqlite3 from pathlib import Path +from enum import Enum from contextlib import asynccontextmanager from typing import Annotated -from fastapi import FastAPI, Depends, Request +from fastapi import FastAPI, Depends, Response from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from passlib.context import CryptContext from .config import Config +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +cursor = None -sq_con = None @asynccontextmanager async def lifespan(app: FastAPI): db_dir = os.path.dirname(Config.DATABASE_PATH) Path(db_dir).mkdir(parents=True, exist_ok=True) + assert sqlite3.threadsafety == 3, "QLite thread safety is not set to 'Serialized'." sq_con = sqlite3.connect(Config.DATABASE_PATH) + global cursor + cursor = sq_con.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS agents(primary_key INTEGER PRIMARY KEY, callsign TEXT NOT NULL UNIQUE, token_hash TEXT NOT NULL)") yield app = FastAPI(lifespan=lifespan) - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -33,6 +40,40 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()): return {"access_token": user, "token_type": "bearer"} +class AuthResult(Enum): + SUCCESS = 0 + NOT_FOUND = 1 + TOKEN_MISMATCH = 2 + + +async def auth_agent(callsign: str, token: str) -> AuthResult: + cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,)) + row = cursor.fetchone() + if row == None: + return AuthResult.NOT_FOUND + + return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH + + +@app.post("/{callsign}/init", status_code=201) +async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): + result = await auth_agent(callsign, token) + + if result == AuthResult.SUCCESS: + response.status_code = 200 + return '{"result": "Agent already registered."}' + + if result == AuthResult.TOKEN_MISMATCH: + response.status_code = 400 + return '{"error": "Agent already registered but with a different token."}' + + # TODO: test token on spacetraders api + + hash = pwd_context.hash(token) + cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, hash)) + return '{"result": "Agent successfully registered."}' + + @app.get("/{callsign}/tasks") async def get_tasks(callsign: str, token: Annotated[str, Depends(oauth2_scheme)]): return f'{{"callsign": "{callsign}", "token": "{token}"}}' -- cgit v1.2.3-70-g09d2