diff options
-rw-r--r-- | __main__.py | 6 | ||||
-rw-r--r-- | apirouters/agents.py | 29 |
2 files changed, 20 insertions, 15 deletions
diff --git a/__main__.py b/__main__.py index 03fb366..2434422 100644 --- a/__main__.py +++ b/__main__.py @@ -22,6 +22,6 @@ app = FastAPI(lifespan=lifespan) app.include_router(agents.router) -@app.get("/{callsign}/tasks") -async def get_tasks(callsign: str, token: Annotated[str, Depends(agents.oauth2_scheme)]): - return f'{{"callsign": "{callsign}", "token": "{token}"}}' +@app.get("/tasks") +async def get_tasks(agent: Annotated[agents.Agent, Depends(agents.auth_agent)]): + return {"callsign": agent.agent_symbol} diff --git a/apirouters/agents.py b/apirouters/agents.py index 4ac2167..55825e6 100644 --- a/apirouters/agents.py +++ b/apirouters/agents.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Annotated, Dict, List import openapi_client -from fastapi import APIRouter, Depends, Response +from fastapi import APIRouter, Depends, Response, HTTPException from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from ..modules import tasks @@ -39,7 +39,7 @@ class AuthResult(Enum): TOKEN_MISMATCH = 2 -def auth_agent(callsign: str, token: str) -> AuthResult: +def check_init_agent(callsign: str, token: str) -> AuthResult: cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,)) row = cursor.fetchone() if row is None: @@ -48,6 +48,15 @@ def auth_agent(callsign: str, token: str) -> AuthResult: return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH +async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent: + cursor.execute("SELECT callsign from agents WHERE token = ?", (token,)) + row = cursor.fetchone() + if row is None: + raise HTTPException(status_code=400, detail="Authentication failed.") + + return agents[row[0]] + + def load_agents_from_database() -> None: cursor.execute("SELECT callsign, token from agents") invalid_agents = [] @@ -91,21 +100,17 @@ def load_agents_from_database() -> None: agent.ships[row[0]] = missing_ship - - - - -@router.post("/{callsign}/init", status_code=201) +@router.post("/init/{callsign}", status_code=201) async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): - result = auth_agent(callsign, token) + result = check_init_agent(callsign, token) if result == AuthResult.SUCCESS: response.status_code = 200 - return '{"result": "Agent already registered."}' + return {"result": "Agent already registered."} if result == AuthResult.TOKEN_MISMATCH: response.status_code = 400 - return '{"error": "Agent already registered but with a different token."}' + return {"error": "Agent already registered but with a different token."} success = True agent = None @@ -121,8 +126,8 @@ async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme) agents[callsign] = agent else: response.status_code = 400 - return '{"error": "Agent could not be authenticated on the SpaceTraders API."}' + return {"error": "Agent could not be authenticated on the SpaceTraders API."} cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token)) conn.commit() - return '{"result": "Agent successfully registered."}' + return {"result": "Agent successfully registered."} |