diff options
Diffstat (limited to 'apirouters/agents.py')
-rw-r--r-- | apirouters/agents.py | 68 |
1 files changed, 57 insertions, 11 deletions
diff --git a/apirouters/agents.py b/apirouters/agents.py index b16ab6d..89f736f 100644 --- a/apirouters/agents.py +++ b/apirouters/agents.py @@ -1,21 +1,33 @@ from enum import Enum from typing import Annotated +import openapi_client from fastapi import APIRouter, Depends, Response from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from passlib.context import CryptContext from ..modules.database import cursor, conn router = APIRouter() -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +agents = {} + + +class Agent: + def __init__(self, agent_symbol, token): + config = openapi_client.Configuration(access_token=token) + self.api_client = openapi_client.ApiClient(config) + + self.agents_api = openapi_client.AgentsApi(self.api_client) + self.fleet_api = openapi_client.FleetApi(self.api_client) + self.systems_api = openapi_client.SystemsApi(self.api_client) + + self.agent_symbol = agent_symbol @router.post("/token") async def login(form_data: OAuth2PasswordRequestForm = Depends()): - user = form_data.username - return {"access_token": user, "token_type": "bearer"} + passwd = form_data.password + return {"access_token": passwd, "token_type": "bearer"} class AuthResult(Enum): @@ -24,18 +36,39 @@ class AuthResult(Enum): TOKEN_MISMATCH = 2 -async def auth_agent(callsign: str, token: str) -> AuthResult: - cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,)) +def auth_agent(callsign: str, token: str) -> AuthResult: + cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,)) row = cursor.fetchone() if row is None: return AuthResult.NOT_FOUND - return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH + return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH + + +def load_agents_from_database() -> None: + cursor.execute("SELECT callsign, token from agents") + invalid_agents = [] + for row in cursor.fetchall(): + callsign = row[0] + token = row[1] + try: + agent = Agent(callsign, token) + if agent.agents_api.get_my_agent().data.symbol != callsign: + invalid_agents.append(callsign) + continue + agents[callsign] = agent + except: + invalid_agents.append(callsign) + + if len(invalid_agents) > 0: + for invalid in invalid_agents: + cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,)) + conn.commit() @router.post("/{callsign}/init", status_code=201) async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): - result = await auth_agent(callsign, token) + result = auth_agent(callsign, token) if result == AuthResult.SUCCESS: response.status_code = 200 @@ -45,9 +78,22 @@ async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme) response.status_code = 400 return '{"error": "Agent already registered but with a different token."}' - # TODO: test token on spacetraders api + success = True + agent = None + + try: + agent = Agent(callsign, token) + if agent.agents_api.get_my_agent().data.symbol != callsign: + success = False + except: + success = False + + if success: + agents[callsign] = agent + else: + response.status_code = 400 + return '{"error": "Agent could not be authenticated on the SpaceTraders API."}' - token_hash = pwd_context.hash(token) - cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, token_hash)) + cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token)) conn.commit() return '{"result": "Agent successfully registered."}' |