summaryrefslogtreecommitdiff
path: root/apirouters/agents.py
blob: 89f736f3ff76f4743e7f613dfc16ad32eff9052e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from enum import Enum
from typing import Annotated

import openapi_client
from fastapi import APIRouter, Depends, Response
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm

from ..modules.database import cursor, conn

router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
agents = {}


class Agent:
    def __init__(self, agent_symbol, token):
        config = openapi_client.Configuration(access_token=token)
        self.api_client = openapi_client.ApiClient(config)

        self.agents_api = openapi_client.AgentsApi(self.api_client)
        self.fleet_api = openapi_client.FleetApi(self.api_client)
        self.systems_api = openapi_client.SystemsApi(self.api_client)

        self.agent_symbol = agent_symbol


@router.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    passwd = form_data.password
    return {"access_token": passwd, "token_type": "bearer"}


class AuthResult(Enum):
    SUCCESS = 0
    NOT_FOUND = 1
    TOKEN_MISMATCH = 2


def auth_agent(callsign: str, token: str) -> AuthResult:
    cursor.execute("SELECT token from agents WHERE callsign = ?", (callsign,))
    row = cursor.fetchone()
    if row is None:
        return AuthResult.NOT_FOUND

    return AuthResult.SUCCESS if token == row[0] else AuthResult.TOKEN_MISMATCH


def load_agents_from_database() -> None:
    cursor.execute("SELECT callsign, token from agents")
    invalid_agents = []
    for row in cursor.fetchall():
        callsign = row[0]
        token = row[1]
        try:
            agent = Agent(callsign, token)
            if agent.agents_api.get_my_agent().data.symbol != callsign:
                invalid_agents.append(callsign)
                continue
            agents[callsign] = agent
        except:
            invalid_agents.append(callsign)

    if len(invalid_agents) > 0:
        for invalid in invalid_agents:
            cursor.execute("DELETE from agents WHERE callsign = ?", (invalid,))
        conn.commit()


@router.post("/{callsign}/init", status_code=201)
async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
    result = auth_agent(callsign, token)

    if result == AuthResult.SUCCESS:
        response.status_code = 200
        return '{"result": "Agent already registered."}'

    if result == AuthResult.TOKEN_MISMATCH:
        response.status_code = 400
        return '{"error": "Agent already registered but with a different token."}'

    success = True
    agent = None

    try:
        agent = Agent(callsign, token)
        if agent.agents_api.get_my_agent().data.symbol != callsign:
            success = False
    except:
        success = False

    if success:
        agents[callsign] = agent
    else:
        response.status_code = 400
        return '{"error": "Agent could not be authenticated on the SpaceTraders API."}'

    cursor.execute("INSERT INTO agents (callsign, token) VALUES (?, ?)", (callsign, token))
    conn.commit()
    return '{"result": "Agent successfully registered."}'