from enum import Enum from typing import Annotated, Dict import openapi_client from fastapi import APIRouter, Depends, Response, HTTPException from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from ..modules import task_types from ..modules.database import cursor, conn from ..modules.ships import Ship router = APIRouter() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") agents: Dict[str, 'Agent'] = {} class Agent: def __init__(self, agent_symbol, token): config = openapi_client.Configuration(access_token=token) self.api_client = openapi_client.ApiClient(config) self.agents_api = openapi_client.AgentsApi(self.api_client) self.fleet_api = openapi_client.FleetApi(self.api_client) self.systems_api = openapi_client.SystemsApi(self.api_client) self.agent_symbol = agent_symbol self.ships: Dict[str, Ship] = {} @router.post("/token") async def login(form_data: OAuth2PasswordRequestForm = Depends()): passwd = form_data.password return {"access_token": passwd, "token_type": "bearer"} class AuthResult(Enum): SUCCESS = 0 NOT_FOUND = 1 TOKEN_MISMATCH = 2 def check_init_agent(callsign: str, token: str) -> AuthResult: cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,)) row = cursor.fetchone() if row is None: return AuthResult.NOT_FOUND return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent: cursor.execute("SELECT callsign from agents WHERE token = ?", (token,)) row = cursor.fetchone() if row is None: raise HTTPException(status_code=400, detail="Authentication failed.") return agents[row[0]] def load_agents_from_database() -> None: cursor.execute("SELECT callsign, token from agents") invalid_agents = [] for row in cursor.fetchall(): callsign = row[0] token = row[1] try: agent = Agent(callsign, token) if agent.agents_api.get_my_agent().data.symbol != callsign: invalid_agents.append(callsign) continue agents[callsign] = agent except: invalid_agents.append(callsign) if len(invalid_agents) > 0: for invalid in invalid_agents: cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,)) conn.commit() for agent in agents.values(): try: response = agent.fleet_api.get_my_ships() for ship_data in response.data: agent.ships[ship_data.symbol] = Ship(ship_data.symbol) for ship in agent.ships.values(): ship.load_task() except: for ship in agent.ships.values(): ship.set_task(task_types.ERROR) cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",)) ship_names = agent.ships.keys() for row in cursor.fetchall(): if row[0] in ship_names: continue missing_ship = Ship(row[0]) missing_ship.set_task(task_types.MIA) agent.ships[row[0]] = missing_ship @router.post("/init/{callsign}", status_code=201) async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): result = check_init_agent(callsign, token) if result == AuthResult.SUCCESS: response.status_code = 200 return {"result": "Agent already registered."} if result == AuthResult.TOKEN_MISMATCH: response.status_code = 400 return {"error": "Agent already registered but with a different token."} success = True agent = None try: agent = Agent(callsign, token) if agent.agents_api.get_my_agent().data.symbol != callsign: success = False except: success = False if success: agents[callsign] = agent else: response.status_code = 400 return {"error": "Agent could not be authenticated on the SpaceTraders API."} cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token)) conn.commit() return {"result": "Agent successfully registered."}