diff --git a/.gitignore b/.gitignore index edcc64d..974cf25 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ k8s/01-secrets.yaml # But keep README.md !README.md + +taskflow-cli/venv diff --git a/core/queue_manager.py b/core/queue_manager.py index dd44f7a..25da4ad 100644 --- a/core/queue_manager.py +++ b/core/queue_manager.py @@ -4,17 +4,16 @@ from .database import SessionLocal from .models import Tasks, TaskStatus +# Configuration LEADER_KEY = "taskflow:leader" -LEASE_TTL_MS = 10000 # Leader lease time (10 seconds) -RENEW_INTERVAL_S = 3 # Try to renew every 3 seconds - -SCHEDULER_INTERVAL_S = 5 # How often to check for scheduled tasks -RECLAIM_INTERVAL_S = 10 # How often to check for stuck tasks +LEASE_TTL_MS = 10000 +RENEW_INTERVAL_S = 3 +SCHEDULER_INTERVAL_S = 5 +RECLAIM_INTERVAL_S = 10 MAX_RETRIES = 3 PROCESSING_QUEUE_PREFIX = "processing" -PROCESSING_RECLAIM_S = 30 # Age (s) after which a processing item is considered stale +PROCESSING_RECLAIM_S = 30 -# Logger configuration import os os.makedirs("logs", exist_ok=True) @@ -22,15 +21,12 @@ level=logging.INFO, format='%(asctime)s - [QueueManager] - %(levelname)s - %(message)s', handlers=[ - logging.FileHandler("logs/queue_manager.log"), # Writes to the file - logging.StreamHandler() # Writes to the terminal + logging.FileHandler("logs/queue_manager.log"), + logging.StreamHandler() ] ) logger = logging.getLogger(__name__) - -# --- LUA SCRIPT FOR ATOMIC RENEWAL --- -# Returns 1 if successful (we still own the lock), 0 if lost RENEW_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) @@ -39,80 +35,57 @@ end """ -# ================== CLIENT FUNCTIONS USED BY THE API & WORKER ===================================== def push_task(queue_name: str, message: dict, priority: str = "low") -> bool: - """ - Pushes a task to the specific Redis instance based on priority. - Returns True on success, False on failure. - """ + """Pushes task with full payload to ensure workers can execute immediately.""" try: r = get_redis_client(priority) json_message = json.dumps(message) r.rpush(queue_name, json_message) - try: - # logging the length of the queue - length = r.llen(queue_name) - logger.debug(f"Pushed task {message.get('task_id')} to {queue_name} (len={length})") - except Exception: - # non-fatal, ignore if LLEN fails - pass + length = r.llen(queue_name) + logger.debug(f"Pushed task {message.get('task_id')} to {queue_name} (len={length})") return True except Exception as e: logger.error(f"Error pushing to {priority} Redis: {e}") return False -# ===================== LEADER COORDINATOR ===================== class QueueManager: def __init__(self): - """ - We use the High Priority Redis for coordination (locking) because - it is less likely to be clogged by bulk tasks. - """ self.instance_id = str(uuid.uuid4()) self.redis = get_redis() self.running = True self.is_leader = False self.renew = self.redis.register_script(RENEW_SCRIPT) - # Handle graceful shutdown (SIGTERM/SIGINT) signal.signal(signal.SIGTERM, self.shutdown) signal.signal(signal.SIGINT, self.shutdown) - def shutdown(self, signum, frame): logger.info("Shutting down QueueManager...") self.running = False if self.is_leader: - logger.info("Releasing leadership lock...") - # Release lock so others can take over immediately try: if self.redis.get(LEADER_KEY) == self.instance_id: self.redis.delete(LEADER_KEY) except Exception as e: logger.error(f"Error releasing lock: {e}") - - # ============== LEADER ELECTION MECHANISM ============= + # Leadership Management """ The leader has a hearbeat whhich it sends to the redis every ten seconds and if the leader doesnt respond for a particular amount of time, then a new leader is elected """ - def try_aquire_leader(self) -> bool: """ Attempts to become leader using Redis SET NX """ try: - result = self.redis.set(LEADER_KEY, self.instance_id, - nx=True, px=LEASE_TTL_MS) + result = self.redis.set(LEADER_KEY, self.instance_id, nx=True, px=LEASE_TTL_MS) return bool(result) except Exception as e: - logger.error(f"Error aquiring the leader: {e}") + logger.error(f"Error acquiring leadership: {e}") return False def renew_lease(self) -> bool: """ Attempts to extend the lease only if we own the lease """ try: - result = self.renew(keys=[LEADER_KEY], - args=[self.instance_id, LEASE_TTL_MS]) - return bool(result) + return bool(self.renew(keys=[LEADER_KEY], args=[self.instance_id, LEASE_TTL_MS])) except Exception as e: logger.error(f"Error renewing lease: {e}") return False @@ -131,71 +104,12 @@ def maintain_leadership(self): if self.try_aquire_leader(): logger.info(f"Instance {self.instance_id} ACQUIRED leadership.") self.is_leader = True - # Sleep less than the TTL to ensure we renew in time time.sleep(RENEW_INTERVAL_S) - - # ===================== LOOPS RUN ONLY BY THE LEADER ======================= - def queued_reconciliation_loop(self): - """ - Recovery mechanism for tasks marked as QUEUED in DB but missing from Redis. - This can happen if Redis was down when tasks were being queued. - Runs less frequently than scheduler to avoid overhead. - """ - logger.info("Queued reconciliation loop started.") - while self.running: - if not self.is_leader: - time.sleep(30) # Check every 30 seconds when not leader - continue - - db = SessionLocal() - try: - # Find tasks that are marked QUEUED in DB - queued_tasks = ( - db.query(Tasks) - .filter(Tasks.status == TaskStatus.QUEUED) - .limit(100) - .all() - ) - - if queued_tasks: - logger.info(f"Reconciling {len(queued_tasks)} QUEUED tasks with Redis") - requeued_count = 0 - - for task in queued_tasks: - payload = {"task_id": task.id} - priority = getattr(task, "priority", "low") or "low" - - # Try to push to Redis - success = push_task("default", payload, priority=priority) - if success: - requeued_count += 1 - else: - logger.error(f"Failed to reconcile task {task.id} to Redis") - - if requeued_count > 0: - logger.info(f"Successfully reconciled {requeued_count} tasks to Redis") - - db.close() - except Exception as e: - logger.error(f"Error in queued reconciliation: {e}") - try: - db.rollback() - db.close() - except Exception: - pass - - time.sleep(30) # Run every 30 seconds - logger.info("Queued reconciliation loop stopped.") + # --- Task Logic Loops --- def scheduler_loop(self): - """ - Efficient scheduler: - - Uses index (status, scheduled_at) by ordering on scheduled_at - - Claims rows with FOR UPDATE SKIP LOCKED to avoid races between schedulers - - Pushes to Redis, then batch-updates DB rows that were successfully queued - """ - logger.info("Scheduler loop started (Waiting for leadership).") + """Claims PENDING tasks and pushes them to Redis using pipelines for performance.""" while self.running: if not self.is_leader: time.sleep(SCHEDULER_INTERVAL_S) @@ -203,258 +117,131 @@ def scheduler_loop(self): db = SessionLocal() try: now = datetime.now(timezone.utc) - # Select candidate tasks using the index-friendly query - # NOTE: with_for_update(skip_locked=True) prevents locking contention candidates = ( db.query(Tasks) - .filter( - Tasks.status == TaskStatus.PENDING, - Tasks.scheduled_at != None, - Tasks.scheduled_at <= now - ) + .filter(Tasks.status == TaskStatus.PENDING, Tasks.scheduled_at <= now) .order_by(Tasks.scheduled_at.asc()) - .limit(100) - .with_for_update(skip_locked=True) - .all() + .limit(100).with_for_update(skip_locked=True).all() ) + if not candidates: - # nothing to do db.close() time.sleep(SCHEDULER_INTERVAL_S) continue - logger.info(f"Scheduler found {len(candidates)} tasks.") - - # Batch push tasks to Redis by priority - high_priority_tasks = [] - low_priority_tasks = [] - + queued_ids = [] + # Updated mapping for push_task logic in queue_manager.py for task in candidates: - payload = {"task_id": task.id} - priority = getattr(task, "priority", "low") or "low" + task_title = str(task.title) + task_payload = task.payload if task.payload is not None else {} + logger.info(f"This is the task_title: {task.title} that we got from the database") + payload = { + "task_id": task.id, + "title": task_title, + "payload": task_payload + } - if priority == "high": - high_priority_tasks.append((task.id, json.dumps(payload))) - else: - low_priority_tasks.append((task.id, json.dumps(payload))) - - # Batch push to Redis using pipeline - queued_ids = [] - - if high_priority_tasks: - try: - r_high = get_redis_client("high") - pipe = r_high.pipeline() - for task_id, json_payload in high_priority_tasks: - pipe.rpush("default", json_payload) - pipe.execute() - queued_ids.extend([tid for tid, _ in high_priority_tasks]) - logger.info(f"Batch pushed {len(high_priority_tasks)} high-priority tasks to Redis") - except Exception as e: - logger.error(f"Failed to batch push high-priority tasks: {e}") - - if low_priority_tasks: - try: - r_low = get_redis_client("low") - pipe = r_low.pipeline() - for task_id, json_payload in low_priority_tasks: - pipe.rpush("default", json_payload) - pipe.execute() - queued_ids.extend([tid for tid, _ in low_priority_tasks]) - logger.info(f"Batch pushed {len(low_priority_tasks)} low-priority tasks to Redis") - except Exception as e: - logger.error(f"Failed to batch push low-priority tasks: {e}") + priority = getattr(task, "priority", "low") or "low" + if push_task("default", payload, priority=priority): + queued_ids.append(task.id) - # Batch-update DB for all queued ids if queued_ids: - now_upd = datetime.now(timezone.utc) - # single UPDATE for performance db.query(Tasks).filter(Tasks.id.in_(queued_ids)).update( - { - Tasks.status: TaskStatus.QUEUED, - Tasks.updated_at: now_upd - }, + {Tasks.status: TaskStatus.QUEUED, Tasks.updated_at: datetime.now(timezone.utc)}, synchronize_session=False ) db.commit() - logger.info(f"Marked {len(queued_ids)} tasks as QUEUED in DB.") - else: - # nothing queued, just rollback to release locks - db.rollback() except Exception as e: - logger.error(f"Error in Scheduler: {e}") + logger.error(f"Scheduler Error: {e}") db.rollback() finally: - # make sure session closed in all situations - try: - db.close() - except Exception: - pass - # sleep before next poll + db.close() time.sleep(SCHEDULER_INTERVAL_S) - logger.info("Scheduler loop stopped.") - def pel_scanner_loop(self): - """ - Recovery Mechanism. - Checks for tasks stuck in 'IN_PROGRESS' for too long (indicating worker crash). - Re-queues them or marks them as failed. - """ - logger.info("PEL Scanner (Recovery) started. ") + """Recovery mechanism that respects the worker startup window.""" while self.running: if self.is_leader: db = SessionLocal() try: - # LOGIC UPDATE: We only recover tasks that are claimed by a worker (IN_PROGRESS) - running_tasks = db.query(Tasks).filter( - Tasks.status == TaskStatus.IN_PROGRESS - ).all() - + running_tasks = db.query(Tasks).filter(Tasks.status == TaskStatus.IN_PROGRESS).all() for task in running_tasks: - if not task.worker_id: - self._recover_task(db, task, "No worker assigned") - continue + # FIX: Wait for worker_id to be written to avoid race condition + if not task.worker_id: continue - # 2. Check Redis for Worker Heartbeat - heartbeat_key = f"worker:{task.worker_id}:heartbeat" - if not self.redis.exists(heartbeat_key): - self._recover_task(db, task, f"Worker {task.worker_id} died") - - db.commit() # <--- FIXED: Commit needed + if not self.redis.exists(f"worker:{task.worker_id}:heartbeat"): + self._recover_task(db, task, f"Worker {task.worker_id} dead") + db.commit() except Exception as e: - logger.error(f"Error in PEL Scanner: {e}") + logger.error(f"PEL Scanner Error: {e}") db.rollback() finally: db.close() time.sleep(RECLAIM_INTERVAL_S) - # loop runniing in the background for adding stale tasks back in the queue + def _recover_task(self, db, task: Tasks, reason): + """Re-queues task with script payload if retry limit not exceeded.""" + if task.retry_count < MAX_RETRIES: + # FIX: Payload must include title/code for the worker + payload = {"task_id": task.id, "title": task.title, "payload": task.payload} + if push_task("default", payload, priority=getattr(task, "priority", "low")): + task.status = TaskStatus.QUEUED + task.worker_id = None + task.retry_count += 1 + task.updated_at = datetime.now(timezone.utc) + else: + task.status = TaskStatus.FAILED + task.updated_at = datetime.now(timezone.utc) + + def processing_reclaimer_loop(self): - """Scan `processing:*` lists and move stale items back to the main queue. - We look at the `processing:default` list (where workers atomically move - items) and for each element we check the DB row. If the DB row is not - IN_PROGRESS and it has not been updated recently, we consider the - processing item stale and move it back to the main queue. - """ - logger.info("Processing reclaimer started.") + """Moves stale items from processing lists back to main queue.""" low_redis = get_redis_client('low') - processing_queue = f"{PROCESSING_QUEUE_PREFIX}:default" + p_queue = f"{PROCESSING_QUEUE_PREFIX}:default" while self.running: if not self.is_leader: time.sleep(RECLAIM_INTERVAL_S) continue - try: - items = low_redis.lrange(processing_queue, 0, -1) or [] - if not items: - time.sleep(RECLAIM_INTERVAL_S) - continue - - now = datetime.now(timezone.utc) + items = low_redis.lrange(p_queue, 0, -1) or [] for raw in items: - try: - payload = json.loads(raw) - task_id = payload.get('task_id') - except Exception: - # If payload is unreadable, remove it to avoid blocking - logger.warning("Removing unreadable payload from processing queue") - try: - low_redis.lrem(processing_queue, 0, raw) - except Exception: - logger.exception("Failed to remove unreadable payload") - continue - + data = json.loads(raw) db = SessionLocal() - try: - task = db.query(Tasks).filter(Tasks.id == task_id).first() - if not task: - # No DB row, remove the message - low_redis.lrem(processing_queue, 0, raw) - continue - - # If the task is currently IN_PROGRESS, skip it - if task.status == TaskStatus.IN_PROGRESS: - continue - - # If task hasn't been updated recently, requeue it - updated_at = getattr(task, 'updated_at', None) - age = (now - updated_at).total_seconds() if updated_at else None - if age is None or age > PROCESSING_RECLAIM_S: - logger.warning(f"Reclaiming stale processing item for task {task_id}") - # Move the message back to main queue and update DB - try: - low_redis.lrem(processing_queue, 0, raw) - low_redis.lpush('default', raw) - except Exception: - logger.exception("Failed to move item back to default queue") - - task.status = TaskStatus.QUEUED - task.worker_id = None - task.updated_at = now - db.commit() - except Exception: - logger.exception("Error while examining processing item; rolling back DB") - db.rollback() - finally: - db.close() + task = db.query(Tasks).filter(Tasks.id == data.get('task_id')).first() + if task and task.status != TaskStatus.IN_PROGRESS: + low_redis.lrem(p_queue, 0, raw) + low_redis.lpush('default', raw) + db.close() + except Exception as e: + logger.error(f"Reclaimer Error: {e}") + time.sleep(RECLAIM_INTERVAL_S) - time.sleep(RECLAIM_INTERVAL_S) - except Exception: - logger.exception("Error in processing reclaimer") - time.sleep(RECLAIM_INTERVAL_S) - def _recover_task(self, db, task: Tasks, reason): - """Helper to retry or fail a task.""" - if task.retry_count < MAX_RETRIES: - logger.warning(f"Recovering Task {task.id}: {reason}") - - # Re-push to Redis - success = push_task("default", {"task_id": task.id}) + def queued_reconciliation_loop(self): + """Fixes sync issues where DB says QUEUED but Redis is empty.""" + while self.running: + if self.is_leader: + db = SessionLocal() + try: + queued = db.query(Tasks).filter(Tasks.status == TaskStatus.QUEUED).limit(100).all() + for t in queued: + push_task("default", {"task_id": t.id, "title": t.title, "payload": t.payload}) + finally: + db.close() + time.sleep(30) - if success: - # Reset to QUEUED so a new worker can pick it up - task.status = TaskStatus.QUEUED - task.worker_id = None - task.retry_count += 1 - task.updated_at = datetime.now(timezone.utc) - else: - logger.error(f"Task {task.id} FAILED: {reason} (Max retries)") - task.status = TaskStatus.FAILED - task.worker_id = None - task.updated_at = datetime.now(timezone.utc) - # ======================= ENTRY POINT ========================================== def start(self): - """ - Starts all the threads - """ - - logger.info(f"Starting Queue Manager {self.instance_id}...") - - # 1. Start Leadership Maintainer - t_leader = threading.Thread(target=self.maintain_leadership, daemon=True) - t_leader.start() - # 2. Start Scheduler - t_scheduler = threading.Thread(target=self.scheduler_loop, daemon=True) - t_scheduler.start() - # 3. Start PEL Scanner - t_scanner = threading.Thread(target=self.pel_scanner_loop, daemon=True) - t_scanner.start() - # 4. Start Processing Reclaimer (moves stale items from processing back) - t_reclaimer = threading.Thread(target=self.processing_reclaimer_loop, daemon=True) - t_reclaimer.start() - # 5. Start Queued Reconciliation (ensures QUEUED tasks are in Redis) - t_reconciler = threading.Thread(target=self.queued_reconciliation_loop, daemon=True) - t_reconciler.start() - - # Keep main thread alive to handle signals - try: - while self.running: - time.sleep(1) - except KeyboardInterrupt: - self.shutdown(None, None) + logger.info(f"Queue Manager {self.instance_id} online.") + t_list = [ + threading.Thread(target=self.maintain_leadership, daemon=True), + threading.Thread(target=self.scheduler_loop, daemon=True), + threading.Thread(target=self.pel_scanner_loop, daemon=True), + threading.Thread(target=self.processing_reclaimer_loop, daemon=True), + threading.Thread(target=self.queued_reconciliation_loop, daemon=True) + ] + for t in t_list: t.start() + while self.running: time.sleep(1) if __name__ == "__main__": - qm = QueueManager() - qm.start() \ No newline at end of file + QueueManager().start() \ No newline at end of file diff --git a/taskflow-cli/requirements.txt b/taskflow-cli/requirements.txt new file mode 100644 index 0000000..e4b8dd2 --- /dev/null +++ b/taskflow-cli/requirements.txt @@ -0,0 +1,16 @@ +certifi==2026.1.4 +charset-normalizer==3.4.4 +click==8.3.1 +idna==3.11 +keyring==25.7.0 +markdown-it-py==4.0.0 +mdurl==0.1.2 +pillow==12.1.0 +Pygments==2.19.2 +requests==2.32.5 +rich==14.2.0 +rich-pixels==3.0.1 +shellingham==1.5.4 +typer==0.21.1 +typing_extensions==4.15.0 +urllib3==2.6.3 diff --git a/taskflow-cli/run_cli.py b/taskflow-cli/run_cli.py new file mode 100755 index 0000000..406da03 --- /dev/null +++ b/taskflow-cli/run_cli.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +""" +TaskFlow CLI Entry Point + +This script allows running the CLI from the project root without installation. +""" + +import sys +from pathlib import Path + +# Add the parent directory to the path so we can import taskflow +sys.path.insert(0, str(Path(__file__).parent)) + +from taskflow.main import run + +if __name__ == "__main__": + run() diff --git a/taskflow-cli/setup.py b/taskflow-cli/setup.py new file mode 100644 index 0000000..08d4eaf --- /dev/null +++ b/taskflow-cli/setup.py @@ -0,0 +1,28 @@ +from setuptools import setup, find_packages + +with open("requirements.txt") as f: + requirements = f.read().splitlines() + +setup( + name="taskflow-cli", + version="2.1.0", + description="TaskFlow CLI - Distributed Task Orchestrator", + author="TaskFlow Team", + packages=find_packages(), + include_package_data=True, + install_requires=requirements, + entry_points={ + "console_scripts": [ + "taskflow=taskflow.main:run", + ], + }, + python_requires=">=3.8", + classifiers=[ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Operating System :: OS Independent", + ], +) diff --git a/taskflow-cli/taskflow/__init__.py b/taskflow-cli/taskflow/__init__.py new file mode 100644 index 0000000..c109764 --- /dev/null +++ b/taskflow-cli/taskflow/__init__.py @@ -0,0 +1,13 @@ +""" +TaskFlow CLI - Distributed Task Orchestrator + +A powerful command-line interface for managing distributed task orchestration. +""" + +__version__ = "2.1.0" +__author__ = "TaskFlow Team" + +from . import cli, auth, api, main + +__all__ = ["cli", "auth", "api", "main"] + diff --git a/taskflow-cli/taskflow/api.py b/taskflow-cli/taskflow/api.py new file mode 100644 index 0000000..299a932 --- /dev/null +++ b/taskflow-cli/taskflow/api.py @@ -0,0 +1,36 @@ +import requests +import os +import json +from pathlib import Path +from rich.console import Console +from .auth import get_token + +console = Console() + +BASE_URL = "http://localhost:8080" + + +def get_headers(): + """Get headers with JWT token if available.""" + headers = {} + token = get_token() + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +def api_request(method: str, endpoint: str, **kwargs): + """Make an authenticated API request.""" + headers = kwargs.get("headers", {}) + headers.update(get_headers()) + kwargs["headers"] = headers + + # Refresh BASE_URL in case config changed + url = f"{BASE_URL}{endpoint}" + + try: + response = requests.request(method, url, **kwargs) + return response + except requests.exceptions.RequestException as e: + console.print(f"[bold red]Connection Error:[/] {e}") + return None \ No newline at end of file diff --git a/taskflow-cli/taskflow/auth.py b/taskflow-cli/taskflow/auth.py new file mode 100644 index 0000000..3479e2f --- /dev/null +++ b/taskflow-cli/taskflow/auth.py @@ -0,0 +1,37 @@ +import keyring +import requests +import os +import json +from pathlib import Path + +SERVICE_NAME = "taskflow-cli" +TOKEN_KEY = "jwt_token" + +BASE_URL = "http://localhost:8080" + + +def save_token(token: str): + """Saves the JWT token to the system's secure keyring.""" + keyring.set_password(SERVICE_NAME, TOKEN_KEY, token) + + +def get_token(): + """Retrieves the stored token for API requests.""" + return keyring.get_password(SERVICE_NAME, TOKEN_KEY) + + +def delete_token(): + """Removes the token on logout.""" + keyring.delete_password(SERVICE_NAME, TOKEN_KEY) + + +def api_request(method, endpoint, **kwargs): + """Wrapper for requests that automatically injects the Bearer token.""" + token = get_token() + headers = kwargs.get("headers", {}) + if token: + headers["Authorization"] = f"Bearer {token}" + + kwargs["headers"] = headers + url = f"{BASE_URL}{endpoint}" + return requests.request(method, url, **kwargs) \ No newline at end of file diff --git a/taskflow-cli/taskflow/cli.py b/taskflow-cli/taskflow/cli.py new file mode 100644 index 0000000..9735e93 --- /dev/null +++ b/taskflow-cli/taskflow/cli.py @@ -0,0 +1,411 @@ +import typer +from rich.console import Console +from rich.table import Table +from rich.prompt import Prompt, Confirm +from .auth import save_token, delete_token, get_token +from .api import api_request +import os +from pathlib import Path +import json + +# initialise the app and console +app = typer.Typer(help="TaskFlow CLI - Manage your tasks from the terminal") +console = Console() + + +@app.command() +def register( + email: str = typer.Option(..., "--email", "-e", prompt=True, help="Your email address"), + username: str = typer.Option(..., "--username", "-u", prompt=True, help="Your username"), + password: str = typer.Option(..., "--password", "-p", prompt=True, hide_input=True, help="Your password") +): + """Register a new TaskFlow account.""" + console.print("\n[bold cyan]Creating your TaskFlow account...[/]") + + data = { + "email": email, + "username": username, + "password": password + } + + response = api_request("POST", "/users/", json=data) + + if response is None: + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + elif response.status_code == 201: + user_data = response.json() + console.print(f"\n[bold green]✓[/] Account created successfully!") + console.print(f"[dim]User ID:[/] {user_data['id']}") + console.print(f"[dim]Username:[/] {user_data['username']}") + console.print(f"[dim]Email:[/] {user_data['email']}") + console.print("\n[yellow]→[/] You can now login with: [bold]login[/]") + else: + try: + error = response.json().get("detail", "Registration failed") + except: + error = "Registration failed" + console.print(f"\n[bold red]✗[/] {error}") + + +@app.command() +def login( + identifier: str = typer.Option(..., "--identifier", "-i", prompt=True, help="Your email or username"), + password: str = typer.Option(..., "--password", "-p", prompt=True, hide_input=True, help="Your password") +): + """Login to your TaskFlow account.""" + console.print("\n[bold cyan]Logging in...[/]") + + data = { + "identifier": identifier, + "password": password + } + + response = api_request("POST", "/login", json=data) + + if response is None: + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + elif response.status_code == 200: + token_data = response.json() + save_token(token_data["access_token"]) + console.print("\n[bold green]✓[/] Login successful!") + console.print("[dim]Your session has been saved securely.[/]") + else: + # Handle error responses (403, 401, etc.) + try: + error = response.json().get("detail", "Login failed") + except: + error = "Login failed" + console.print(f"\n[bold red]✗[/] {error}") + + +@app.command() +def logout(): + """Logout from your TaskFlow account.""" + if not get_token(): + console.print("[yellow]You are not logged in.[/]") + return + + if Confirm.ask("\n[bold yellow]Are you sure you want to logout?[/]"): + delete_token() # deleting the token will not allow the user to make any commands + console.print("\n[bold green]✓[/] Logged out successfully!") + else: + console.print("[dim]Logout cancelled.[/]") + + +@app.command() +def upload_file( + file_path: str = typer.Argument(..., help="Path to the Python file to upload"), + title: str = typer.Option(..., "--title", "-t", help="Task title/name for this file") +): + """Upload a Python task file to TaskFlow.""" + if not get_token(): + console.print("[bold red]✗[/] You must be logged in to upload files.") + console.print("[dim]Run:[/] taskflow login") + return + + file_path_obj = Path(file_path) + + if not file_path_obj.exists(): + console.print(f"[bold red]✗[/] File not found: {file_path}") + return + + if not file_path_obj.suffix == ".py": + console.print("[bold red]✗[/] Only Python (.py) files are allowed") + return + + console.print(f"\n[bold cyan]Uploading task file...[/]") + console.print(f"[dim]File:[/] {file_path_obj.name}") + console.print(f"[dim]Title:[/] {title}") + + with open(file_path_obj, "rb") as f: + files = {"file": (file_path_obj.name, f, "text/x-python")} + params = {"file_name": title} + + response = api_request("POST", "/tasks/upload_file", files=files, params=params) + + if response is None: + + + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + + elif response.status_code == 201: + result = response.json() + console.print(f"\n[bold green]✓[/] {result['message']}") + console.print(f"\n[yellow]→[/] You can now create tasks with title: [bold]{title}[/]") + elif response: + error = response.json().get("detail", "Upload failed") + console.print(f"\n[bold red]✗[/] Upload failed: {error}") + else: + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + +@app.command() +def create_task( + title: str = typer.Option(..., "--title", "-t", help="Task title (must match uploaded file)"), + payload: str = typer.Option(..., "--payload", "-p", help="Task payload data"), + scheduled_at: int = typer.Option(0, "--scheduled-at", "-s", help="Schedule task in N minutes from now") +): + """Create a new task.""" + if not get_token(): + console.print("[bold red]✗[/] You must be logged in to create tasks.") + console.print("[dim]Run:[/] taskflow login") + return + + console.print(f"\n[bold cyan]Creating task...[/]") + + data = { + "title": title, + "payload": payload, + "scheduled_at": scheduled_at + } + + response = api_request("POST", "/tasks/", json=data) + + if response is None: + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + elif response.status_code == 201: + task = response.json() + console.print(f"\n[bold green]✓[/] Task created successfully!") + console.print(f"[dim]Task ID:[/] {task['id']}") + console.print(f"[dim]Title:[/] {task['title']}") + console.print(f"[dim]Status:[/] {task['status']}") + console.print(f"[dim]Scheduled:[/] {task['scheduled_at']}") + else: + try: + error = response.json().get("detail", "Task creation failed") + except: + error = "Task creation failed" + console.print(f"\n[bold red]✗[/] {error}") + + +@app.command() +def list_tasks( + limit: int = typer.Option(10, "--limit", "-l", help="Number of tasks to retrieve"), + skip: int = typer.Option(0, "--skip", help="Number of tasks to skip"), + search: str = typer.Option("", "--search", "-s", help="Search tasks by title"), + status: str = typer.Option(None, "--status", help="Filter by status: pending, processing, completed, failed") +): + """List all your tasks.""" + if not get_token(): + console.print("[bold red]✗[/] You must be logged in to view tasks.") + console.print("[dim]Run:[/] taskflow login") + return + + console.print(f"\n[bold cyan]Fetching your tasks...[/]") + + params = { + "limit": limit, + "skip": skip, + "search": search + } + if status: + params["status"] = status + + response = api_request("GET", "/tasks/", params=params) + + if response is None: + + + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + + elif response.status_code == 200: + tasks = response.json() + + if not tasks: + console.print("\n[yellow]No tasks found.[/]") + return + + table = Table(title=f"\n[bold]Your Tasks[/] ({len(tasks)} found)") + table.add_column("ID", style="cyan", no_wrap=True) + table.add_column("Title", style="magenta") + table.add_column("Status", style="green") + table.add_column("Created At", style="blue") + table.add_column("Scheduled At", style="yellow") + + for task in tasks: + table.add_row( + str(task["id"]), + task["title"], + task["status"], + task["created_at"][:19], + task["scheduled_at"][:19] + ) + + console.print(table) + elif response: + error = response.json().get("detail", "Failed to fetch tasks") + console.print(f"\n[bold red]✗[/] Failed to fetch tasks: {error}") + else: + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + +@app.command() +def get_task(task_id: int = typer.Argument(..., help="Task ID to retrieve")): + """Get details of a specific task.""" + if not get_token(): + console.print("[bold red]✗[/] You must be logged in to view tasks.") + console.print("[dim]Run:[/] taskflow login") + return + + console.print(f"\n[bold cyan]Fetching task {task_id}...[/]") + + response = api_request("GET", f"/tasks/{task_id}") + + if response is None: + + + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + + elif response.status_code == 200: + task = response.json() + + console.print(f"\n[bold]Task Details[/]") + console.print(f"[cyan]ID:[/] {task['id']}") + console.print(f"[cyan]Title:[/] {task['title']}") + console.print(f"[cyan]Status:[/] {task['status']}") + console.print(f"[cyan]Owner ID:[/] {task['owner_id']}") + console.print(f"[cyan]Created At:[/] {task['created_at']}") + console.print(f"[cyan]Scheduled At:[/] {task['scheduled_at']}") + elif response and response.status_code == 404: + console.print(f"\n[bold red]✗[/] Task with ID {task_id} not found") + elif response: + error = response.json().get("detail", "Failed to fetch task") + console.print(f"\n[bold red]✗[/] Failed to fetch task: {error}") + else: + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + +@app.command() +def delete_task(task_id: int = typer.Argument(..., help="Task ID to delete")): + """Delete a specific task.""" + if not get_token(): + console.print("[bold red]✗[/] You must be logged in to delete tasks.") + console.print("[dim]Run:[/] taskflow login") + return + + if not Confirm.ask(f"\n[bold yellow]Are you sure you want to delete task {task_id}?[/]"): + console.print("[dim]Deletion cancelled.[/]") + return + + console.print(f"\n[bold cyan]Deleting task {task_id}...[/]") + + response = api_request("DELETE", f"/tasks/{task_id}") + + if response is None: + + + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + + elif response.status_code == 204: + console.print(f"\n[bold green]✓[/] Task {task_id} deleted successfully!") + elif response and response.status_code == 404: + console.print(f"\n[bold red]✗[/] Task with ID {task_id} not found") + elif response and response.status_code == 401: + console.print(f"\n[bold red]✗[/] Not authorized to delete this task") + elif response: + error = response.json().get("detail", "Failed to delete task") + console.print(f"\n[bold red]✗[/] Failed to delete task: {error}") + else: + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + +@app.command() +def delete_file( + title: str = typer.Option(..., "--title", "-t", help="Task file title to delete") +): + """Delete a task file from the server.""" + if not get_token(): + console.print("[bold red]✗[/] You must be logged in to delete files.") + console.print("[dim]Run:[/] taskflow login") + return + + if not Confirm.ask(f"\n[bold yellow]Are you sure you want to delete task file '{title}.py'?[/]"): + console.print("[dim]Deletion cancelled.[/]") + return + + console.print(f"\n[bold cyan]Deleting task file...[/]") + + params = {"file_name": title} + response = api_request("DELETE", "/tasks/delete_file", params=params) + + if response is None: + + + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + + elif response.status_code == 200: + result = response.json() + console.print(f"\n[bold green]✓[/] {result['message']}") + elif response and response.status_code == 404: + console.print(f"\n[bold red]✗[/] Task file '{title}.py' not found") + elif response: + error = response.json().get("detail", "Failed to delete file") + console.print(f"\n[bold red]✗[/] Failed to delete file: {error}") + else: + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + +@app.command() +def whoami(): + """Display current login status.""" + token = get_token() + + if token: + console.print("\n[bold green]✓[/] You are logged in") + console.print(f"[dim]Token stored securely in system keyring[/]") + + # Try to get user info + response = api_request("GET", "/tasks/", params={"limit": 1}) + if response is None: + + console.print("\n[bold red]✗[/] Could not connect to TaskFlow API") + + elif response.status_code == 200: + console.print("[dim]Connection to API: Active[/]") + else: + console.print("[yellow]Warning: Token may be expired or invalid[/]") + else: + console.print("\n[yellow]You are not logged in[/]") + console.print("[dim]Run:[/] taskflow login") + + +@app.command() +def config( + show: bool = typer.Option(False, "--show", help="Show current configuration"), + api_url: str = typer.Option(None, "--api-url", help="Set the API URL") +): + """Configure CLI settings.""" + config_file = Path.home() / ".taskflow" / "config.json" + config_file.parent.mkdir(exist_ok=True) + + if show: + if config_file.exists(): + with open(config_file, "r") as f: + cfg = json.load(f) + console.print("\n[bold]Current Configuration:[/]") + for key, value in cfg.items(): + console.print(f" [cyan]{key}:[/] {value}") + else: + console.print("\n[yellow]No configuration file found[/]") + console.print(f"[dim]Default API URL:[/] {os.getenv('TASKFLOW_API_URL', 'http://localhost:8000')}") + return + + if api_url: + cfg = {} + if config_file.exists(): + with open(config_file, "r") as f: + cfg = json.load(f) + + cfg["api_url"] = api_url + os.environ["TASKFLOW_API_URL"] = api_url + + with open(config_file, "w") as f: + json.dump(cfg, f, indent=2) + + console.print(f"\n[bold green]✓[/] API URL set to: {api_url}") + diff --git a/taskflow-cli/taskflow/main.py b/taskflow-cli/taskflow/main.py new file mode 100644 index 0000000..c22e43f --- /dev/null +++ b/taskflow-cli/taskflow/main.py @@ -0,0 +1,216 @@ +import typer +import time +import sys +import signal +from rich.console import Console +from .cli import app as cli_app + +console = Console() +app = typer.Typer() + +# Track CTRL+C presses +ctrl_c_count = 0 +last_ctrl_c_time = 0 + +# Register all CLI commands +app.add_typer(cli_app, name="", help="TaskFlow CLI commands") + + +def signal_handler(sig, frame): + """Handle CTRL+C gracefully with double-press confirmation.""" + global ctrl_c_count, last_ctrl_c_time + + current_time = time.time() + + # Reset counter if more than 2 seconds passed since last CTRL+C + if current_time - last_ctrl_c_time > 2: + ctrl_c_count = 0 + + ctrl_c_count += 1 + last_ctrl_c_time = current_time + + if ctrl_c_count == 1: + console.print("\n[yellow]Press CTRL+C again within 2 seconds to exit[/]") + else: + console.print("\n[bold red]Exiting TaskFlow CLI...[/]") + sys.exit(0) + + +# Register signal handler +signal.signal(signal.SIGINT, signal_handler) + + +def typewriter_print(text: str, delay: float = 0.04): + """Simulates a typewriter transition for terminal text.""" + for char in text: + sys.stdout.write(char) + sys.stdout.flush() + time.sleep(delay) + print() + + +def display_splash(): + # ASCII Art for TASKFLOW + ascii_logo = """ + ████████╗ █████╗ ███████╗██╗ ██╗███████╗██╗ ██████╗ ██╗ ██╗ + ╚══██╔══╝██╔══██╗██╔════╝██║ ██╔╝██╔════╝██║ ██╔═══██╗██║ ██║ + ██║ ███████║███████╗█████═╝ █████╗ ██║ ██║ ██║██║ █╗ ██║ + ██║ ██╔══██║╚════██║██╔═██╗ ██╔══╝ ██║ ██║ ██║██║███╗██║ + ██║ ██║ ██║███████║██║ ██╗██║ ███████╗╚██████╔╝╚███╔███╔╝ + ╚═╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝╚═╝ ╚══════╝ ╚═════╝ ╚══╝╚══╝ + """ + + console.print(f"[orange_red1]{ascii_logo}[/orange_red1]") + + # Version and Subtitle + console.print(" [bold white]v2.1.0[/bold white] | [dim]Distributed Task Orchestrator[/dim]\n") + + # Personalized Typewriter Greeting + typewriter_print("> Connection established. Ready to process tasks", delay=0.003) + + +@app.callback(invoke_without_command=True) +def main(ctx: typer.Context): + """TaskFlow CLI - Distributed Task Orchestrator""" + if ctx.invoked_subcommand is None: + display_splash() + console.print("\n[bold cyan]Quick Start:[/]") + console.print(" [bold]register[/] - Create a new account") + console.print(" [bold]login[/] - Login to your account") + console.print(" [bold]upload-file[/] - Upload a task Python file") + console.print(" [bold]create-task[/] - Create a new task") + console.print(" [bold]list-tasks[/] - View all your tasks") + console.print("\n[dim]Type a command or [bold]help[/bold] to see all commands.") + console.print("[dim italic]Press CTRL+C twice to exit[/dim italic]\n") + + # Enter interactive mode + interactive_mode() + + +def interactive_mode(): + """Run CLI in interactive loop mode.""" + import shlex + from rich.prompt import Confirm + from rich.table import Table + + while True: + try: + # Prompt for command + user_input = console.input("[bold cyan]taskflow>[/] ") + + if not user_input.strip(): + continue + + # Handle built-in commands + if user_input.strip().lower() in ['exit', 'quit']: + if Confirm.ask("\n[bold yellow]Are you sure you want to exit?[/]"): + console.print("[bold green]Goodbye![/]") + sys.exit(0) + continue + + if user_input.strip().lower() == 'clear': + console.clear() + continue + + if user_input.strip().lower() in ['help', '--help', '-h']: + display_help() + continue + + # Parse the command + args = shlex.split(user_input) + + # Execute the command by invoking the main app with args + try: + # Save original argv + original_argv = sys.argv.copy() + + # Set new argv with the command + sys.argv = ['taskflow'] + args + + # Invoke the app + try: + app(args, standalone_mode=False) + except SystemExit: + pass + + # Restore original argv + sys.argv = original_argv + + except Exception as e: + # Restore argv on error + sys.argv = original_argv + console.print(f"[bold red]Error:[/] {str(e)}") + console.print("[dim]Type 'help' to see available commands[/]") + + except KeyboardInterrupt: + # Let the global handler deal with it + signal_handler(signal.SIGINT, None) + except EOFError: + console.print("\n[bold red]Exiting TaskFlow CLI...[/]") + sys.exit(0) + + +def display_help(): + """Display help menu with all available commands.""" + from rich.table import Table + + console.print("\n[bold cyan]TaskFlow CLI - Available Commands[/]\n") + + # Single table with all commands + table = Table(show_header=True, header_style="bold magenta") + table.add_column("Command", style="cyan", no_wrap=True, width=45) + table.add_column("Description", style="white") + + # Authentication commands + table.add_row("[bold yellow]Authentication[/]", "") + table.add_row("register", "Create a new TaskFlow account") + table.add_row("login", "Login to your account") + table.add_row("logout", "Logout from your account") + + # Task management commands + table.add_row("", "") + table.add_row("[bold yellow]Task Management[/]", "") + table.add_row("create-task --title --payload ", "Create a new task") + table.add_row(" --scheduled-at ", " Schedule task in N minutes (default: 0)") + table.add_row("list-tasks", "List all your tasks") + table.add_row(" --limit ", " Number of tasks to show (default: 10)") + table.add_row(" --skip ", " Skip first N tasks (default: 0)") + table.add_row(" --search ", " Search tasks by title") + table.add_row(" --status ", " Filter by status (pending/processing/completed/failed)") + table.add_row("get-task ", "Get details of a specific task") + table.add_row("delete-task ", "Delete a task") + + # File management commands + table.add_row("", "") + table.add_row("[bold yellow]File Management[/]", "") + table.add_row("upload-file --title ", "Upload a Python task file") + table.add_row(" ", " Path to the .py file to upload") + table.add_row(" --title (required)", " Task name to use when creating tasks") + table.add_row("delete-file --title ", "Delete an uploaded task file") + + # Built-in commands + table.add_row("", "") + table.add_row("[bold yellow]Built-in Commands[/]", "") + table.add_row("help", "Show this help message") + table.add_row("clear", "Clear the screen") + table.add_row("exit / quit", "Exit the CLI") + + console.print(table) + console.print() + + console.print("[dim]For detailed help on a command, use: [bold] --help[/bold][/dim]") + console.print("[dim]Example: [bold]register --help[/bold] or [bold]upload-file --help[/bold][/dim]\n") + + +def run(): + """Entry point for the CLI.""" + try: + app() + except KeyboardInterrupt: + # This catches any unhandled CTRL+C + console.print("\n[bold red]Exiting TaskFlow CLI...[/]") + sys.exit(0) + + +if __name__ == "__main__": + run() \ No newline at end of file diff --git a/worker/loader.py b/worker/loader.py index 9cf33fd..854bf8a 100644 --- a/worker/loader.py +++ b/worker/loader.py @@ -1,16 +1,26 @@ import importlib.util import os import logging +import sys logger = logging.getLogger(__name__) +# Use absolute path - worker container has files at /app/worker/tasks/ +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # /app/worker/ +TASKS_DIR = os.path.join(BASE_DIR, "tasks") # /app/worker/tasks/ + def load_handler(task_title: str): - file_path = f"worker/tasks/{task_title}.py" + file_path = os.path.join(TASKS_DIR, f"{task_title}.py") if not os.path.exists(file_path): - return None, "File not found" + logger.error(f"Worker could not find task file at: {file_path}") + return None, f"File not found at {file_path}" try: + # Prevent "Zombie Modules" by clearing cache if it was loaded before + if task_title in sys.modules: + del sys.modules[task_title] + spec = importlib.util.spec_from_file_location(task_title, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -20,7 +30,6 @@ def load_handler(task_title: str): return None, "No handler function found" except ImportError as e: - # This catches "No module named 'pandas'" etc. logger.error(f"Missing dependency in {task_title}: {e}") return None, f"Dependency Error: {str(e)}" except Exception as e: diff --git a/worker/main.py b/worker/main.py index 2ff4534..1a66775 100644 --- a/worker/main.py +++ b/worker/main.py @@ -4,19 +4,18 @@ import signal import asyncio import os -import sys # Core imports from core.redis_client import get_async_redis_client from .heartbeat import HeartbeatService from .task_handler import execute_dynamic_task -# Import the new database helper +# Import the updated database helper that supports worker_id from .utils import update_task_status # Ensure logs directory exists os.makedirs("logs", exist_ok=True) -# Logging matches your Victus environment +# Logging configuration matches your Victus environment logging.basicConfig( level=logging.INFO, format='%(asctime)s - [Worker] - %(levelname)s - %(message)s', @@ -32,6 +31,7 @@ class AsyncWorker: def __init__(self): + # Generate a unique short ID for this worker instance self.worker_id = str(uuid.uuid4())[:8] self.running = True self.redis_high = None @@ -39,28 +39,24 @@ def __init__(self): self.heartbeat = HeartbeatService(self.worker_id) async def start(self): - logger.info(f"Async worker:{self.worker_id} starting up on modular-worker branch...") + logger.info(f"Async worker:{self.worker_id} starting up on TaskFlow cluster...") self.redis_high = await get_async_redis_client("high") self.redis_low = await get_async_redis_client("low") + + # Start the heartbeat so the Leader knows this worker is alive await self.heartbeat.start() - logger.info(f"Worker:{self.worker_id} listening for dynamic tasks on Redis.") + logger.info(f"Worker:{self.worker_id} listening for tasks on Redis.") while self.running: try: raw_data = None - # Priority-based Redis polling - if hasattr(self.redis_high, 'blmove'): - raw_data = await self.redis_high.blmove(QUEUE_NAME, PROCESSING_QUEUE, 'RIGHT', 'LEFT', 1) - else: - raw_data = await self.redis_high.brpoplpush(QUEUE_NAME, PROCESSING_QUEUE, 1) + # Atomically move task from main queue to processing queue + raw_data = await self.redis_high.brpoplpush(QUEUE_NAME, PROCESSING_QUEUE, 1) if not raw_data: - if hasattr(self.redis_low, 'blmove'): - raw_data = await self.redis_low.blmove(QUEUE_NAME, PROCESSING_QUEUE, 'RIGHT', 'LEFT', 1) - else: - raw_data = await self.redis_low.brpoplpush(QUEUE_NAME, PROCESSING_QUEUE, 1) + raw_data = await self.redis_low.brpoplpush(QUEUE_NAME, PROCESSING_QUEUE, 1) if raw_data: try: @@ -72,29 +68,28 @@ async def start(self): task_id = data.get('task_id') task_title = data.get('title') - payload = data.get('payload') # The JSONB salted dictionary + payload = data.get('payload') - logger.info(f"Worker:{self.worker_id} processing Dynamic Task: {task_id}") + logger.info(f"Worker:{self.worker_id} claiming Task: {task_id}") try: - # 1. UPDATE DB: Mark as starting - await update_task_status(task_id, "IN_PROGRESS") + # --- THE CRITICAL FIX --- + # Pass self.worker_id so the Leader's PEL scanner sees this task is claimed + await update_task_status(task_id, "IN_PROGRESS", self.worker_id) - # 2. EXECUTE: Load and run the uploaded script + # Execute the dynamically loaded script result = await execute_dynamic_task(task_title, payload) logger.info(f"Task {task_id} COMPLETED successfully.") - - # 3. UPDATE DB: Mark as success with result - await update_task_status(task_id, "COMPLETED", result=json.dumps(result)) + await update_task_status(task_id, "COMPLETED") except Exception as e: logger.error(f"Execution failed for Task {task_id}: {str(e)}") - - # 4. UPDATE DB: Mark as failed with error message - await update_task_status(task_id, "FAILED", result=str(e)) + # Mark as failed in DB + await update_task_status(task_id, "FAILED") finally: + # Task is finished (success or fail), remove from processing queue await self._remove_from_processing(raw_data) except Exception as e: @@ -103,12 +98,13 @@ async def start(self): await asyncio.sleep(2) # Shutdown Logic - logger.info(f"Worker:{self.worker_id} shutting down...") + logger.info(f"Worker:{self.worker_id} gracefully shutting down...") await self.heartbeat.stop() if self.redis_low: await self.redis_low.aclose() if self.redis_high: await self.redis_high.aclose() async def _remove_from_processing(self, raw_data): + """Clean up the processing queue in both Redis instances""" try: await self.redis_low.lrem(PROCESSING_QUEUE, 0, raw_data) await self.redis_high.lrem(PROCESSING_QUEUE, 0, raw_data) @@ -116,6 +112,7 @@ async def _remove_from_processing(self, raw_data): logger.exception("Failed to remove item from processing queue") async def _cleanup_malformed(self, raw_data): + """Remove messages that cannot be parsed as JSON""" try: await self.redis_low.lrem(PROCESSING_QUEUE, 0, raw_data) await self.redis_high.lrem(PROCESSING_QUEUE, 0, raw_data) @@ -128,6 +125,7 @@ def request_shutdown(self): async def main(): worker = AsyncWorker() loop = asyncio.get_running_loop() + # Handle OS signals for clean shutdown in Kubernetes for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, worker.request_shutdown) await worker.start() diff --git a/worker/task_handler.py b/worker/task_handler.py index a46b26b..b1ee745 100644 --- a/worker/task_handler.py +++ b/worker/task_handler.py @@ -6,31 +6,17 @@ from typing import Callable, Optional, Tuple, Any logger = logging.getLogger(__name__) -TASKS_DIR = "worker/tasks" +TASKS_DIR = "/app/worker/tasks" -def cleanup_task_file(task_title: str) -> bool: - """ - Delete the task file after execution. - Returns True if deleted successfully, False otherwise. - """ - file_path = os.path.join(TASKS_DIR, f"{task_title}.py") - - try: - if os.path.exists(file_path): - os.remove(file_path) - logger.info(f"Cleaned up task file: {task_title}.py") - return True - else: - logger.warning(f"Task file not found for cleanup: {task_title}.py") - return False - except Exception as e: - logger.error(f"Failed to cleanup task file {task_title}.py: {e}") - return False def load_task_handler(task_title: str) -> Tuple[Optional[Callable], Optional[str]]: file_path = os.path.join(TASKS_DIR, f"{task_title}.py") + + logger.info(f"[DEBUG] Loading task: title='{task_title}', path='{file_path}'") + logger.info(f"[DEBUG] TASKS_DIR='{TASKS_DIR}', exists={os.path.exists(file_path)}") if not os.path.exists(file_path): + logger.error(f"File not found: {file_path}") return None, "File not found" # --- FIX FOR ZOMBIE MODULES --- diff --git a/worker/utils.py b/worker/utils.py index 761ffd0..79be858 100644 --- a/worker/utils.py +++ b/worker/utils.py @@ -1,12 +1,23 @@ -from sqlalchemy import create_all, update -from core.database import SessionLocal # Adjust based on your path +from sqlalchemy import update +from core.database import SessionLocal from core.models import Tasks -async def update_task_status(task_id: int, status: str, result: str = None): - async with SessionLocal() as session: - query = update(Tasks).where(Tasks.id == task_id).values( - status=status, - result=result - ) - await session.execute(query) - await session.commit() \ No newline at end of file + +def update_task_status_sync(task_id: int, status: str, worker_id: str = None): + session = SessionLocal() + try: + # Update both status AND worker_id + update_values = {"status": status} + if worker_id: + update_values["worker_id"] = worker_id + + query = update(Tasks).where(Tasks.id == task_id).values(**update_values) + session.execute(query) + session.commit() + finally: + session.close() + +async def update_task_status(task_id: int, status: str, worker_id: str = None): + import asyncio + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, update_task_status_sync, task_id, status, worker_id) \ No newline at end of file