diff options
Diffstat (limited to 'apirouters/auth.py')
-rw-r--r-- | apirouters/auth.py | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/apirouters/auth.py b/apirouters/auth.py new file mode 100644 index 0000000..49d3740 --- /dev/null +++ b/apirouters/auth.py @@ -0,0 +1,120 @@ +from enum import Enum +from typing import Annotated, Dict + +from fastapi import APIRouter, Depends, Response, HTTPException +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm + +from ..modules.database import cursor, conn +from ..entities import task_types +from ..entities.agent import Agent +from ..entities.ship import Ship + +router = APIRouter() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +agents: Dict[str, Agent] = {} + + +@router.post("/token") +async def login(form_data: OAuth2PasswordRequestForm = Depends()): + passwd = form_data.password + return {"access_token": passwd, "token_type": "bearer"} + + +class AuthResult(Enum): + SUCCESS = 0 + NOT_FOUND = 1 + TOKEN_MISMATCH = 2 + + +def check_init_agent(callsign: str, token: str) -> AuthResult: + cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,)) + row = cursor.fetchone() + if row is None: + return AuthResult.NOT_FOUND + + return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH + + +async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent: + cursor.execute("SELECT callsign from agents WHERE token = ?", (token,)) + row = cursor.fetchone() + if row is None: + raise HTTPException(status_code=400, detail="Authentication failed.") + + return agents[row[0]] + + +def load_agents_from_database() -> None: + cursor.execute("SELECT callsign, token from agents") + invalid_agents = [] + for row in cursor.fetchall(): + callsign = row[0] + token = row[1] + try: + agent = Agent(callsign, token) + if agent.agents_api.get_my_agent().data.symbol != callsign: + invalid_agents.append(callsign) + continue + agents[callsign] = agent + except: + invalid_agents.append(callsign) + + if len(invalid_agents) > 0: + for invalid in invalid_agents: + cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,)) + conn.commit() + + for agent in agents.values(): + try: + response = agent.fleet_api.get_my_ships() + for ship_data in response.data: + agent.ships[ship_data.symbol] = Ship(ship_data.symbol) + + for ship in agent.ships.values(): + ship.load_task() + except: + for ship in agent.ships.values(): + ship.set_task(task_types.ERROR) + + cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",)) + ship_names = agent.ships.keys() + for row in cursor.fetchall(): + if row[0] in ship_names: + continue + + missing_ship = Ship(row[0]) + missing_ship.set_task(task_types.MIA) + agent.ships[row[0]] = missing_ship + + +@router.post("/init/{callsign}", status_code=201) +async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): + result = check_init_agent(callsign, token) + + if result == AuthResult.SUCCESS: + response.status_code = 200 + return {"result": "Agent already registered."} + + if result == AuthResult.TOKEN_MISMATCH: + response.status_code = 400 + return {"error": "Agent already registered but with a different token."} + + success = True + agent = None + + try: + agent = Agent(callsign, token) + if agent.agents_api.get_my_agent().data.symbol != callsign: + success = False + except: + success = False + + if success: + agents[callsign] = agent + else: + response.status_code = 400 + return {"error": "Agent could not be authenticated on the SpaceTraders API."} + + cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token)) + conn.commit() + return {"result": "Agent successfully registered."} |