from enum import Enum from typing import Annotated, Dict, List import openapi_client from fastapi import APIRouter, Depends, Response from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from ..modules import tasks from ..modules.ships import Ship from ..modules.database import cursor, conn router = APIRouter() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") agents: Dict[str, 'Agent'] = {} class Agent: def __init__(self, agent_symbol, token): config = openapi_client.Configuration(access_token=token) self.api_client = openapi_client.ApiClient(config) self.agents_api = openapi_client.AgentsApi(self.api_client) self.fleet_api = openapi_client.FleetApi(self.api_client) self.systems_api = openapi_client.SystemsApi(self.api_client) self.agent_symbol = agent_symbol self.ships: Dict[str, Ship] = {} @router.post("/token") async def login(form_data: OAuth2PasswordRequestForm = Depends()): passwd = form_data.password return {"access_token": passwd, "token_type": "bearer"} class AuthResult(Enum): SUCCESS = 0 NOT_FOUND = 1 TOKEN_MISMATCH = 2 def auth_agent(callsign: str, token: str) -> AuthResult: cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,)) row = cursor.fetchone() if row is None: return AuthResult.NOT_FOUND return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH def load_agents_from_database() -> None: cursor.execute("SELECT callsign, token from agents") invalid_agents = [] for row in cursor.fetchall(): callsign = row[0] token = row[1] try: agent = Agent(callsign, token) if agent.agents_api.get_my_agent().data.symbol != callsign: invalid_agents.append(callsign) continue agents[callsign] = agent except: invalid_agents.append(callsign) if len(invalid_agents) > 0: for invalid in invalid_agents: cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,)) conn.commit() for agent in agents.values(): try: response = agent.fleet_api.get_my_ships() for ship_data in response.data: agent.ships[ship_data.symbol] = Ship(ship_data.symbol) for ship in agent.ships.values(): ship.load_task() except: for ship in agent.ships.values(): ship.set_task(tasks.ERROR) cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",)) ship_names = agent.ships.keys() for row in cursor.fetchall(): if row[0] in ship_names: continue missing_ship = Ship(row[0]) missing_ship.set_task(tasks.MIA) agent.ships[row[0]] = missing_ship @router.post("/{callsign}/init", status_code=201) async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): result = auth_agent(callsign, token) if result == AuthResult.SUCCESS: response.status_code = 200 return '{"result": "Agent already registered."}' if result == AuthResult.TOKEN_MISMATCH: response.status_code = 400 return '{"error": "Agent already registered but with a different token."}' success = True agent = None try: agent = Agent(callsign, token) if agent.agents_api.get_my_agent().data.symbol != callsign: success = False except: success = False if success: agents[callsign] = agent else: response.status_code = 400 return '{"error": "Agent could not be authenticated on the SpaceTraders API."}' cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token)) conn.commit() return '{"result": "Agent successfully registered."}'