From a0e04e28bbab7ecc6a0c3b1bee72da303b82aebd Mon Sep 17 00:00:00 2001 From: Botond Hende Date: Sun, 25 Aug 2024 21:18:08 +0200 Subject: split code into separate files --- __main__.py | 75 ++++------------------------------------------- apirouters/__init__.py | 0 apirouters/agents.py | 53 +++++++++++++++++++++++++++++++++ datatemplates/agent_id.py | 6 ---- modules/database.py | 12 ++++++++ 5 files changed, 71 insertions(+), 75 deletions(-) create mode 100644 apirouters/__init__.py create mode 100644 apirouters/agents.py delete mode 100644 datatemplates/agent_id.py create mode 100644 modules/database.py diff --git a/__main__.py b/__main__.py index 2002425..833bf72 100644 --- a/__main__.py +++ b/__main__.py @@ -1,86 +1,23 @@ -import os.path -import sqlite3 -from pathlib import Path -from enum import Enum - from contextlib import asynccontextmanager from typing import Annotated -from fastapi import FastAPI, Depends, Response -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from passlib.context import CryptContext - -from .config import Config +from fastapi import FastAPI, Depends -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -cursor = None +import apirouters.agents +import modules.database @asynccontextmanager async def lifespan(app: FastAPI): - db_dir = os.path.dirname(Config.DATABASE_PATH) - Path(db_dir).mkdir(parents=True, exist_ok=True) - - assert sqlite3.threadsafety == 3, "QLite thread safety is not set to 'Serialized'." - sq_con = sqlite3.connect(Config.DATABASE_PATH) - global cursor - cursor = sq_con.cursor() - cursor.execute( + modules.database.cursor.execute( "CREATE TABLE IF NOT EXISTS agents(primary_key INTEGER PRIMARY KEY, callsign TEXT NOT NULL UNIQUE, token_hash TEXT NOT NULL)") yield app = FastAPI(lifespan=lifespan) -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - - -@app.post("/token") -async def login(form_data: OAuth2PasswordRequestForm = Depends()): - user = form_data.username - return {"access_token": user, "token_type": "bearer"} - - -class AuthResult(Enum): - SUCCESS = 0 - NOT_FOUND = 1 - TOKEN_MISMATCH = 2 - - -async def auth_agent(callsign: str, token: str) -> AuthResult: - cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,)) - row = cursor.fetchone() - if row == None: - return AuthResult.NOT_FOUND - - return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH - - -@app.post("/{callsign}/init", status_code=201) -async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): - result = await auth_agent(callsign, token) - - if result == AuthResult.SUCCESS: - response.status_code = 200 - return '{"result": "Agent already registered."}' - - if result == AuthResult.TOKEN_MISMATCH: - response.status_code = 400 - return '{"error": "Agent already registered but with a different token."}' - - # TODO: test token on spacetraders api - - hash = pwd_context.hash(token) - cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, hash)) - return '{"result": "Agent successfully registered."}' +app.include_router(apirouters.agents.router) @app.get("/{callsign}/tasks") -async def get_tasks(callsign: str, token: Annotated[str, Depends(oauth2_scheme)]): +async def get_tasks(callsign: str, token: Annotated[str, Depends(apirouters.agents.oauth2_scheme)]): return f'{{"callsign": "{callsign}", "token": "{token}"}}' - -# if __name__ == "__main__": -# with open(os.path.join(os.path.dirname(__file__), Config.TOKEN_FILE_NAME)) as f: -# token = f.read().strip() -# -# d = daemon.Daemon(Config.AGENT_SYMBOL, token) -# d.run() diff --git a/apirouters/__init__.py b/apirouters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apirouters/agents.py b/apirouters/agents.py new file mode 100644 index 0000000..3d148d2 --- /dev/null +++ b/apirouters/agents.py @@ -0,0 +1,53 @@ +from enum import Enum +from typing import Annotated + +from fastapi import APIRouter, Depends, Response +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from passlib.context import CryptContext + +from modules.database import cursor, sq_con + +router = APIRouter() +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +@router.post("/token") +async def login(form_data: OAuth2PasswordRequestForm = Depends()): + user = form_data.username + return {"access_token": user, "token_type": "bearer"} + + +class AuthResult(Enum): + SUCCESS = 0 + NOT_FOUND = 1 + TOKEN_MISMATCH = 2 + + +async def auth_agent(callsign: str, token: str) -> AuthResult: + cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,)) + row = cursor.fetchone() + if row is None: + return AuthResult.NOT_FOUND + + return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH + + +@router.post("/{callsign}/init", status_code=201) +async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): + result = await auth_agent(callsign, token) + + if result == AuthResult.SUCCESS: + response.status_code = 200 + return '{"result": "Agent already registered."}' + + if result == AuthResult.TOKEN_MISMATCH: + response.status_code = 400 + return '{"error": "Agent already registered but with a different token."}' + + # TODO: test token on spacetraders api + + token_hash = pwd_context.hash(token) + cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, token_hash)) + sq_con.commit() + return '{"result": "Agent successfully registered."}' diff --git a/datatemplates/agent_id.py b/datatemplates/agent_id.py deleted file mode 100644 index 2424a40..0000000 --- a/datatemplates/agent_id.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel - - -class AgentId(BaseModel): - callsign: str - token: str diff --git a/modules/database.py b/modules/database.py new file mode 100644 index 0000000..f75f68e --- /dev/null +++ b/modules/database.py @@ -0,0 +1,12 @@ +import os.path +import sqlite3 +from pathlib import Path +from config import Config + +db_dir = os.path.dirname(Config.DATABASE_PATH) +Path(db_dir).mkdir(parents=True, exist_ok=True) + +assert sqlite3.threadsafety == 3, "QLite thread safety is not set to 'Serialized'." +sq_con = sqlite3.connect(Config.DATABASE_PATH) + +cursor = sq_con.cursor() -- cgit v1.2.3-70-g09d2