summaryrefslogtreecommitdiff
path: root/apirouters/agents.py
blob: 55825e6b6740735cd90d70bad80a1d5305642e53 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from enum import Enum
from typing import Annotated, Dict, List

import openapi_client
from fastapi import APIRouter, Depends, Response, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm

from ..modules import tasks
from ..modules.ships import Ship
from ..modules.database import cursor, conn

router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
agents: Dict[str, 'Agent'] = {}


class Agent:
    def __init__(self, agent_symbol, token):
        config = openapi_client.Configuration(access_token=token)
        self.api_client = openapi_client.ApiClient(config)

        self.agents_api = openapi_client.AgentsApi(self.api_client)
        self.fleet_api = openapi_client.FleetApi(self.api_client)
        self.systems_api = openapi_client.SystemsApi(self.api_client)

        self.agent_symbol = agent_symbol
        self.ships: Dict[str, Ship] = {}


@router.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    passwd = form_data.password
    return {"access_token": passwd, "token_type": "bearer"}


class AuthResult(Enum):
    SUCCESS = 0
    NOT_FOUND = 1
    TOKEN_MISMATCH = 2


def check_init_agent(callsign: str, token: str) -> AuthResult:
    cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,))
    row = cursor.fetchone()
    if row is None:
        return AuthResult.NOT_FOUND

    return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH


async def auth_agent(token: Annotated[str, Depends(oauth2_scheme)]) -> Agent:
    cursor.execute("SELECT callsign from agents WHERE token = ?", (token,))
    row = cursor.fetchone()
    if row is None:
        raise HTTPException(status_code=400, detail="Authentication failed.")

    return agents[row[0]]


def load_agents_from_database() -> None:
    cursor.execute("SELECT callsign, token from agents")
    invalid_agents = []
    for row in cursor.fetchall():
        callsign = row[0]
        token = row[1]
        try:
            agent = Agent(callsign, token)
            if agent.agents_api.get_my_agent().data.symbol != callsign:
                invalid_agents.append(callsign)
                continue
            agents[callsign] = agent
        except:
            invalid_agents.append(callsign)

    if len(invalid_agents) > 0:
        for invalid in invalid_agents:
            cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,))
        conn.commit()

    for agent in agents.values():
        try:
            response = agent.fleet_api.get_my_ships()
            for ship_data in response.data:
                agent.ships[ship_data.symbol] = Ship(ship_data.symbol)

            for ship in agent.ships.values():
                ship.load_task()
        except:
            for ship in agent.ships.values():
                ship.set_task(tasks.ERROR)

        cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",))
        ship_names = agent.ships.keys()
        for row in cursor.fetchall():
            if row[0] in ship_names:
                continue

            missing_ship = Ship(row[0])
            missing_ship.set_task(tasks.MIA)
            agent.ships[row[0]] = missing_ship


@router.post("/init/{callsign}", status_code=201)
async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
    result = check_init_agent(callsign, token)

    if result == AuthResult.SUCCESS:
        response.status_code = 200
        return {"result": "Agent already registered."}

    if result == AuthResult.TOKEN_MISMATCH:
        response.status_code = 400
        return {"error": "Agent already registered but with a different token."}

    success = True
    agent = None

    try:
        agent = Agent(callsign, token)
        if agent.agents_api.get_my_agent().data.symbol != callsign:
            success = False
    except:
        success = False

    if success:
        agents[callsign] = agent
    else:
        response.status_code = 400
        return {"error": "Agent could not be authenticated on the SpaceTraders API."}

    cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token))
    conn.commit()
    return {"result": "Agent successfully registered."}