From 21019bc8414e60495573807b0fe0eb562a0e8876 Mon Sep 17 00:00:00 2001 From: Botond Hende Date: Sat, 31 Aug 2024 14:27:35 +0200 Subject: agent auth dependency --- apirouters/agents.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) (limited to 'apirouters/agents.py') diff --git a/apirouters/agents.py b/apirouters/agents.py index 4ac2167..55825e6 100644 --- a/apirouters/agents.py +++ b/apirouters/agents.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Annotated, Dict, List import openapi_client -from fastapi import APIRouter, Depends, Response +from fastapi import APIRouter, Depends, Response, HTTPException from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from ..modules import tasks @@ -39,7 +39,7 @@ class AuthResult(Enum): TOKEN_MISMATCH = 2 -def auth_agent(callsign: str, token: str) -> AuthResult: +def check_init_agent(callsign: str, token: str) -> AuthResult: cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,)) row = cursor.fetchone() if row is None: @@ -48,6 +48,15 @@ def auth_agent(callsign: str, token: str) -> AuthResult: return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH +async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent: + cursor.execute("SELECT callsign from agents WHERE token = ?", (token,)) + row = cursor.fetchone() + if row is None: + raise HTTPException(status_code=400, detail="Authentication failed.") + + return agents[row[0]] + + def load_agents_from_database() -> None: cursor.execute("SELECT callsign, token from agents") invalid_agents = [] @@ -91,21 +100,17 @@ def load_agents_from_database() -> None: agent.ships[row[0]] = missing_ship - - - - -@router.post("/{callsign}/init", status_code=201) +@router.post("/init/{callsign}", status_code=201) async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): - result = auth_agent(callsign, token) + result = check_init_agent(callsign, token) if result == AuthResult.SUCCESS: response.status_code = 200 - return '{"result": "Agent already registered."}' + return {"result": "Agent already registered."} if result == AuthResult.TOKEN_MISMATCH: response.status_code = 400 - return '{"error": "Agent already registered but with a different token."}' + return {"error": "Agent already registered but with a different token."} success = True agent = None @@ -121,8 +126,8 @@ async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme) agents[callsign] = agent else: response.status_code = 400 - return '{"error": "Agent could not be authenticated on the SpaceTraders API."}' + return {"error": "Agent could not be authenticated on the SpaceTraders API."} cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token)) conn.commit() - return '{"result": "Agent successfully registered."}' + return {"result": "Agent successfully registered."} -- cgit v1.2.3-70-g09d2