summaryrefslogtreecommitdiff
path: root/apirouters/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'apirouters/auth.py')
-rw-r--r--apirouters/auth.py120
1 files changed, 120 insertions, 0 deletions
diff --git a/apirouters/auth.py b/apirouters/auth.py
new file mode 100644
index 0000000..49d3740
--- /dev/null
+++ b/apirouters/auth.py
@@ -0,0 +1,120 @@
+from enum import Enum
+from typing import Annotated, Dict
+
+from fastapi import APIRouter, Depends, Response, HTTPException
+from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
+
+from ..modules.database import cursor, conn
+from ..entities import task_types
+from ..entities.agent import Agent
+from ..entities.ship import Ship
+
+router = APIRouter()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+agents: Dict[str, Agent] = {}
+
+
+@router.post("/token")
+async def login(form_data: OAuth2PasswordRequestForm = Depends()):
+ passwd = form_data.password
+ return {"access_token": passwd, "token_type": "bearer"}
+
+
+class AuthResult(Enum):
+ SUCCESS = 0
+ NOT_FOUND = 1
+ TOKEN_MISMATCH = 2
+
+
+def check_init_agent(callsign: str, token: str) -> AuthResult:
+ cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,))
+ row = cursor.fetchone()
+ if row is None:
+ return AuthResult.NOT_FOUND
+
+ return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH
+
+
+async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent:
+ cursor.execute("SELECT callsign from agents WHERE token = ?", (token,))
+ row = cursor.fetchone()
+ if row is None:
+ raise HTTPException(status_code=400, detail="Authentication failed.")
+
+ return agents[row[0]]
+
+
+def load_agents_from_database() -> None:
+ cursor.execute("SELECT callsign, token from agents")
+ invalid_agents = []
+ for row in cursor.fetchall():
+ callsign = row[0]
+ token = row[1]
+ try:
+ agent = Agent(callsign, token)
+ if agent.agents_api.get_my_agent().data.symbol != callsign:
+ invalid_agents.append(callsign)
+ continue
+ agents[callsign] = agent
+ except:
+ invalid_agents.append(callsign)
+
+ if len(invalid_agents) > 0:
+ for invalid in invalid_agents:
+ cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,))
+ conn.commit()
+
+ for agent in agents.values():
+ try:
+ response = agent.fleet_api.get_my_ships()
+ for ship_data in response.data:
+ agent.ships[ship_data.symbol] = Ship(ship_data.symbol)
+
+ for ship in agent.ships.values():
+ ship.load_task()
+ except:
+ for ship in agent.ships.values():
+ ship.set_task(task_types.ERROR)
+
+ cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",))
+ ship_names = agent.ships.keys()
+ for row in cursor.fetchall():
+ if row[0] in ship_names:
+ continue
+
+ missing_ship = Ship(row[0])
+ missing_ship.set_task(task_types.MIA)
+ agent.ships[row[0]] = missing_ship
+
+
+@router.post("/init/{callsign}", status_code=201)
+async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
+ result = check_init_agent(callsign, token)
+
+ if result == AuthResult.SUCCESS:
+ response.status_code = 200
+ return {"result": "Agent already registered."}
+
+ if result == AuthResult.TOKEN_MISMATCH:
+ response.status_code = 400
+ return {"error": "Agent already registered but with a different token."}
+
+ success = True
+ agent = None
+
+ try:
+ agent = Agent(callsign, token)
+ if agent.agents_api.get_my_agent().data.symbol != callsign:
+ success = False
+ except:
+ success = False
+
+ if success:
+ agents[callsign] = agent
+ else:
+ response.status_code = 400
+ return {"error": "Agent could not be authenticated on the SpaceTraders API."}
+
+ cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token))
+ conn.commit()
+ return {"result": "Agent successfully registered."}