summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBotond Hende <nettingman@gmail.com>2024-08-31 14:27:35 +0200
committerBotond Hende <nettingman@gmail.com>2024-08-31 14:27:35 +0200
commit21019bc8414e60495573807b0fe0eb562a0e8876 (patch)
tree49a1c412dfe3068e87ff51b2a66e1ea49b290206
parent254a5eda7d17ff28d4604d95dba6f102120189b8 (diff)
agent auth dependency
-rw-r--r--__main__.py6
-rw-r--r--apirouters/agents.py29
2 files changed, 20 insertions, 15 deletions
diff --git a/__main__.py b/__main__.py
index 03fb366..2434422 100644
--- a/__main__.py
+++ b/__main__.py
@@ -22,6 +22,6 @@ app = FastAPI(lifespan=lifespan)
app.include_router(agents.router)
-@app.get("/{callsign}/tasks")
-async def get_tasks(callsign: str, token: Annotated[str, Depends(agents.oauth2_scheme)]):
- return f'{{"callsign": "{callsign}", "token": "{token}"}}'
+@app.get("/tasks")
+async def get_tasks(agent: Annotated[agents.Agent, Depends(agents.auth_agent)]):
+ return {"callsign": agent.agent_symbol}
diff --git a/apirouters/agents.py b/apirouters/agents.py
index 4ac2167..55825e6 100644
--- a/apirouters/agents.py
+++ b/apirouters/agents.py
@@ -2,7 +2,7 @@ from enum import Enum
from typing import Annotated, Dict, List
import openapi_client
-from fastapi import APIRouter, Depends, Response
+from fastapi import APIRouter, Depends, Response, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from ..modules import tasks
@@ -39,7 +39,7 @@ class AuthResult(Enum):
TOKEN_MISMATCH = 2
-def auth_agent(callsign: str, token: str) -> AuthResult:
+def check_init_agent(callsign: str, token: str) -> AuthResult:
cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,))
row = cursor.fetchone()
if row is None:
@@ -48,6 +48,15 @@ def auth_agent(callsign: str, token: str) -> AuthResult:
return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH
+async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent:
+ cursor.execute("SELECT callsign from agents WHERE token = ?", (token,))
+ row = cursor.fetchone()
+ if row is None:
+ raise HTTPException(status_code=400, detail="Authentication failed.")
+
+ return agents[row[0]]
+
+
def load_agents_from_database() -> None:
cursor.execute("SELECT callsign, token from agents")
invalid_agents = []
@@ -91,21 +100,17 @@ def load_agents_from_database() -> None:
agent.ships[row[0]] = missing_ship
-
-
-
-
-@router.post("/{callsign}/init", status_code=201)
+@router.post("/init/{callsign}", status_code=201)
async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
- result = auth_agent(callsign, token)
+ result = check_init_agent(callsign, token)
if result == AuthResult.SUCCESS:
response.status_code = 200
- return '{"result": "Agent already registered."}'
+ return {"result": "Agent already registered."}
if result == AuthResult.TOKEN_MISMATCH:
response.status_code = 400
- return '{"error": "Agent already registered but with a different token."}'
+ return {"error": "Agent already registered but with a different token."}
success = True
agent = None
@@ -121,8 +126,8 @@ async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)
agents[callsign] = agent
else:
response.status_code = 400
- return '{"error": "Agent could not be authenticated on the SpaceTraders API."}'
+ return {"error": "Agent could not be authenticated on the SpaceTraders API."}
cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token))
conn.commit()
- return '{"result": "Agent successfully registered."}'
+ return {"result": "Agent successfully registered."}