summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBotond Hende <nettingman@gmail.com>2024-08-25 21:18:08 +0200
committerBotond Hende <nettingman@gmail.com>2024-08-25 21:18:08 +0200
commita0e04e28bbab7ecc6a0c3b1bee72da303b82aebd (patch)
tree8a8000e514a27280104266c0870dd4d62953aaa2
parentdbb410dafec593ff8f082c07074fbf8148613d9b (diff)
split code into separate files
-rw-r--r--__main__.py75
-rw-r--r--apirouters/__init__.py0
-rw-r--r--apirouters/agents.py53
-rw-r--r--datatemplates/agent_id.py6
-rw-r--r--modules/database.py12
5 files changed, 71 insertions, 75 deletions
diff --git a/__main__.py b/__main__.py
index 2002425..833bf72 100644
--- a/__main__.py
+++ b/__main__.py
@@ -1,86 +1,23 @@
-import os.path
-import sqlite3
-from pathlib import Path
-from enum import Enum
-
from contextlib import asynccontextmanager
from typing import Annotated
-from fastapi import FastAPI, Depends, Response
-from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
-from passlib.context import CryptContext
-
-from .config import Config
+from fastapi import FastAPI, Depends
-pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
-cursor = None
+import apirouters.agents
+import modules.database
@asynccontextmanager
async def lifespan(app: FastAPI):
- db_dir = os.path.dirname(Config.DATABASE_PATH)
- Path(db_dir).mkdir(parents=True, exist_ok=True)
-
- assert sqlite3.threadsafety == 3, "QLite thread safety is not set to 'Serialized'."
- sq_con = sqlite3.connect(Config.DATABASE_PATH)
- global cursor
- cursor = sq_con.cursor()
- cursor.execute(
+ modules.database.cursor.execute(
"CREATE TABLE IF NOT EXISTS agents(primary_key INTEGER PRIMARY KEY, callsign TEXT NOT NULL UNIQUE, token_hash TEXT NOT NULL)")
yield
app = FastAPI(lifespan=lifespan)
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
-
-
-@app.post("/token")
-async def login(form_data: OAuth2PasswordRequestForm = Depends()):
- user = form_data.username
- return {"access_token": user, "token_type": "bearer"}
-
-
-class AuthResult(Enum):
- SUCCESS = 0
- NOT_FOUND = 1
- TOKEN_MISMATCH = 2
-
-
-async def auth_agent(callsign: str, token: str) -> AuthResult:
- cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,))
- row = cursor.fetchone()
- if row == None:
- return AuthResult.NOT_FOUND
-
- return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH
-
-
-@app.post("/{callsign}/init", status_code=201)
-async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
- result = await auth_agent(callsign, token)
-
- if result == AuthResult.SUCCESS:
- response.status_code = 200
- return '{"result": "Agent already registered."}'
-
- if result == AuthResult.TOKEN_MISMATCH:
- response.status_code = 400
- return '{"error": "Agent already registered but with a different token."}'
-
- # TODO: test token on spacetraders api
-
- hash = pwd_context.hash(token)
- cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, hash))
- return '{"result": "Agent successfully registered."}'
+app.include_router(apirouters.agents.router)
@app.get("/{callsign}/tasks")
-async def get_tasks(callsign: str, token: Annotated[str, Depends(oauth2_scheme)]):
+async def get_tasks(callsign: str, token: Annotated[str, Depends(apirouters.agents.oauth2_scheme)]):
return f'{{"callsign": "{callsign}", "token": "{token}"}}'
-
-# if __name__ == "__main__":
-# with open(os.path.join(os.path.dirname(__file__), Config.TOKEN_FILE_NAME)) as f:
-# token = f.read().strip()
-#
-# d = daemon.Daemon(Config.AGENT_SYMBOL, token)
-# d.run()
diff --git a/apirouters/__init__.py b/apirouters/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apirouters/__init__.py
diff --git a/apirouters/agents.py b/apirouters/agents.py
new file mode 100644
index 0000000..3d148d2
--- /dev/null
+++ b/apirouters/agents.py
@@ -0,0 +1,53 @@
+from enum import Enum
+from typing import Annotated
+
+from fastapi import APIRouter, Depends, Response
+from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
+from passlib.context import CryptContext
+
+from modules.database import cursor, sq_con
+
+router = APIRouter()
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+@router.post("/token")
+async def login(form_data: OAuth2PasswordRequestForm = Depends()):
+ user = form_data.username
+ return {"access_token": user, "token_type": "bearer"}
+
+
+class AuthResult(Enum):
+ SUCCESS = 0
+ NOT_FOUND = 1
+ TOKEN_MISMATCH = 2
+
+
+async def auth_agent(callsign: str, token: str) -> AuthResult:
+ cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,))
+ row = cursor.fetchone()
+ if row is None:
+ return AuthResult.NOT_FOUND
+
+ return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH
+
+
+@router.post("/{callsign}/init", status_code=201)
+async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
+ result = await auth_agent(callsign, token)
+
+ if result == AuthResult.SUCCESS:
+ response.status_code = 200
+ return '{"result": "Agent already registered."}'
+
+ if result == AuthResult.TOKEN_MISMATCH:
+ response.status_code = 400
+ return '{"error": "Agent already registered but with a different token."}'
+
+ # TODO: test token on spacetraders api
+
+ token_hash = pwd_context.hash(token)
+ cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, token_hash))
+ sq_con.commit()
+ return '{"result": "Agent successfully registered."}'
diff --git a/datatemplates/agent_id.py b/datatemplates/agent_id.py
deleted file mode 100644
index 2424a40..0000000
--- a/datatemplates/agent_id.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from pydantic import BaseModel
-
-
-class AgentId(BaseModel):
- callsign: str
- token: str
diff --git a/modules/database.py b/modules/database.py
new file mode 100644
index 0000000..f75f68e
--- /dev/null
+++ b/modules/database.py
@@ -0,0 +1,12 @@
+import os.path
+import sqlite3
+from pathlib import Path
+from config import Config
+
+db_dir = os.path.dirname(Config.DATABASE_PATH)
+Path(db_dir).mkdir(parents=True, exist_ok=True)
+
+assert sqlite3.threadsafety == 3, "QLite thread safety is not set to 'Serialized'."
+sq_con = sqlite3.connect(Config.DATABASE_PATH)
+
+cursor = sq_con.cursor()