summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--apirouters/agents.py6
-rw-r--r--apirouters/tasks.py12
-rw-r--r--modules/ships.py8
-rw-r--r--modules/task_type.py10
-rw-r--r--modules/task_types.py21
5 files changed, 34 insertions, 23 deletions
diff --git a/apirouters/agents.py b/apirouters/agents.py
index a1e015b..1b927e7 100644
--- a/apirouters/agents.py
+++ b/apirouters/agents.py
@@ -5,7 +5,7 @@ import openapi_client
from fastapi import APIRouter, Depends, Response, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
-from ..modules import task_type
+from ..modules import task_types
from ..modules.database import cursor, conn
from ..modules.ships import Ship
@@ -87,7 +87,7 @@ def load_agents_from_database() -> None:
ship.load_task()
except:
for ship in agent.ships.values():
- ship.set_task(tasks.ERROR)
+ ship.set_task(task_types.ERROR)
cursor.execute("SELECT symbol FROM ships WHERE symbol LIKE ?", (f"{agent.agent_symbol}-%",))
ship_names = agent.ships.keys()
@@ -96,7 +96,7 @@ def load_agents_from_database() -> None:
continue
missing_ship = Ship(row[0])
- missing_ship.set_task(tasks.MIA)
+ missing_ship.set_task(task_types.MIA)
agent.ships[row[0]] = missing_ship
diff --git a/apirouters/tasks.py b/apirouters/tasks.py
index 293ab52..6cc9a25 100644
--- a/apirouters/tasks.py
+++ b/apirouters/tasks.py
@@ -27,15 +27,15 @@ async def get_tasks(ship_symbol: str, agent: Annotated[Agent, Depends(auth_agent
return {"error": "Unknown ship symbol."}
-@router.post("/task/{ship_symbol}/set/{task_type}")
-async def get_tasks(ship_symbol: str, task: str, agent: Annotated[Agent, Depends(auth_agent)]):
- if task not in task_type.task_types:
- return {"error": "Invalid task."}
+class SetTaskBody(BaseModel):
+ task: task_types.task_type
+
+@router.post("/task/{ship_symbol}/set")
+async def get_tasks(ship_symbol: str, set_task: SetTaskBody, agent: Annotated[Agent, Depends(auth_agent)]):
for current_ship in agent.ships.values():
if current_ship.symbol == ship_symbol:
- current_ship.set_task(task)
+ current_ship.set_task(set_task.task)
return current_ship.get_data()
return {"error": "Unknown ship symbol."}
-
diff --git a/modules/ships.py b/modules/ships.py
index e35cdd2..05ce6c1 100644
--- a/modules/ships.py
+++ b/modules/ships.py
@@ -1,6 +1,6 @@
from typing import Dict
-from . import task_type
+from . import task_types
from .database import cursor, conn
@@ -20,10 +20,10 @@ class Ship:
row = cursor.fetchone()
if row is None:
cursor.execute("INSERT INTO ships (symbol, task, params, name) VALUES (?, ?, ?, ?)",
- (self.symbol, task_type.IDLE, None, self.symbol))
+ (self.symbol, task_types.IDLE, None, self.symbol))
conn.commit()
- self.task = task_type.IDLE
+ self.task = task_types.IDLE
self.task = self.symbol
else:
self.name = row[0]
@@ -31,7 +31,7 @@ class Ship:
def set_task(self, task):
self.task = task
- if task != task_type.ERROR:
+ if task != task_types.ERROR:
cursor.execute("UPDATE ships SET task = ? WHERE symbol = ?", (task, self.symbol))
conn.commit()
diff --git a/modules/task_type.py b/modules/task_type.py
deleted file mode 100644
index 7579ad0..0000000
--- a/modules/task_type.py
+++ /dev/null
@@ -1,10 +0,0 @@
-IDLE = 'IDLE'
-MINING = 'MINING'
-
-MIA = 'MIA'
-ERROR = 'ERROR'
-
-task_types = [
- IDLE,
- MINING,
-]
diff --git a/modules/task_types.py b/modules/task_types.py
new file mode 100644
index 0000000..d23ce55
--- /dev/null
+++ b/modules/task_types.py
@@ -0,0 +1,21 @@
+from typing import Annotated
+from pydantic.functional_validators import AfterValidator
+
+IDLE = 'IDLE'
+MINING = 'MINING'
+
+MIA = 'MIA'
+ERROR = 'ERROR'
+
+task_types = [
+ IDLE,
+ MINING,
+]
+
+
+def is_task_type(task: str):
+ assert task in task_types, f"'{task}' is not a valid task type."
+ return task
+
+
+task_type = Annotated[str, AfterValidator(is_task_type)]