From 5f67692c2424ce54c8dd754ef87d8d753759700e Mon Sep 17 00:00:00 2001 From: Botond Hende Date: Mon, 2 Sep 2024 00:30:16 +0200 Subject: renamed/refactored code --- __main__.py | 6 +-- apirouters/agents.py | 120 ------------------------------------------- apirouters/auth.py | 120 +++++++++++++++++++++++++++++++++++++++++++ apirouters/customize_ship.py | 2 +- apirouters/tasks.py | 5 +- modules/ships.py | 4 +- 6 files changed, 129 insertions(+), 128 deletions(-) delete mode 100644 apirouters/agents.py create mode 100644 apirouters/auth.py diff --git a/__main__.py b/__main__.py index 662d7f3..55fa740 100644 --- a/__main__.py +++ b/__main__.py @@ -2,7 +2,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI -from .apirouters import agents, tasks, customize_ship +from .apirouters import auth, tasks, customize_ship from .modules.database import cursor, conn @@ -13,12 +13,12 @@ async def lifespan(application: FastAPI): cursor.execute( "CREATE TABLE IF NOT EXISTS ships(primary_key INTEGER PRIMARY KEY, symbol TEXT NOT NULL UNIQUE, task TEXT NOT NULL, params TEXT, name TEXT)") conn.commit() - agents.load_agents_from_database() + auth.load_agents_from_database() yield app = FastAPI(lifespan=lifespan) -app.include_router(agents.router) +app.include_router(auth.router) app.include_router(tasks.router) app.include_router(customize_ship.router) diff --git a/apirouters/agents.py b/apirouters/agents.py deleted file mode 100644 index 49d3740..0000000 --- a/apirouters/agents.py +++ /dev/null @@ -1,120 +0,0 @@ -from enum import Enum -from typing import Annotated, Dict - -from fastapi import APIRouter, Depends, Response, HTTPException -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm - -from ..modules.database import cursor, conn -from ..entities import task_types -from ..entities.agent import Agent -from ..entities.ship import Ship - -router = APIRouter() -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -agents: Dict[str, Agent] = {} - - -@router.post("/token") -async def login(form_data: OAuth2PasswordRequestForm = Depends()): - passwd = form_data.password - return {"access_token": passwd, "token_type": "bearer"} - - -class AuthResult(Enum): - SUCCESS = 0 - NOT_FOUND = 1 - TOKEN_MISMATCH = 2 - - -def check_init_agent(callsign: str, token: str) -> AuthResult: - cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,)) - row = cursor.fetchone() - if row is None: - return AuthResult.NOT_FOUND - - return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH - - -async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent: - cursor.execute("SELECT callsign from agents WHERE token = ?", (token,)) - row = cursor.fetchone() - if row is None: - raise HTTPException(status_code=400, detail="Authentication failed.") - - return agents[row[0]] - - -def load_agents_from_database() -> None: - cursor.execute("SELECT callsign, token from agents") - invalid_agents = [] - for row in cursor.fetchall(): - callsign = row[0] - token = row[1] - try: - agent = Agent(callsign, token) - if agent.agents_api.get_my_agent().data.symbol != callsign: - invalid_agents.append(callsign) - continue - agents[callsign] = agent - except: - invalid_agents.append(callsign) - - if len(invalid_agents) > 0: - for invalid in invalid_agents: - cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,)) - conn.commit() - - for agent in agents.values(): - try: - response = agent.fleet_api.get_my_ships() - for ship_data in response.data: - agent.ships[ship_data.symbol] = Ship(ship_data.symbol) - - for ship in agent.ships.values(): - ship.load_task() - except: - for ship in agent.ships.values(): - ship.set_task(task_types.ERROR) - - cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",)) - ship_names = agent.ships.keys() - for row in cursor.fetchall(): - if row[0] in ship_names: - continue - - missing_ship = Ship(row[0]) - missing_ship.set_task(task_types.MIA) - agent.ships[row[0]] = missing_ship - - -@router.post("/init/{callsign}", status_code=201) -async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): - result = check_init_agent(callsign, token) - - if result == AuthResult.SUCCESS: - response.status_code = 200 - return {"result": "Agent already registered."} - - if result == AuthResult.TOKEN_MISMATCH: - response.status_code = 400 - return {"error": "Agent already registered but with a different token."} - - success = True - agent = None - - try: - agent = Agent(callsign, token) - if agent.agents_api.get_my_agent().data.symbol != callsign: - success = False - except: - success = False - - if success: - agents[callsign] = agent - else: - response.status_code = 400 - return {"error": "Agent could not be authenticated on the SpaceTraders API."} - - cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token)) - conn.commit() - return {"result": "Agent successfully registered."} diff --git a/apirouters/auth.py b/apirouters/auth.py new file mode 100644 index 0000000..49d3740 --- /dev/null +++ b/apirouters/auth.py @@ -0,0 +1,120 @@ +from enum import Enum +from typing import Annotated, Dict + +from fastapi import APIRouter, Depends, Response, HTTPException +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm + +from ..modules.database import cursor, conn +from ..entities import task_types +from ..entities.agent import Agent +from ..entities.ship import Ship + +router = APIRouter() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +agents: Dict[str, Agent] = {} + + +@router.post("/token") +async def login(form_data: OAuth2PasswordRequestForm = Depends()): + passwd = form_data.password + return {"access_token": passwd, "token_type": "bearer"} + + +class AuthResult(Enum): + SUCCESS = 0 + NOT_FOUND = 1 + TOKEN_MISMATCH = 2 + + +def check_init_agent(callsign: str, token: str) -> AuthResult: + cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,)) + row = cursor.fetchone() + if row is None: + return AuthResult.NOT_FOUND + + return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH + + +async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent: + cursor.execute("SELECT callsign from agents WHERE token = ?", (token,)) + row = cursor.fetchone() + if row is None: + raise HTTPException(status_code=400, detail="Authentication failed.") + + return agents[row[0]] + + +def load_agents_from_database() -> None: + cursor.execute("SELECT callsign, token from agents") + invalid_agents = [] + for row in cursor.fetchall(): + callsign = row[0] + token = row[1] + try: + agent = Agent(callsign, token) + if agent.agents_api.get_my_agent().data.symbol != callsign: + invalid_agents.append(callsign) + continue + agents[callsign] = agent + except: + invalid_agents.append(callsign) + + if len(invalid_agents) > 0: + for invalid in invalid_agents: + cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,)) + conn.commit() + + for agent in agents.values(): + try: + response = agent.fleet_api.get_my_ships() + for ship_data in response.data: + agent.ships[ship_data.symbol] = Ship(ship_data.symbol) + + for ship in agent.ships.values(): + ship.load_task() + except: + for ship in agent.ships.values(): + ship.set_task(task_types.ERROR) + + cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",)) + ship_names = agent.ships.keys() + for row in cursor.fetchall(): + if row[0] in ship_names: + continue + + missing_ship = Ship(row[0]) + missing_ship.set_task(task_types.MIA) + agent.ships[row[0]] = missing_ship + + +@router.post("/init/{callsign}", status_code=201) +async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response): + result = check_init_agent(callsign, token) + + if result == AuthResult.SUCCESS: + response.status_code = 200 + return {"result": "Agent already registered."} + + if result == AuthResult.TOKEN_MISMATCH: + response.status_code = 400 + return {"error": "Agent already registered but with a different token."} + + success = True + agent = None + + try: + agent = Agent(callsign, token) + if agent.agents_api.get_my_agent().data.symbol != callsign: + success = False + except: + success = False + + if success: + agents[callsign] = agent + else: + response.status_code = 400 + return {"error": "Agent could not be authenticated on the SpaceTraders API."} + + cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token)) + conn.commit() + return {"result": "Agent successfully registered."} diff --git a/apirouters/customize_ship.py b/apirouters/customize_ship.py index ca3ef83..5f0666e 100644 --- a/apirouters/customize_ship.py +++ b/apirouters/customize_ship.py @@ -13,7 +13,7 @@ class RenameBody(BaseModel): @router.post("/customize_ship/{ship_symbol}/rename") -async def rename(ship_symbol: str, rename_body: RenameBody, ship: Annotated[Ship, Depends(ships.get_ship)]): +async def rename(rename_body: RenameBody, ship: Annotated[Ship, Depends(ships.get_ship)]): ship.rename(rename_body.name) return ship.get_data() diff --git a/apirouters/tasks.py b/apirouters/tasks.py index 684cf60..87fdb6e 100644 --- a/apirouters/tasks.py +++ b/apirouters/tasks.py @@ -2,16 +2,17 @@ from typing import Annotated from fastapi import APIRouter, Depends from pydantic import BaseModel -from . import agents +from . import auth from ..modules import ships from ..entities import task_types from ..entities.ship import Ship +from ..entities.agent import Agent router = APIRouter() @router.get("/tasks") -async def get_tasks(agent: Annotated[agents.Agent, Depends(agents.auth_agent)]): +async def get_tasks(agent: Annotated[Agent, Depends(auth.auth_agent)]): ret_list = [] for current_ship in agent.ships.values(): ret_list.append(current_ship.get_data()) diff --git a/modules/ships.py b/modules/ships.py index 9b0d4bd..4a6d0d6 100644 --- a/modules/ships.py +++ b/modules/ships.py @@ -2,12 +2,12 @@ from typing import Annotated from fastapi import Depends, HTTPException -from ..apirouters import agents +from ..apirouters import auth from ..entities.agent import Agent from ..entities.ship import Ship -async def get_ship(ship_symbol: str, agent: Annotated[Agent, Depends(agents.auth_agent)]) -> Ship: +async def get_ship(ship_symbol: str, agent: Annotated[Agent, Depends(auth.auth_agent)]) -> Ship: for ship in agent.ships.values(): if ship.symbol == ship_symbol: return ship -- cgit v1.2.3-70-g09d2