summaryrefslogtreecommitdiff
path: root/apirouters/auth.py
blob: 49d374090d384c058cefbfd34881c123fec30efa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from enum import Enum
from typing import Annotated, Dict

from fastapi import APIRouter, Depends, Response, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm

from ..modules.database import cursor, conn
from ..entities import task_types
from ..entities.agent import Agent
from ..entities.ship import Ship

router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
agents: Dict[str, Agent] = {}


@router.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    passwd = form_data.password
    return {"access_token": passwd, "token_type": "bearer"}


class AuthResult(Enum):
    SUCCESS = 0
    NOT_FOUND = 1
    TOKEN_MISMATCH = 2


def check_init_agent(callsign: str, token: str) -> AuthResult:
    cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,))
    row = cursor.fetchone()
    if row is None:
        return AuthResult.NOT_FOUND

    return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH


async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent:
    cursor.execute("SELECT callsign from agents WHERE token = ?", (token,))
    row = cursor.fetchone()
    if row is None:
        raise HTTPException(status_code=400, detail="Authentication failed.")

    return agents[row[0]]


def load_agents_from_database() -> None:
    cursor.execute("SELECT callsign, token from agents")
    invalid_agents = []
    for row in cursor.fetchall():
        callsign = row[0]
        token = row[1]
        try:
            agent = Agent(callsign, token)
            if agent.agents_api.get_my_agent().data.symbol != callsign:
                invalid_agents.append(callsign)
                continue
            agents[callsign] = agent
        except:
            invalid_agents.append(callsign)

    if len(invalid_agents) > 0:
        for invalid in invalid_agents:
            cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,))
        conn.commit()

    for agent in agents.values():
        try:
            response = agent.fleet_api.get_my_ships()
            for ship_data in response.data:
                agent.ships[ship_data.symbol] = Ship(ship_data.symbol)

            for ship in agent.ships.values():
                ship.load_task()
        except:
            for ship in agent.ships.values():
                ship.set_task(task_types.ERROR)

        cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",))
        ship_names = agent.ships.keys()
        for row in cursor.fetchall():
            if row[0] in ship_names:
                continue

            missing_ship = Ship(row[0])
            missing_ship.set_task(task_types.MIA)
            agent.ships[row[0]] = missing_ship


@router.post("/init/{callsign}", status_code=201)
async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
    result = check_init_agent(callsign, token)

    if result == AuthResult.SUCCESS:
        response.status_code = 200
        return {"result": "Agent already registered."}

    if result == AuthResult.TOKEN_MISMATCH:
        response.status_code = 400
        return {"error": "Agent already registered but with a different token."}

    success = True
    agent = None

    try:
        agent = Agent(callsign, token)
        if agent.agents_api.get_my_agent().data.symbol != callsign:
            success = False
    except:
        success = False

    if success:
        agents[callsign] = agent
    else:
        response.status_code = 400
        return {"error": "Agent could not be authenticated on the SpaceTraders API."}

    cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token))
    conn.commit()
    return {"result": "Agent successfully registered."}