diff --git a/config/agent_input.yaml b/config/agent_input.yaml new file mode 100644 index 0000000..c122a8b --- /dev/null +++ b/config/agent_input.yaml @@ -0,0 +1,34 @@ +id: agent-io +hostname: localhost + +huri: + hostname: localhost + router: + port: 3000 + event-proxy: + xsub: 5555 + xpub: 5556 + log-puller: + port: 8008 + +forwarder-proxy: + down-xsub: 6665 + up-xpub: 6666 + +logging: INFO + +modules: + inp: + name: INP + logging: INFO + out: + name: OUT + logging: INFO + mod: + name: MOD + logging: INFO + rag: + name: RAG + args: + model: deepseek-v2:16b + logging: INFO diff --git a/config/agent_io.yaml b/config/agent_io.yaml new file mode 100644 index 0000000..c9a5646 --- /dev/null +++ b/config/agent_io.yaml @@ -0,0 +1,36 @@ +id: agent-io +hostname: localhost + +huri: + hostname: localhost + router: + port: 3000 + event-proxy: + xsub: 5555 + xpub: 5556 + log-puller: + port: 8008 + +forwarder-proxy: + down-xsub: 6665 + up-xpub: 6666 + +logging: INFO + +modules: + mic: + name: mic + args: + sample_rate: 18000 + logging: INFO + stt: + name: stt + args: + sample_rate: 18000 + logging: INFO + # tts: + # name: vibe + # args: + # model: vibe-voice + # voice: adrien + # logging: DEBUG diff --git a/config/huri.yaml b/config/huri.yaml new file mode 100644 index 0000000..13f06b1 --- /dev/null +++ b/config/huri.yaml @@ -0,0 +1,11 @@ +hostname: localhost + +router: + port: 3000 + +event-proxy: + xsub: 5555 + xpub: 5556 + +log-puller: + port: 8008 diff --git a/quick_launch.sh b/quick_launch.sh new file mode 100755 index 0000000..a76da2a --- /dev/null +++ b/quick_launch.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +set -e + +# Check args +if [ "$#" -lt 2 ]; then + echo "Usage: $0 [CLEAN]" + exit 1 +fi + +HURI_CONFIG="$1" +AGENT_CONFIG="$2" + +LOG_DIR="./tmp/log" + +if [[ " $* " == *" CLEAN "* ]]; then + echo "Cleaning previous logs in ${LOG_DIR}" + rm -rf "${LOG_DIR}" +fi + +mkdir -p "$LOG_DIR" + +TIMESTAMP=$(date +"%Y%m%d-%H%M%S") +HURI_LOG="${LOG_DIR}/huri-${TIMESTAMP}.log" + + +# Run huri with output redirected +python -m src.launch_huri --config "$HURI_CONFIG" > "$HURI_LOG" 2>&1 & +HURI_PID=$! +echo "HURI started in background (PID=${HURI_PID}), logging to ${HURI_LOG}" + +# Run agent +python -m src.launch_agent --config "$AGENT_CONFIG" + +# Ensure HURI is killed on script exit (normal or Ctrl+C) +cleanup() { + echo "Stopping HURI (PID=${HURI_PID})" + kill "${HURI_PID}" 2>/dev/null || true +} +trap cleanup EXIT INT TERM \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/agent.py b/src/core/agent.py new file mode 100644 index 0000000..df2c182 --- /dev/null +++ b/src/core/agent.py @@ -0,0 +1,277 @@ +import multiprocessing as mp +import signal +import threading +from dataclasses import dataclass +from multiprocessing.synchronize import Event +from typing import Any, Dict, Mapping + +from src.modules.factory import ModuleFactory +from src.tools.logger import logging, setup_logger + +from .huri import HuriConfig +from .zmq.control_channel import Command, Dealer +from .zmq.event_proxy import EventProxy +from .zmq.log_channel import LogPusher + + +@dataclass +class ForwarderProxyConfig: + down_xsub: int + up_xpub: int + + @classmethod + def from_dict(cls, raw: dict): + return cls( + down_xsub=raw["down-xsub"], + up_xpub=raw["up-xpub"], + ) + + +@dataclass +class ModuleConfig: + name: str + args: Mapping[str, Any] + logging: int + + @classmethod + def from_dict(cls, raw: dict): + level = logging._nameToLevel.get( + raw.get("logging", "INFO"), + logging.INFO, + ) + return cls( + name=raw["name"], + args=raw.get("args", {}), + logging=level, + ) + + +@dataclass +class AgentConfig: + id: str + hostname: str + huri: HuriConfig + logging: int + forwarder_proxy: ForwarderProxyConfig + modules: Dict[str, ModuleConfig] + + @classmethod + def from_dict(cls, raw: dict): + level = logging._nameToLevel.get( + raw.get("logging", "INFO").upper(), + logging.INFO, + ) + modules = { + module_id: ModuleConfig.from_dict(mod_raw) + for module_id, mod_raw in raw.get("modules", {}).items() + } + return cls( + id=raw["id"], + hostname=raw["hostname"], + huri=HuriConfig.from_dict(raw["huri"]), + forwarder_proxy=ForwarderProxyConfig.from_dict(raw["forwarder-proxy"]), + logging=level, + modules=modules, + ) + + +class Agent: + """Control Modules and communication with HuRI""" + + def __init__(self, config: AgentConfig) -> None: + self.modules: Dict[str, ModuleConfig] = config.modules + self.config = config + + self.processes: Dict[str, mp.Process] = {} + self.stop_events: Dict[str, Event] = {} + + self.threads: Dict[str, threading.Thread] = {} + + self.log_pusher = LogPusher( + hostname=config.huri.hostname, port=config.huri.log_puller.port + ) + + self.dealer = Dealer( + hostname=config.huri.hostname, + port=config.huri.router.port, + executor=self._command_handler, + logger=setup_logger("Dealer", log_queue=self.log_pusher.log_queue), + ) + + self.up_proxy = EventProxy( + hostname=config.hostname, + connect_hostname=config.huri.hostname, + xpub_port=config.huri.event_proxy.xsub, + xsub_port=config.forwarder_proxy.up_xpub, + logger=setup_logger("UpProxy", log_queue=self.log_pusher.log_queue), + ) + self.down_proxy = EventProxy( + hostname=config.hostname, + connect_hostname=config.huri.hostname, + xpub_port=config.forwarder_proxy.down_xsub, + xsub_port=config.huri.event_proxy.xpub, + logger=setup_logger("DownProxy", log_queue=self.log_pusher.log_queue), + ) + + self.logger = setup_logger( + f"Agent {self.dealer.identity}", log_queue=self.log_pusher.log_queue + ) + + def _command_handler(self, command: Command) -> bool: + match command.cmd: + case "START": + return self.start_module(*command.args) + case "STOP": + return self.stop_module(*command.args) + case "STATUS": + return self.status() + case _: + return False # todo log + + @staticmethod + def _start_module( + name: str, + module_config: ModuleConfig, + agent_config: AgentConfig, + log_queue: mp.Queue, + stop_event: Event, + ) -> None: + """Helper function to start module in child process.""" + logger = setup_logger( + module_config.name, level=module_config.logging, log_queue=log_queue + ) + + module = ModuleFactory.create(name, module_config.args) + module.set_custom_logger(logger) + + def handle_sigint(signum, frame): + logger.info("Ctrl+C ignored in child module") + + signal.signal(signal.SIGINT, handle_sigint) + + module.start_module( + agent_config.hostname, + agent_config.forwarder_proxy.up_xpub, + agent_config.forwarder_proxy.down_xsub, + stop_event=stop_event, + ) + + def start_module(self, name) -> None: + """Check if module is registered and not already running, and start a child process.""" + if name not in self.modules: + self.logger.warning( + f"{name} is not in the registered Modules: {self.modules.keys()}" + ) + return + if name in self.processes: + self.logger.warning( + f"{name} is already running (PID={self.processes[name].pid})" + ) + return + + module_config = self.modules[name] + stop_event = mp.Event() + p = mp.Process( + target=self._start_module, + args=( + name, + module_config, + self.config, + self.log_pusher.log_queue, + stop_event, + ), + daemon=True, + ) + self.processes[name] = p + self.stop_events[name] = stop_event + self.log_pusher.level_filter.add_level(name) + + p.start() + self.logger.info(f"{name} ({module_config.name}) started (PID={p.pid})") + + def stop_module(self, name) -> None: + if name in self.processes: + self.logger.info(f"Stopping {name}...") + self.stop_events[name].set() + self.processes[name].join(timeout=5) + if self.processes[name].is_alive(): + self.logger.warning(f"{name} did not stop in time, killing") + self.processes[name].kill() + self.logger.info(f"{name} stopped") + del self.processes[name] + del self.stop_events[name] + self.log_pusher.level_filter.del_level(name) + + def stop_all(self) -> None: + for name in list(self.processes.keys()): + self.stop_module(name) + + self.dealer.stop() + self.up_proxy.stop() + self.down_proxy.stop() + for name, thread in self.threads.items(): + self.logger.info(f"Stopping {name} thread...") + thread.join(timeout=5) + self.logger.info(f"{name} thread stopped") + self.log_pusher.level_filter.del_level(name) + + self.log_pusher.stop() + print("Fully stopped") + + def status(self) -> None: + """Print status of all modules and router.""" + print("=== Module Status ===") + for name in self.modules: + process = self.processes.get(name) + if process: + state = "alive" if process.is_alive() else "stopped" + print(f"- {name}: {state} (PID={process.pid})") + else: + print(f"- {name}: stopped") + print("=====================") + + def set_root_log_level(self, level: int) -> None: + self.log_pusher.level_filter.set_root_level(level) + + def set_log_level(self, name: str, level: int) -> None: + self.log_pusher.level_filter.set_level(name, level) + + def set_log_levels(self, level: int) -> None: + self.log_pusher.level_filter.set_levels(level) + + def _connect_to_huri(self) -> None: + self.log_pusher.level_filter.add_level("Dealer") + self.threads["Dealer"] = threading.Thread(target=self.dealer.start) + self.threads["Dealer"].start() + + def _start_event_proxies(self) -> None: + """Used to handle inter-module communication, though events""" + self.log_pusher.level_filter.add_level("UpProxy") + self.log_pusher.level_filter.add_level("DownProxy") + self.threads["UpProxy"] = threading.Thread( + target=self.up_proxy.start, args=[True, False] + ) + self.threads["DownProxy"] = threading.Thread( + target=self.down_proxy.start, args=[False, True] + ) + + self.threads["UpProxy"].start() + self.threads["DownProxy"].start() + + def run(self) -> None: + """Start event router and modules""" # TODO config (also logs levels) + + try: + self.log_pusher.start() + self._connect_to_huri() + self._start_event_proxies() + except Exception as e: + self.logger.error(e) + return + + for name in self.modules: + self.start_module(name) + + while True: + data = input() + self.down_proxy.publish("std.in", data) diff --git a/src/core/events.py b/src/core/events.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/huri.py b/src/core/huri.py new file mode 100644 index 0000000..18f52f3 --- /dev/null +++ b/src/core/huri.py @@ -0,0 +1,97 @@ +import sys +import threading +from dataclasses import dataclass +from typing import Dict + +from src.tools.logger import setup_logger + +from .zmq.control_channel import Router +from .zmq.event_proxy import EventProxy +from .zmq.log_channel import LogPuller + + +@dataclass +class RouterConfig: + port: int + + +@dataclass +class EventProxyConfig: + xsub: int + xpub: int + + +@dataclass +class LogPullerConfig: + port: int + + +@dataclass +class HuriConfig: + hostname: str + router: RouterConfig + event_proxy: EventProxyConfig + log_puller: LogPullerConfig + + @classmethod + def from_dict(cls, raw: dict): + return cls( + hostname=raw["hostname"], + router=RouterConfig(**raw["router"]), + event_proxy=EventProxyConfig(**raw["event-proxy"]), + log_puller=LogPullerConfig(**raw["log-puller"]), + ) + + +class HuRI: + """Wait for Agent to connect, handle module communication and Logging""" + + def __init__(self, config: HuriConfig) -> None: + self.router = Router(config.hostname, config.router.port) + self.event_proxy = EventProxy( + config.hostname, "", config.event_proxy.xpub, config.event_proxy.xsub + ) + self.log_channel = LogPuller(config.hostname, config.log_puller.port) + + self.threads: Dict[str, threading.Thread] = {} + + self.logger = setup_logger("HuRI") + + def _start_router(self) -> None: + """Used to handle Agent registration and control""" + self.threads["Router"] = threading.Thread(target=self.router.start) + self.threads["Router"].start() + + def _start_event_proxy(self) -> None: + """Used to handle inter-module communication, though events""" + self.threads["EventProxy"] = threading.Thread( + target=self.event_proxy.start, args=[False, False] + ) + self.threads["EventProxy"].start() + + def _start_log_channel(self) -> None: + """Used to handle Agent registration and control""" + self.threads["LogChannel"] = threading.Thread(target=self.log_channel.start) + self.threads["LogChannel"].start() + + def run(self) -> None: + self._start_log_channel() + self._start_router() + self._start_event_proxy() + + if not sys.stdin.isatty(): + threading.Event().wait() + return + + from src.core.shell import RobotShell + + RobotShell(self).cmdloop() + + def stop(self) -> None: + self.router.stop() + self.event_proxy.stop() + self.log_channel.stop() + for name, thread in self.threads.items(): + self.logger.info(f"Stopping {name} thread...") + thread.join(timeout=5) + self.logger.info(f"{name} thread stopped") diff --git a/src/core/module.py b/src/core/module.py new file mode 100644 index 0000000..4be1246 --- /dev/null +++ b/src/core/module.py @@ -0,0 +1,159 @@ +import json +import threading +from abc import ABC, abstractmethod +from multiprocessing.synchronize import Event +from typing import Callable, Dict, final + +import zmq + +from src.tools.logger import logging + + +class Module(ABC): + def __init__(self): + """Child Modules must call super.__init__() in their __init__() function.""" + self.ctx = None + self.pub_socket = None + self.connect_hostname = None + self.xpub_port = None + self.xsub_port = None + self.subs: Dict[str, zmq.Socket[bytes]] = {} + self.callbacks = {} + self._poller_running = False + self.poller = None + self.logger = logging.getLogger(__name__) + + @final + def _initialize(self) -> None: + """ + Called inside start_module() or manually before usage. + This function exist because ctx cannot be set in __init__, because of multi-processing. maybe deprecated + """ + self.ctx = zmq.Context() + self.pub_socket = self.ctx.socket(zmq.PUB) + self.pub_socket.connect(f"tcp://{self.connect_hostname}:{self.xpub_port}") + self.poller = threading.Thread(target=self._poll_loop, daemon=True) + self.set_subscriptions() + + @abstractmethod + def set_subscriptions(self) -> None: + """Child module must define this funcction with subscriptions""" + ... + + @final + def subscribe(self, topic: str, callback: Callable) -> None: + sub_socket = self.ctx.socket(zmq.SUB) + sub_socket.connect(f"tcp://{self.connect_hostname}:{self.xsub_port}") + sub_socket.setsockopt_string(zmq.SUBSCRIBE, topic) + self.subs[topic] = sub_socket + self.callbacks[topic] = callback + self.logger.info(f"Subscribe: {topic}") + + @final + def publish( + self, topic: str, msg: object, content_type: str = "str" + ) -> None: # TODO content type enum + if content_type == "json": + payload = json.dumps(msg).encode() + elif content_type == "bytes": + payload = msg + elif content_type == "str": + payload = msg.encode() + else: + raise ValueError(f"Unsupported content_type: {content_type}") + + self.pub_socket.send_multipart([topic.encode(), content_type.encode(), payload]) + self.logger.info(f"Publish: {topic} {content_type}") + + @final + def _start_polling(self) -> None: + self._poller_running = True + self.poller.start() + + @final + def _poll_loop(self) -> None: + poller = zmq.Poller() + for sub in self.subs.values(): + poller.register(sub, zmq.POLLIN) + + while self._poller_running: + events = dict(poller.poll(100)) + for _, sub in self.subs.items(): + if sub in events: + topic, content_type, payload = sub.recv_multipart() + topic_str = topic.decode() + content_type_str = content_type.decode() + self.logger.info(f"Receive: {topic_str} {content_type_str}") + if content_type_str == "json": + kwargs = json.loads(payload.decode()) + self.callbacks[topic_str]( + **kwargs + ) # TODO better and cleaner way ? + elif content_type_str == "bytes": + data = payload + self.callbacks[topic_str](data) + elif content_type_str == "str": + data = payload.decode() + self.callbacks[topic_str](data) + + @final + def start_module( + self, + connect_hostname: str, + xpub_port: int, + xsub_port: int, + stop_event: Event = None, + ) -> None: + self.connect_hostname = connect_hostname + self.xpub_port = xpub_port + self.xsub_port = xsub_port + self._initialize() + if self.subs != {}: + self._start_polling() + try: + self.run_module(stop_event) + except KeyboardInterrupt: + self.logger.info("Ctrl+C pressed, exiting cleanly") + except Exception as e: + self.logger.error(e) + finally: + self.stop_module() + + @final + def stop_module(self) -> None: + """Stop the module gracefully.""" + + if self._poller_running: + self._poller_running = False + self.poller.join() + + for topic, sub in self.subs.items(): + try: + sub.close(0) + except Exception as e: + self.logger.error(f"Error closing SUB socket for '{topic}': {e}") + + self.subs.clear() + self.callbacks.clear() + + try: + self.pub_socket.close(0) + except Exception as e: + self.logger.error(f"Error closing SUB socket for '{topic}': {e}") + + try: + self.ctx.term() + except Exception as e: + self.logger.error(f"Error terminating ZMQ context: {e}") + + self.logger.info("Module stopped gracefully.") + + def run_module(self, stop_event: Event = None) -> None: + """Child modules override this instead of run(). Default: idle wait.""" + if stop_event: + stop_event.wait() + + @final + def set_custom_logger(self, logger) -> None: + """The default logger in set in __init__.""" + self.logger = logger diff --git a/src/core/shell.py b/src/core/shell.py new file mode 100644 index 0000000..6473357 --- /dev/null +++ b/src/core/shell.py @@ -0,0 +1,31 @@ +import cmd + +from src.core.huri import HuRI +from src.core.zmq.control_channel import Command + + +class RobotShell(cmd.Cmd): + intro = "HuRI's shell. Type 'help' to see command's list." + prompt = "(HuRI) " + + def __init__(self, huri: HuRI) -> None: + super().__init__() + self.huri = huri + + def do_status(self, arg) -> None: + "Display modules and router status." + self.huri.router.send_commands(Command("STATUS", [])) + + def do_start(self, arg) -> None: + "Start a module." + self.huri.router.send_commands(Command("START", [arg.strip()])) + + def do_stop(self, arg) -> None: + "Stop a module." + self.huri.router.send_commands(Command("STOP", [arg.strip()])) + + def do_exit(self, arg) -> None: + "Exit HuRi." + self.huri.router.send_commands(Command("EXIT", [])) + print("Bye !") + return True diff --git a/src/core/zmq/__init__.py b/src/core/zmq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/zmq/control_channel.py b/src/core/zmq/control_channel.py new file mode 100644 index 0000000..89ef839 --- /dev/null +++ b/src/core/zmq/control_channel.py @@ -0,0 +1,151 @@ +import json +import uuid +from dataclasses import asdict, dataclass +from typing import Any, Callable, Dict, List, Optional + +import zmq + +from src.tools.logger import logging, setup_logger + + +@dataclass +class Command: + cmd: str # "STOP", "START", "STATUS", ... + args: List[Any] # JSON-serializable arguments + + def to_bytes(self) -> bytes: + return json.dumps(asdict(self)).encode("utf-8") + + @staticmethod + def from_bytes(data: bytes) -> "Command": + obj = json.loads(data.decode("utf-8")) + return Command(**obj) + + +@dataclass +class Result: + success: bool + result: List[Any] + + def to_bytes(self) -> bytes: + return json.dumps(asdict(self)).encode("utf-8") + + @staticmethod + def from_bytes(data: bytes) -> "Command": + obj = json.loads(data.decode("utf-8")) + return Result(**obj) + + +class Router: + def __init__( + self, + hostname: str, + port: int, + logger: Optional[logging.Logger] = setup_logger("Router"), + ): + + self.ctx = zmq.Context.instance() + self.router = self.ctx.socket(zmq.ROUTER) + self.hostname = hostname + self.port = port + + self.logger = logger or logging.getLogger(__name__) + + self.dealers: Dict[bytes, bool] = {} + + def start(self): + self.router.bind(f"tcp://{self.hostname}:{self.port}") + self.logger.info("Router started") + + try: + while True: + identity, *frames = self.router.recv_multipart() + + if not frames: + continue + + command = frames[0] + + if command == b"REGISTER": + self.dealers[identity] = True + self.logger.info(f"Dealer registered: {identity}") + + elif command == b"RESULT": + payload = frames[1] if len(frames) > 1 else b"" + self.logger.info(f"Result from {identity}: {payload.decode()}") + except Exception as e: + self.logger.exception(e) + pass + finally: + self.router.close() + + def stop(self) -> None: + self.router.close() + + def send_command(self, dealer_id: bytes, command: Command) -> None: + if dealer_id not in self.dealers: + raise ValueError("Dealer not registered") + + self.router.send_multipart([dealer_id, b"COMMAND", command.to_bytes()]) + + def send_commands(self, command: Command) -> None: + for dealer_id, _ in self.dealers.items(): + self.send_command(dealer_id, command) + + +class Dealer: + def __init__( + self, + hostname: str, + port: int, + executor: Callable[[Command], bool], + logger: Optional[logging.Logger] = None, + identity: Optional[str] = None, + ): + self.ctx = zmq.Context.instance() + self.dealer = self.ctx.socket(zmq.DEALER) + + self.hostname = hostname + self.port = port + + self.executor = executor + self.identity = (identity or str(uuid.uuid4())).encode() # TODO agent name + + self.logger = logger or logging.getLogger(f"Dealer {self.identity}") + + def start(self): + self.dealer.connect(f"tcp://{self.hostname}:{self.port}") + self.dealer.setsockopt(zmq.IDENTITY, self.identity) + self.logger.info(f"Dealer started: {self.identity}") + + try: + self.dealer.send(b"REGISTER") + + while True: + frames = self.dealer.recv_multipart() + + command = frames[0] + + if command == b"COMMAND": + self.logger.info("received command") + payload = frames[1] if len(frames) > 1 else b"" + result = self.execute(payload) + + self.dealer.send_multipart([b"RESULT", result]) + except Exception as e: + self.logger.exception(e) + finally: + self.dealer.close() + + def execute(self, command: Command) -> bytes: + """ + Execute command sent by Router + """ + self.executor(command) + + # Example execution + result = f"Executed: {command.cmd}" + return result.encode() + + def stop(self) -> None: + self.dealer.close(linger=0) diff --git a/src/core/zmq/event_proxy.py b/src/core/zmq/event_proxy.py new file mode 100644 index 0000000..9447de8 --- /dev/null +++ b/src/core/zmq/event_proxy.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass +from typing import Optional + +import zmq + +from src.tools.logger import logging, setup_logger + + +@dataclass +class ZMQEventPorts: + xpub: str + xsub: str + + +class EventProxy: + def __init__( + self, + hostname: str, + connect_hostname: str, + xpub_port: int, + xsub_port: int, + logger: Optional[logging.Logger] = setup_logger("EventProxy"), + ): + + self.ctx = zmq.Context.instance() + self.xpub = self.ctx.socket(zmq.XPUB) + self.xsub = self.ctx.socket(zmq.XSUB) + + self.hostname = hostname + self.connect_hostname = connect_hostname + self.xpub_port = xpub_port + self.xsub_port = xsub_port + + self.logger = logger or logging.getLogger(__name__) + + def start(self, xpub_connect: bool, xsub_connect: bool): + if xpub_connect: + self.xpub.connect(f"tcp://{self.connect_hostname}:{self.xpub_port}") + else: + self.xpub.bind(f"tcp://{self.hostname}:{self.xpub_port}") + if xsub_connect: + self.xsub.connect(f"tcp://{self.connect_hostname}:{self.xsub_port}") + else: + self.xsub.bind(f"tcp://{self.hostname}:{self.xsub_port}") + + try: + self.logger.info("Correctly initialized, starting proxy") + zmq.proxy(self.xsub, self.xpub) + except Exception as e: + self.logger.error(e) + + def stop(self) -> None: + self.xsub.close(linger=0) + self.xpub.close(linger=0) + + def publish(self, topic: str, msg: str) -> None: + self.xpub.send_multipart([topic.encode(), "str".encode(), msg.encode()]) + self.logger.info(f"Publish: {topic} str") diff --git a/src/core/zmq/log_channel.py b/src/core/zmq/log_channel.py new file mode 100644 index 0000000..6eb2a8e --- /dev/null +++ b/src/core/zmq/log_channel.py @@ -0,0 +1,142 @@ +import json +import time +from typing import Any, Dict, Optional + +import zmq + +from src.tools.logger import ( + LevelFilter, + QueueListener, + logging, + mp, + setup_log_listener, + setup_logger, +) + + +def record_to_dict(record: logging.LogRecord) -> Dict[str, Any]: + return { + "name": record.name, + "levelno": record.levelno, + "levelname": record.levelname, + "message": record.getMessage(), + "created": record.created, + "asctime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created)), + "process": record.process, + "processName": record.processName, + "thread": record.thread, + "threadName": record.threadName, + "module": record.module, + "filename": record.filename, + "pathname": record.pathname, + "lineno": record.lineno, + "funcName": record.funcName, + } + + +def dict_to_record(data: Dict[str, Any]) -> logging.LogRecord: + record = logging.LogRecord( + name=data["name"], + level=data["levelno"], + pathname=data["pathname"], + lineno=data["lineno"], + msg=data["message"], + args=(), + exc_info=None, + func=data["funcName"], + ) + + # Restore metadata + record.created = data["created"] + record.process = data["process"] + record.processName = data["processName"] + record.thread = data["thread"] + record.threadName = data["threadName"] + record.module = data["module"] + record.filename = data["filename"] + + return record + + +class LogPuller: + def __init__( + self, + hostname: str, + port: int, + logger: Optional[logging.Logger] = setup_logger("LogPuller"), + ) -> None: + self.ctx = zmq.Context.instance() + self.pull = self.ctx.socket(zmq.PULL) + + self.hostname = hostname + self.port = port + + self.logger = logger or logging.getLogger(__name__) + + def start(self) -> None: + self.pull.bind(f"tcp://{self.hostname}:{self.port}") + + self.logger.info("started") + while True: + payload = self.pull.recv() + + self.logger.handle(dict_to_record(json.loads(payload.decode()))) + + def stop(self) -> None: + self.pull.close() + + +class LogPusher: + class LogPusherHandler(logging.Handler): + def __init__( + self, + hostname: str, + port: int, + ): + super().__init__() + self.ctx = zmq.Context.instance() + self.socket = self.ctx.socket(zmq.PUSH) + + self.hostname = hostname + self.port = port + + def emit(self, record: logging.LogRecord) -> None: + try: + payload = json.dumps(record_to_dict(record)).encode() + self.socket.send(payload) + except Exception: + self.handleError(record) + except Exception: + self.handleError(record) + + def start(self) -> None: + self.socket.connect(f"tcp://{self.hostname}:{self.port}") + + def stop(self) -> None: + self.socket.close() + + def __init__( + self, + hostname: str, + port: int, + ): + + self.log_queue = mp.Queue() + + self.log_handler = self.LogPusherHandler(hostname, port) + self.level_filter = LevelFilter(logging.DEBUG) + self.log_listener: QueueListener = setup_log_listener( + self.log_queue, self.level_filter, self.log_handler + ) + + self.logger = setup_logger("LogPusher", log_queue=self.log_queue) + + def start(self) -> None: + self.log_handler.start() + self.log_listener.start() + + def stop(self): + self.logger.info("stopping") + time.sleep(0.2) + self.log_listener.stop() + self.log_handler.stop() diff --git a/src/emotional_hub/input_analysis.py b/src/emotional_hub/input_analysis.py index 2605010..7391578 100644 --- a/src/emotional_hub/input_analysis.py +++ b/src/emotional_hub/input_analysis.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor MODEL_NAME = "superb/hubert-large-superb-er" diff --git a/src/launch_agent.py b/src/launch_agent.py new file mode 100644 index 0000000..efc63c4 --- /dev/null +++ b/src/launch_agent.py @@ -0,0 +1,43 @@ +import argparse +import logging +import time + +import yaml + +from src.core.agent import Agent, AgentConfig +from src.modules.factory import build_module_factory + + +def load_config(path: str) -> AgentConfig: + with open(path) as f: + raw = yaml.safe_load(f) + + return AgentConfig.from_dict(raw) + + +def main() -> None: + parser = argparse.ArgumentParser(description="HuRI core") + parser.add_argument( + "--config", + required=True, + help="Path to HuRI config file (YAML)", + ) + + args = parser.parse_args() + + config = load_config(args.config) + + build_module_factory() + + agent = Agent(config) + time.sleep(0.1) + try: + agent.run() + except KeyboardInterrupt: + agent.stop_all() + except Exception as e: + logging.getLogger(__name__).error(e) + + +if __name__ == "__main__": + main() diff --git a/src/launch_huri.py b/src/launch_huri.py new file mode 100644 index 0000000..b4f9fd0 --- /dev/null +++ b/src/launch_huri.py @@ -0,0 +1,43 @@ +import argparse +import logging +import time + +import yaml + +from src.core.huri import HuRI, HuriConfig +from src.modules.factory import build_module_factory + + +def load_config(path: str) -> HuriConfig: + with open(path) as f: + raw = yaml.safe_load(f) + + return HuriConfig.from_dict(raw) + + +def main() -> None: + parser = argparse.ArgumentParser(description="HuRI core") + parser.add_argument( + "--config", + required=True, + help="Path to HuRI config file (YAML)", + ) + + args = parser.parse_args() + + config = load_config(args.config) + + build_module_factory() + + huri = HuRI(config) + time.sleep(0.1) + try: + huri.run() + except KeyboardInterrupt: + huri.stop() + except Exception as e: + logging.getLogger(__name__).error(e) + + +if __name__ == "__main__": + main() diff --git a/src/main.py b/src/main.py deleted file mode 100644 index 5489238..0000000 --- a/src/main.py +++ /dev/null @@ -1,60 +0,0 @@ -from enum import Enum -import soundfile as sf -import simpleaudio as sa -from emotional_hub.input_analysis import predict_emotion -from speech_to_text.speech_to_text import SpeechToText -from rag.rag import Rag - - -class Modes(Enum): - EXIT = 0 - LLM = 1 - CONTEXT = 2 - RAG = 3 - - -def loop(stt: SpeechToText, tts: None, mode: Modes, mode_function): - while mode: - prompt, audio = stt.get_prompt() - print(prompt) - if "switch llm" in prompt.lower(): - mode = Modes.LLM - elif "switch context" in prompt.lower(): - mode = Modes.CONTEXT - elif "switch rag" in prompt.lower(): - mode = Modes.RAG - elif "bye bye" in prompt.lower(): - mode = Modes.EXIT - elif prompt.strip() == "": - continue - else: - stt.pause() - emotion = predict_emotion(audio) - print("Predicted Emotion:", emotion) - answer = mode_function[mode](f"in a {emotion} emotion: {prompt}") - print(answer) - stt.pause(False) - - -def main(): - stt = SpeechToText() - rag = Rag(model="deepseek-v2:16b") - rag.ragLoader("tests/rag/docsRag", "txt") - mode = Modes.LLM - mode_function = { - Modes.LLM: rag.ragQuestion, - Modes.RAG: rag.ragLoader, - Modes.CONTEXT: lambda x: "Context mode not implemented yet.", - } - stt.start() - try: - loop(stt, None, mode, mode_function) - except KeyboardInterrupt: - print("CTRL+C detected. Stopping the program.") - except Exception as e: - print("Unexpected Error:", e) - stt.stop() - - -if __name__ == "__main__": - main() diff --git a/src/modules/factory.py b/src/modules/factory.py new file mode 100644 index 0000000..74bba03 --- /dev/null +++ b/src/modules/factory.py @@ -0,0 +1,33 @@ +from typing import Any, Mapping + +from src.core.module import Module + +from .rag.mode_controller import ModeController +from .rag.rag import Rag +from .speech_to_text.record_speech import RecordSpeech +from .speech_to_text.speech_to_text import SpeechToText +from .textIO.input import TextInput +from .textIO.output import TextOutput + + +class ModuleFactory: + _registry = {} + + @classmethod + def register(cls, name: str, module_cls): + cls._registry[name] = module_cls + + @classmethod + def create(cls, name: str, args: Mapping[str, Any] | None = None) -> Module: + if name not in cls._registry: + raise ValueError(f"Unknown module '{name}'") + return cls._registry[name](**args) + + +def build_module_factory() -> None: + ModuleFactory.register("mic", RecordSpeech) + ModuleFactory.register("stt", SpeechToText) + ModuleFactory.register("inp", TextInput) + ModuleFactory.register("out", TextOutput) + ModuleFactory.register("rag", Rag) + ModuleFactory.register("mod", ModeController) diff --git a/src/modules/rag/__init__.py b/src/modules/rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/rag/mode_controller.py b/src/modules/rag/mode_controller.py new file mode 100644 index 0000000..9ed0606 --- /dev/null +++ b/src/modules/rag/mode_controller.py @@ -0,0 +1,37 @@ +from enum import Enum + +from src.core.module import Module + + +class Modes(Enum): + LLM = 0 + CONTEXT = 1 + RAG = 2 + + +class ModeController(Module): + def __init__(self, default_mode: Modes = Modes.LLM): + super().__init__() + self.mode = default_mode + + def switchMode(self, mode: str) -> None: + self.mode = mode + + def processTextInput(self, text: str): + if "switch llm" in text.lower(): + self.switchMode(Modes.LLM) + elif "switch context" in text.lower(): + self.switchMode(Modes.CONTEXT) + elif "switch rag" in text.lower(): + self.switchMode(Modes.RAG) + elif "bye bye" in text.lower(): + self.publish("exit", "") # TODO handle (manager being a module) usefull ? + elif text.strip() == "": + return + else: + topic = f"{str(self.mode.name).lower()}.in" + self.publish(topic, text) + + def set_subscriptions(self): + self.subscribe("text.in", self.processTextInput) + self.subscribe("mode.switch", self.switchMode) diff --git a/src/rag/rag.py b/src/modules/rag/rag.py similarity index 65% rename from src/rag/rag.py rename to src/modules/rag/rag.py index 733ae0c..f2fc736 100644 --- a/src/rag/rag.py +++ b/src/modules/rag/rag.py @@ -1,23 +1,28 @@ -from langchain_community.document_loaders import TextLoader +import json +import pathlib + +from langchain.chains import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_chroma import Chroma +from langchain_community.document_loaders import TextLoader +from langchain_core.documents import Document +from langchain_core.prompts import ChatPromptTemplate from langchain_ollama.embeddings import OllamaEmbeddings from langchain_ollama.llms import OllamaLLM -from langchain.chains import create_retrieval_chain -from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain_core.prompts import ChatPromptTemplate from langgraph.checkpoint.memory import MemorySaver -import json -import pathlib +from src.core.module import Module -class Rag: + +class Rag(Module): def __init__( self, - model="deepseek-r1:7b", - collectionName="vectorStore", - vectorstorePath="src/rag/vectorStore", + model: str = "deepseek-v2:16b", + collectionName: str = "vectorStore", + vectorstorePath: str = "src/rag/vectorStore", ): + super().__init__() self.memory = MemorySaver() self.embeddings = OllamaEmbeddings(model=model) self.llm = OllamaLLM(model=model) @@ -44,18 +49,21 @@ def __init__( self.conversation = [] self.conversation_log = {"conversation": []} - def ragLoader(self, text: str): - self.documents += self.textSplitter.split_documents(text) + def ragFill(self, text: str) -> None: + self.documents += self.textSplitter.split_documents( + [Document(page_content=text)] + ) self.vectorstore.add_documents(self.documents) - def ragLoader(self, pathFolder: str, fileType: str): + def ragLoad(self, folderPath: str, fileType: str) -> None: if fileType == "txt": - for file in pathlib.Path(pathFolder).rglob("*.txt"): - fileLoader = TextLoader(file_path=pathFolder + "/" + file.name) + for file in pathlib.Path(folderPath).rglob("*.txt"): + fileLoader = TextLoader(file_path=folderPath + "/" + file.name) self.documents += self.textSplitter.split_documents(fileLoader.load()) self.vectorstore.add_documents(self.documents) - def ragQuestion(self, question: str): + def ragQuestion(self, question: str) -> None: + self.logger.debug("question:", question) history = "\n".join( [ f"Human: {qa['question']}\nAI: {qa['answer']}" @@ -64,13 +72,25 @@ def ragQuestion(self, question: str): ) helpingContext = "Answer with just your message like in a conversation. " question = helpingContext + question + self.logger.debug("full question:", question) response = self.qaChain.invoke({"history": history, "input": question}) answer = response["answer"] + self.logger.debug("answer:", answer) self.conversation_log["conversation"].append( {"question": question.split(helpingContext)[1:], "answer": answer} ) - return answer + self.publish("llm.response", answer) - def saveConversation(self, filename="conversation_log.json"): + def saveConversation(self, filename: str = "conversation_log.json"): with open(filename, "w") as f: json.dump(self.conversation_log, f, indent=4) + + def set_subscriptions(self) -> None: + self.subscribe("rag.load", self.ragLoad) + self.subscribe("llm.in", self.ragQuestion) + self.subscribe("rag.in", self.ragFill) + self.subscribe("rag.save", self.saveConversation) + + def run_module(self, stop_event=None) -> None: + self.ragLoad("tests/rag/docsRag", "txt") + super().run_module(stop_event) diff --git a/src/modules/speech_to_text/__init__.py b/src/modules/speech_to_text/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/record_speech.py similarity index 53% rename from src/speech_to_text/speech_to_text.py rename to src/modules/speech_to_text/record_speech.py index fd3ce39..a361ac3 100644 --- a/src/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/record_speech.py @@ -1,31 +1,24 @@ -import sounddevice as sd -import numpy as np -import io -import soundfile as sf -import whisper -from typing import List, Optional, Callable -import threading import queue +import threading import time +from typing import List, Optional + +import numpy as np +import sounddevice as sd +from src.core.module import Event, Module -class SpeechToText: + +class RecordSpeech(Module): def __init__( self, - model_name: str = "base.en", - device: str = "cpu", threshold: int = 0, silence_duration: float = 1.0, chunk_duration: float = 0.5, sample_rate: int = 16000, ): - if device == "cpu": - import warnings + super().__init__() - warnings.filterwarnings( - "ignore", message="FP16 is not supported on CPU; using FP32 instead" - ) - self.model: whisper.Whisper = whisper.load_model(model_name, device=device) self.THRESHOLD: int = threshold self.SILENCE_DURATION: float = silence_duration self.CHUNK_DURATION: float = chunk_duration @@ -41,6 +34,7 @@ def __init__( def reduce_noise(self, chunk: np.ndarray) -> np.ndarray: if np.abs(chunk).mean() <= self.THRESHOLD: return chunk + return np.clip(chunk - self.noise_profile, -32768, 32767).astype(np.int16) def record_chunk(self) -> np.ndarray: @@ -50,92 +44,64 @@ def record_chunk(self) -> np.ndarray: samplerate=self.SAMPLE_RATE, channels=1, dtype="int16", - ) + ).ravel() sd.wait() self.pause_record.release() return self.reduce_noise(chunk) def calculate_noise_level(self) -> None: - print("Listening for 10 seconds to calculate noise level...") + self.logger.info("Listening for 10 seconds to calculate noise level...") noise_chunk: np.ndarray = sd.rec( int(10 * self.SAMPLE_RATE), samplerate=self.SAMPLE_RATE, channels=1, dtype="int16", - ) + ).ravel() sd.wait() self.noise_profile = noise_chunk.mean(axis=0) self.THRESHOLD = np.abs(self.reduce_noise(noise_chunk)).mean() - print(f"Threshold: {self.THRESHOLD}") - - def process_audio(self, buffer: List[np.ndarray]) -> None: - if not buffer: - return - - audio_data: np.ndarray = np.concatenate(buffer, axis=0) - input_buffer: io.BytesIO = io.BytesIO() - sf.write(input_buffer, audio_data, self.SAMPLE_RATE, format="WAV") - input_buffer.seek(0) - audio_array, _ = sf.read(input_buffer, dtype="float32") - - result: dict = self.model.transcribe(audio_array, language="en") - result["text"] = result["text"].strip() - if not result["text"]: - return + self.logger.info(f"Threshold: {self.THRESHOLD}") - self.transcriptions.put([result["text"], audio_array]) - self.prompt_available.release() - - def record_audio(self, starting_chunk) -> None: + def record_audio(self, starting_chunk, stop_event: Event = None) -> None: buffer: List[np.ndarray] = [starting_chunk] silence_start: Optional[float] = None - while self.running: + while stop_event is None or not stop_event.is_set(): chunk = self.record_chunk() buffer.append(chunk) + if np.abs(chunk).mean() <= self.THRESHOLD: if silence_start is None: silence_start = time.time() elif time.time() - silence_start >= self.SILENCE_DURATION: - self.audio_queue.put(buffer) - self.audio_to_process.release() + if buffer == []: + break + speech = np.concatenate(buffer, axis=0) + self.publish("speech.in", speech.tobytes(), "bytes") break else: silence_start = None - def listen_audio(self) -> None: - self.running = True - while self.running: - chunk: np.ndarray = self.record_chunk() - if np.abs(chunk).mean() > self.THRESHOLD: - self.record_audio(chunk) + def set_subscriptions(self) -> None: + self.subscribe("speech.in.pause", self.pause()) + self.subscribe("speech.in.resume", self.pause(False)) - def process_queue(self) -> None: - self.audio_to_process.acquire() - while self.running: - buffer = self.audio_queue.get() - self.process_audio(buffer) - self.audio_to_process.acquire() - - def start(self) -> None: + def run_module(self, stop_event: Event = None) -> None: if not self.THRESHOLD: self.calculate_noise_level() + else: + self.noise_profile = np.zeros( + int(self.CHUNK_DURATION * self.SAMPLE_RATE), dtype=np.int16 + ) - self.running = True - threading.Thread(target=self.listen_audio).start() - threading.Thread(target=self.process_queue).start() + while stop_event is None or not stop_event.is_set(): + chunk: np.ndarray = self.record_chunk() + + if np.abs(chunk).mean() > self.THRESHOLD: + self.record_audio(chunk, stop_event) def pause(self, true: bool = True) -> None: if true: self.pause_record.acquire() else: self.pause_record.release() - - def stop(self) -> None: - self.running = False - self.audio_to_process.release() - self.pause_record.release() - - def get_prompt(self) -> tuple[str, np.ndarray]: - self.prompt_available.acquire() - return self.transcriptions.get() diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py new file mode 100644 index 0000000..a086a2c --- /dev/null +++ b/src/modules/speech_to_text/speech_to_text.py @@ -0,0 +1,50 @@ +import queue +import threading + +import numpy as np +import whisper + +from src.core.module import Module + + +class SpeechToText(Module): + def __init__( + self, + model_name: str = "base.en", + device: str = "cpu", + sample_rate: int = 16000, + ): + super().__init__() + print(model_name) + if device == "cpu": + import warnings + + warnings.filterwarnings( + "ignore", message="FP16 is not supported on CPU; using FP32 instead" + ) + self.model: whisper.Whisper = whisper.load_model(model_name, device=device) + self.SAMPLE_RATE: int = sample_rate + self.running: bool = False + self.audio_queue: queue.Queue = queue.Queue() + self.transcriptions: queue.Queue = queue.Queue() + self.pause_record = threading.Semaphore(1) + self.audio_to_process = threading.Semaphore(0) + self.prompt_available = threading.Semaphore(0) + self.noise_profile: np.ndarray + + def process_audio(self, buffer: bytes) -> None: + if not buffer: + return + + audio_array = np.frombuffer(buffer, dtype=np.int16) + audio_array = audio_array.astype(np.float32) / 32768.0 + + result: dict = self.model.transcribe(audio_array, language="en") + result["text"] = result["text"].strip() + if not result["text"] or result["text"] == "": + return + + self.publish("text.in", result["text"]) + + def set_subscriptions(self) -> None: + self.subscribe("speech.in", self.process_audio) diff --git a/src/modules/textIO/input.py b/src/modules/textIO/input.py new file mode 100644 index 0000000..6440238 --- /dev/null +++ b/src/modules/textIO/input.py @@ -0,0 +1,17 @@ +from src.core.module import Module + + +class TextInput(Module): + def set_subscriptions(self): + self.subscribe("std.in", self.stdin_to_text) + self.subscribe("std.out", lambda _: print(">> ", end="", flush=True)) + + def stdin_to_text(self, data): + print(">> ", end="", flush=True) + if data == "": + return + self.publish("text.in", data) + + def run_module(self, stop_event=None): + print(">> ", end="", flush=True) + stop_event.wait() diff --git a/src/modules/textIO/output.py b/src/modules/textIO/output.py new file mode 100644 index 0000000..c68ca76 --- /dev/null +++ b/src/modules/textIO/output.py @@ -0,0 +1,10 @@ +from src.core.module import Module + + +class TextOutput(Module): + def set_subscriptions(self) -> None: + self.subscribe("llm.response", self.print_response) + + def print_response(self, text: str) -> None: + print(f"\r<< {text}") + self.publish("std.out", "") diff --git a/src/text_to_speech/tts.py b/src/text_to_speech/tts.py deleted file mode 100644 index 0f8497a..0000000 --- a/src/text_to_speech/tts.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -from parler_tts import ParlerTTSForConditionalGeneration -from transformers import AutoTokenizer - - -def get_tts_model(): - device = "cuda" if torch.cuda.is_available() else "cpu" - return ParlerTTSForConditionalGeneration.from_pretrained( - "parler-tts/parler-tts-mini-v1" - ).to(device) - - -def tokenize_text(text, tokenizer): - device = "cuda" if torch.cuda.is_available() else "cpu" - return tokenizer(text, return_tensors="pt").input_ids.to(device) diff --git a/src/tools/logger.py b/src/tools/logger.py new file mode 100644 index 0000000..1b7d5c5 --- /dev/null +++ b/src/tools/logger.py @@ -0,0 +1,104 @@ +import logging +import multiprocessing as mp +from logging.handlers import QueueHandler, QueueListener +from typing import IO, Dict, Optional + + +def setup_handler( + stream: Optional[IO] = None, + filename: Optional[str] = None, + log_queue: Optional[mp.Queue] = None, + formatter: logging.Formatter = logging.Formatter( + "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s", datefmt="%H:%M:%S" + ), +) -> logging.Handler: + if stream is not None: + handler = logging.StreamHandler(stream) + elif filename is not None: + handler = logging.FileHandler(filename) + elif log_queue is not None: + return QueueHandler(log_queue) + else: + # Default: stdout + handler = logging.StreamHandler() + + handler.setFormatter(formatter) + + return handler + + +def setup_logger( + name: str, + level: int = logging.DEBUG, + stream: Optional[IO] = None, + filename: Optional[str] = None, + log_queue: Optional[mp.Queue] = None, +) -> logging.Logger: + """ + Creates and returns a logger with optional output: + - log_queue (multiprocessing-safe queue, preferred for child processes) + - stream (e.g., sys.stdout) + - filename (log file) + - defaults to stdout if none is given + """ + logger = logging.getLogger(name) + logger.setLevel(level) + if log_queue: + logger.propagate = False + + logger.handlers.clear() + handler = setup_handler(stream, filename, log_queue) + logger.addHandler(handler) + + return logger + + +class LevelFilter(logging.Filter): + def __init__(self, root_level: int = logging.WARNING): + self.root_level = root_level + self.log_levels: Dict[str, int] = {} + + def filter(self, record: logging.LogRecord) -> bool: + """the root level has priority over custom levels""" + level = self.log_levels.get(record.name, self.root_level) + + return self.root_level <= record.levelno and level <= record.levelno + + def set_root_level(self, level: int) -> None: + self.root_level = level + + def add_level(self, name: str) -> None: + self.log_levels[name] = self.root_level + + def set_level(self, name: str, level: int) -> None: + if name not in self.log_levels: + raise ValueError(f"{name} has no linked log level") + self.log_levels[name] = level + + def set_levels(self, level: int) -> None: + self.set_root_level(level) + for name in self.log_levels: + self.set_level(name, level) + + def del_level(self, name: str) -> None: + del self.log_levels[name] + + +def setup_log_listener( + log_queue: mp.Queue, + filter: logging.Filter, + custom_handler: Optional[logging.Handler] = None, +) -> QueueListener: + """ + Starts a central logging listener that reads LogRecords from a queue + and emits them using normal loggers/handlers. + """ + formatter = logging.Formatter( + "[%(asctime)s] [%(processName)s] [%(name)s] [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + handler = custom_handler or setup_handler(formatter=formatter) + handler.addFilter(filter) + + listener = QueueListener(log_queue, handler) + return listener diff --git a/tests/rag/rag.py b/tests/rag/rag.py index 0054820..9cc489a 100644 --- a/tests/rag/rag.py +++ b/tests/rag/rag.py @@ -7,7 +7,7 @@ def test_main(): rag = Rag(vectorstorePath=f"{__path__}/src/rag/vectorStore") - rag.ragLoader(f"{__path__}/tests/rag/docsRag", "txt") + rag.ragLoad(f"{__path__}/tests/rag/docsRag", "txt") print( rag.ragQuestion( "The new capital of France is Edimburgh and what is the capital of Spain?"