summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--__main__.py47
1 files changed, 44 insertions, 3 deletions
diff --git a/__main__.py b/__main__.py
index 7ae3047..2002425 100644
--- a/__main__.py
+++ b/__main__.py
@@ -1,29 +1,36 @@
import os.path
import sqlite3
from pathlib import Path
+from enum import Enum
from contextlib import asynccontextmanager
from typing import Annotated
-from fastapi import FastAPI, Depends, Request
+from fastapi import FastAPI, Depends, Response
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
+from passlib.context import CryptContext
from .config import Config
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
+cursor = None
-sq_con = None
@asynccontextmanager
async def lifespan(app: FastAPI):
db_dir = os.path.dirname(Config.DATABASE_PATH)
Path(db_dir).mkdir(parents=True, exist_ok=True)
+ assert sqlite3.threadsafety == 3, "QLite thread safety is not set to 'Serialized'."
sq_con = sqlite3.connect(Config.DATABASE_PATH)
+ global cursor
+ cursor = sq_con.cursor()
+ cursor.execute(
+ "CREATE TABLE IF NOT EXISTS agents(primary_key INTEGER PRIMARY KEY, callsign TEXT NOT NULL UNIQUE, token_hash TEXT NOT NULL)")
yield
app = FastAPI(lifespan=lifespan)
-
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@@ -33,6 +40,40 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
return {"access_token": user, "token_type": "bearer"}
+class AuthResult(Enum):
+ SUCCESS = 0
+ NOT_FOUND = 1
+ TOKEN_MISMATCH = 2
+
+
+async def auth_agent(callsign: str, token: str) -> AuthResult:
+ cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,))
+ row = cursor.fetchone()
+ if row == None:
+ return AuthResult.NOT_FOUND
+
+ return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH
+
+
+@app.post("/{callsign}/init", status_code=201)
+async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
+ result = await auth_agent(callsign, token)
+
+ if result == AuthResult.SUCCESS:
+ response.status_code = 200
+ return '{"result": "Agent already registered."}'
+
+ if result == AuthResult.TOKEN_MISMATCH:
+ response.status_code = 400
+ return '{"error": "Agent already registered but with a different token."}'
+
+ # TODO: test token on spacetraders api
+
+ hash = pwd_context.hash(token)
+ cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, hash))
+ return '{"result": "Agent successfully registered."}'
+
+
@app.get("/{callsign}/tasks")
async def get_tasks(callsign: str, token: Annotated[str, Depends(oauth2_scheme)]):
return f'{{"callsign": "{callsign}", "token": "{token}"}}'