summaryrefslogtreecommitdiff
path: root/__main__.py
blob: 2002425d72f1bba55ce84725cfcdfdda5497c72b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os.path
import sqlite3
from pathlib import Path
from enum import Enum

from contextlib import asynccontextmanager
from typing import Annotated

from fastapi import FastAPI, Depends, Response
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from passlib.context import CryptContext

from .config import Config

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
cursor = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    db_dir = os.path.dirname(Config.DATABASE_PATH)
    Path(db_dir).mkdir(parents=True, exist_ok=True)

    assert sqlite3.threadsafety == 3, "QLite thread safety is not set to 'Serialized'."
    sq_con = sqlite3.connect(Config.DATABASE_PATH)
    global cursor
    cursor = sq_con.cursor()
    cursor.execute(
        "CREATE TABLE IF NOT EXISTS agents(primary_key INTEGER PRIMARY KEY, callsign TEXT NOT NULL UNIQUE, token_hash TEXT NOT NULL)")
    yield


app = FastAPI(lifespan=lifespan)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    user = form_data.username
    return {"access_token": user, "token_type": "bearer"}


class AuthResult(Enum):
    SUCCESS = 0
    NOT_FOUND = 1
    TOKEN_MISMATCH = 2


async def auth_agent(callsign: str, token: str) -> AuthResult:
    cursor.execute("SELECT token_hash from agents WHERE callsign = ?", (callsign,))
    row = cursor.fetchone()
    if row == None:
        return AuthResult.NOT_FOUND

    return AuthResult.SUCCESS if pwd_context.verify(token, row[0]) else AuthResult.TOKEN_MISMATCH


@app.post("/{callsign}/init", status_code=201)
async def init_agent(callsign: str, token: Annotated[str, Depends(oauth2_scheme)], response: Response):
    result = await auth_agent(callsign, token)

    if result == AuthResult.SUCCESS:
        response.status_code = 200
        return '{"result": "Agent already registered."}'

    if result == AuthResult.TOKEN_MISMATCH:
        response.status_code = 400
        return '{"error": "Agent already registered but with a different token."}'

    # TODO: test token on spacetraders api

    hash = pwd_context.hash(token)
    cursor.execute("INSERT INTO agents (callsign, token_hash) VALUES (?, ?)", (callsign, hash))
    return '{"result": "Agent successfully registered."}'


@app.get("/{callsign}/tasks")
async def get_tasks(callsign: str, token: Annotated[str, Depends(oauth2_scheme)]):
    return f'{{"callsign": "{callsign}", "token": "{token}"}}'

# if __name__ == "__main__":
#     with open(os.path.join(os.path.dirname(__file__), Config.TOKEN_FILE_NAME)) as f:
#         token = f.read().strip()
#
#     d = daemon.Daemon(Config.AGENT_SYMBOL, token)
#     d.run()