From 965c0e3401a4743c4f5e91e2368fd04b0b24aa02 Mon Sep 17 00:00:00 2001 From: Botond Hende Date: Sun, 24 Nov 2024 23:53:13 +0100 Subject: input_handler can be specified from config --- __main__.py | 13 +++-- config.yaml | 1 + modules/config.py | 13 +++++ modules/input_handlers/input_handler.py | 11 ++++ modules/input_handlers/pipewire_record.py | 91 ++++++++++++++++--------------- modules/input_handlers/stdin_input.py | 19 +++++-- 6 files changed, 93 insertions(+), 55 deletions(-) create mode 100644 modules/input_handlers/input_handler.py diff --git a/__main__.py b/__main__.py index b51e70b..797ec56 100644 --- a/__main__.py +++ b/__main__.py @@ -1,15 +1,16 @@ import os -import sys import yaml from pathlib import Path +from modules.config import Config from .modules.config import load_config from .modules.hassil.recognize import recognize from .modules.hassil.util import merge_dict from .modules.hassil.intents import Intents, TextSlotList -from .modules.input_handlers.stdin_input import get_input_stdin -from .modules.input_handlers.pipewire_record import get_input_pw_record, cleanup + +from .modules.input_handlers.stdin_input import StdinInput +from .modules.input_handlers.pipewire_record import PipeWireRecord from .modules.intents import * @@ -40,9 +41,11 @@ def main(): intents = Intents.from_dict(input_dict) + input_handler = PipeWireRecord() if config.input_handler == Config.INPUT_PW else StdinInput() + try: # TODO select input type from config - for input_text in get_input_pw_record(): + for input_text in input_handler.get_input(): result = recognize(input_text, intents, slot_lists=slot_lists) if result is not None: result_dict = { @@ -55,7 +58,7 @@ def main(): else: print("") finally: - cleanup() + input_handler.cleanup() if __name__ == '__main__': main() \ No newline at end of file diff --git a/config.yaml b/config.yaml index 858eb31..0903321 100644 --- a/config.yaml +++ b/config.yaml @@ -1,3 +1,4 @@ +input_mode: pw-record # other option: stdin intents_dir: sentences applications_dir: /usr/bin diff --git a/modules/config.py b/modules/config.py index 609e4d1..83243a0 100644 --- a/modules/config.py +++ b/modules/config.py @@ -1,10 +1,18 @@ +import sys + import __main__ import os.path import yaml class Config: + INPUT_PW = "pw-record" + INPUT_STDIN = "stdin" + INPUT_MODES = [INPUT_PW, INPUT_STDIN] + def __init__(self): + self.input_mode = "" + self.intents_dir = "" self.applications_dir = "" @@ -16,6 +24,10 @@ class Config: if not self.intents_dir.startswith("/"): self.intents_dir = os.path.join(os.path.dirname(__main__.__file__), self.intents_dir) + def validate(self): + if self.input_mode not in self.INPUT_MODES: + sys.exit(f"Invalid input_mode '{self.input_mode}', valid options: {", ".join(self.INPUT_MODES)}") + def load_config(): config = Config() with open(os.path.join(os.path.dirname(__main__.__file__), "config.yaml")) as stream: @@ -26,4 +38,5 @@ def load_config(): with open(user_config) as stream: config.update(**yaml.safe_load(stream)) + config.validate() return config \ No newline at end of file diff --git a/modules/input_handlers/input_handler.py b/modules/input_handlers/input_handler.py new file mode 100644 index 0000000..82a2830 --- /dev/null +++ b/modules/input_handlers/input_handler.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + +class InputHandler(ABC): + + @abstractmethod + def get_input(self): + pass + + @abstractmethod + def cleanup(self): + pass \ No newline at end of file diff --git a/modules/input_handlers/pipewire_record.py b/modules/input_handlers/pipewire_record.py index 147d77b..3ad295c 100644 --- a/modules/input_handlers/pipewire_record.py +++ b/modules/input_handlers/pipewire_record.py @@ -1,69 +1,72 @@ import subprocess import os.path import signal -import sys from time import sleep import whisper +from modules.input_handlers.input_handler import InputHandler + FIFO_PATH = "/tmp/hestia-listening" RECORD_PATH = "/tmp/hestia-record.mp3" -def cleanup(): - if os.path.exists(FIFO_PATH): - os.remove(FIFO_PATH) - +class PipeWireRecord(InputHandler): + def cleanup(self): + if os.path.exists(FIFO_PATH): + os.remove(FIFO_PATH) -def get_input_pw_record(): - device = get_device() - cleanup() - os.mkfifo(FIFO_PATH) + def get_input(self): + device = PipeWireRecord.get_device() - while True: - with open(FIFO_PATH): - pass - # TODO "I'm listening" + self.cleanup() + os.mkfifo(FIFO_PATH) - try: - ps = subprocess.Popen((f"pw-record --target {device} {RECORD_PATH}",), shell=True) + while True: with open(FIFO_PATH): - print("finished") - ps.send_signal(signal.SIGINT) - # TODO "acknowledged" - except: - if "ps" in locals(): - ps.kill() - # TODO "error" - # TODO exit gracefully or try to recover - raise StopIteration + pass + # TODO "I'm listening" + + try: + ps = subprocess.Popen((f"pw-record --target {device} {RECORD_PATH}",), shell=True) + with open(FIFO_PATH): + print("finished") + ps.send_signal(signal.SIGINT) + # TODO "acknowledged" + except: + if "ps" in locals(): + ps.kill() + # TODO "error" + # TODO exit gracefully or try to recover + raise StopIteration - model = whisper.load_model("base") + model = whisper.load_model("base") - audio = whisper.load_audio(RECORD_PATH) - audio = whisper.pad_or_trim(audio) + audio = whisper.load_audio(RECORD_PATH) + audio = whisper.pad_or_trim(audio) - mel = whisper.log_mel_spectrogram(audio).to(model.device) - options = whisper.DecodingOptions(language="en", fp16=False) - result = whisper.decode(model, mel, options) - result_text = result.text.replace(",", "").replace(".", "").lower() + mel = whisper.log_mel_spectrogram(audio).to(model.device) + options = whisper.DecodingOptions(language="en", fp16=False) + result = whisper.decode(model, mel, options) + result_text = result.text.replace(",", "").replace(".", "").lower() - print(result_text) + print(result_text) - yield result_text + yield result_text -def get_device() -> str: - already_warned = False + @staticmethod + def get_device() -> str: + already_warned = False - while True: - ps = subprocess.Popen(('pw-cli ls | \\grep -Poi "(?<=node.name = \\").*mic.*(?=\\")"',), shell=True, stdout=subprocess.PIPE) - ps.wait() + while True: + ps = subprocess.Popen(('pw-cli ls | \\grep -Poi "(?<=node.name = \\").*mic.*(?=\\")"',), shell=True, stdout=subprocess.PIPE) + ps.wait() - if ps.returncode == 0: - return ps.stdout.read().decode().strip() + if ps.returncode == 0: + return ps.stdout.read().decode().strip() - elif not already_warned: - already_warned = True - # TODO warn about device not found + elif not already_warned: + already_warned = True + # TODO warn about device not found - sleep(3) \ No newline at end of file + sleep(3) \ No newline at end of file diff --git a/modules/input_handlers/stdin_input.py b/modules/input_handlers/stdin_input.py index 044ac0d..5c50511 100644 --- a/modules/input_handlers/stdin_input.py +++ b/modules/input_handlers/stdin_input.py @@ -1,9 +1,16 @@ import sys -def get_input_stdin(): - for line in sys.stdin: - line = line.strip() - if not line: - continue +from modules.input_handlers.input_handler import InputHandler - yield line \ No newline at end of file + +class StdinInput(InputHandler): + def cleanup(self): + pass + + def get_input(self): + for line in sys.stdin: + line = line.strip() + if not line: + continue + + yield line \ No newline at end of file -- cgit v1.2.3-70-g09d2