diff --git a/neurons/generator/services/fal_service.py b/neurons/generator/services/fal_service.py new file mode 100644 index 00000000..a051a031 --- /dev/null +++ b/neurons/generator/services/fal_service.py @@ -0,0 +1,191 @@ +import os +import time +import requests +import bittensor as bt +from typing import Dict, Any, Optional + +from .base_service import BaseGenerationService +from ..task_manager import GenerationTask + + +class FalAIService(BaseGenerationService): + """ + Fal.ai API service for media generation. + + Supports: + - Images via FLUX.1 [dev] (fal-ai/flux/dev) + - Video via Kling (fal-ai/kling-video/v1/standard/text-to-video) + """ + + def __init__(self, config: Any = None): + super().__init__(config) + self.api_key = os.getenv("FAL_KEY") + self.base_url = "https://queue.fal.run" + + # Default models + self.image_model = "fal-ai/flux/dev" + self.video_model = "fal-ai/kling-video/v1/standard/text-to-video" + + if self.api_key: + bt.logging.info("FalAIService initialized with API key") + else: + bt.logging.warning("FAL_KEY not found. Fal.ai service will not be available.") + + def is_available(self) -> bool: + return self.api_key is not None and self.api_key.strip() != "" + + def supports_modality(self, modality: str) -> bool: + return modality in {"image", "video"} + + def get_supported_tasks(self) -> Dict[str, list]: + return { + "image": ["image_generation"], + "video": ["video_generation"], + } + + def get_api_key_requirements(self) -> Dict[str, str]: + return {"FAL_KEY": "Fal.ai API key for image and video generation"} + + def process(self, task: GenerationTask) -> Dict[str, Any]: + if task.modality == "image": + return self._generate_image(task) + elif task.modality == "video": + return self._generate_video(task) + else: + raise ValueError(f"Unsupported modality: {task.modality}") + + def _generate_image(self, task: GenerationTask) -> Dict[str, Any]: + """Generate an image using Fal.ai.""" + params = task.parameters or {} + model = params.get("model", self.image_model) + + # Map common parameters to Fal.ai specific ones if needed + # FLUX.1 [dev] supports: prompt, image_size, num_inference_steps, seed, guidance_scale, etc. + payload = { + "prompt": task.prompt, + "image_size": params.get("size", "landscape_4_3"), # default to landscape + "num_inference_steps": params.get("steps", 28), + "seed": params.get("seed"), + "guidance_scale": params.get("guidance_scale", 3.5), + "num_images": 1, + "enable_safety_checker": params.get("safety_checker", True), + "sync_mode": False # Use queue for reliability + } + + # Remove None values + payload = {k: v for k, v in payload.items() if v is not None} + + return self._run_fal_request(model, payload, "image") + + def _generate_video(self, task: GenerationTask) -> Dict[str, Any]: + """Generate a video using Fal.ai.""" + params = task.parameters or {} + model = params.get("model", self.video_model) + + # Kling supports: prompt, duration, aspect_ratio + payload = { + "prompt": task.prompt, + "duration": str(params.get("duration", "5")), # "5" or "10" + "aspect_ratio": params.get("aspect_ratio", "16:9"), + } + + return self._run_fal_request(model, payload, "video") + + def _run_fal_request(self, model: str, payload: Dict[str, Any], modality: str) -> Dict[str, Any]: + """Execute the request against Fal.ai queue and poll for result.""" + headers = { + "Authorization": f"Key {self.api_key}", + "Content-Type": "application/json", + } + + url = f"{self.base_url}/{model}" + + bt.logging.info(f"Fal.ai submitting job to {model}...") + + try: + # Submit job + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + + request_id = data.get("request_id") + if not request_id: + # Some endpoints might return result immediately if sync_mode=True, + # but we forced sync_mode=False or are using queue. + # If we get immediate result (unlikely with queue URL), handle it. + if "images" in data or "video" in data: + return self._process_completed_response(data, model, modality) + raise RuntimeError(f"No request_id returned from Fal.ai: {data}") + + bt.logging.info(f"Fal.ai job submitted. Request ID: {request_id}") + + # Poll for status + start_time = time.time() + timeout = 600 # 10 minutes max + poll_interval = 2.0 + + while time.time() - start_time < timeout: + status_url = f"{self.base_url}/requests/{request_id}/status" + status_res = requests.get(status_url, headers=headers) + status_res.raise_for_status() + status_data = status_res.json() + + status = status_data.get("status") + + if status == "COMPLETED": + # Fetch final result + result_url = f"{self.base_url}/requests/{request_id}" + result_res = requests.get(result_url, headers=headers) + result_res.raise_for_status() + result_data = result_res.json() + + return self._process_completed_response(result_data, model, modality) + + elif status == "FAILED": + error = status_data.get("error", "Unknown error") + raise RuntimeError(f"Fal.ai job failed: {error}") + + elif status in ["IN_QUEUE", "IN_PROGRESS"]: + time.sleep(poll_interval) + poll_interval = min(poll_interval * 1.5, 10) # Exponential backoff cap at 10s + + else: + bt.logging.warning(f"Unknown Fal.ai status: {status}") + time.sleep(poll_interval) + + raise TimeoutError(f"Fal.ai job timed out after {timeout}s") + + except Exception as e: + bt.logging.error(f"Fal.ai request failed: {e}") + raise + + def _process_completed_response(self, data: Dict[str, Any], model: str, modality: str) -> Dict[str, Any]: + """Download and format the result.""" + + media_url = None + if modality == "image": + # FLUX returns 'images': [{'url': ..., ...}] + if "images" in data and len(data["images"]) > 0: + media_url = data["images"][0]["url"] + elif modality == "video": + # Kling returns 'video': {'url': ...} + if "video" in data and "url" in data["video"]: + media_url = data["video"]["url"] + + if not media_url: + raise RuntimeError(f"No media URL found in Fal.ai response: {data}") + + bt.logging.info(f"Downloading media from {media_url}...") + media_res = requests.get(media_url) + media_res.raise_for_status() + media_bytes = media_res.content + + return { + "data": media_bytes, + "metadata": { + "model": model, + "provider": "fal.ai", + "source_url": media_url, + "size": len(media_bytes) + } + } diff --git a/neurons/generator/services/service_registry.py b/neurons/generator/services/service_registry.py index 62c43b06..dcb7841d 100644 --- a/neurons/generator/services/service_registry.py +++ b/neurons/generator/services/service_registry.py @@ -6,14 +6,18 @@ from .openai_service import OpenAIService from .openrouter_service import OpenRouterService from .stabilityai_service import StabilityAIService +from .stabilityai_service import StabilityAIService from .local_service import LocalService +from .fal_service import FalAIService SERVICE_MAP = { "openai": OpenAIService, "openrouter": OpenRouterService, "local": LocalService, - "stabilityai": StabilityAIService + "local": LocalService, + "stabilityai": StabilityAIService, + "fal": FalAIService } @@ -144,8 +148,8 @@ def get_available_services(self) -> List[Dict[str, Any]]: def get_all_api_key_requirements(self) -> Dict[str, str]: """Get API key requirements from all services.""" all_requirements = { - "IMAGE_SERVICE": "Service for images: openai, openrouter, local, or none", - "VIDEO_SERVICE": "Service for videos: openai, openrouter, local, or none", + "IMAGE_SERVICE": "Service for images: openai, openrouter, local, fal, or none", + "VIDEO_SERVICE": "Service for videos: openai, openrouter, local, fal, or none", } for name, service_class in SERVICE_MAP.items(): diff --git a/tests/generator/fal_service.py b/tests/generator/fal_service.py new file mode 100644 index 00000000..a448c5d7 --- /dev/null +++ b/tests/generator/fal_service.py @@ -0,0 +1,150 @@ +import os +import traceback +from PIL import Image +import io +import time +import requests +from unittest.mock import MagicMock, patch + +import sys +from unittest.mock import MagicMock + +# Mock bittensor before importing services +mock_bt = MagicMock() +sys.modules["bittensor"] = mock_bt + +from neurons.generator.services.fal_service import FalAIService +from neurons.generator.task_manager import TaskManager +from gas.verification.c2pa_verification import verify_c2pa + +# Set API key if one isn't already set +os.environ.setdefault( + "FAL_KEY", + "key-YOUR-API-KEY" # replace with your test key +) + +def save_image(img_bytes, filename): + os.makedirs("outputs", exist_ok=True) + out_path = f"outputs/{filename}" + with open(out_path, "wb") as f: + f.write(img_bytes) + return out_path + +def validate_image(img_bytes): + try: + Image.open(io.BytesIO(img_bytes)).verify() + return True + except Exception: + return False + + +def run_model_test(service, manager, model, modality="image"): + print(f"\n=== Running generation test for model: {model} ({modality}) ===") + + task_id = manager.create_task( + modality=modality, + prompt="A futuristic city with flying cars", + parameters={ + "model": model, + "seed": 777, + "duration": "5" if modality == "video" else None + }, + webhook_url=None, + signed_by="test-suite" + ) + + task = manager.get_task(task_id) + + try: + start_time = time.time() + result = service.process(task) + elapsed = time.time() - start_time + + data_bytes = result["data"] + meta = result["metadata"] + + print(f"✔ Generated media in {elapsed:.2f}s ({len(data_bytes)/1024:.1f} KB)") + print(f"✔ Metadata keys: {list(meta.keys())}") + + # Save output + ext = "mp4" if modality == "video" else "png" + out = save_image(data_bytes, f"{model.replace('/', '_')}.{ext}") + print(f"✔ Saved to {out}") + + if modality == "image": + # Validate image integrity + assert validate_image(data_bytes), "Image failed Pillow validation" + + # Check C2PA + try: + c2pa_result = verify_c2pa(data_bytes) + if c2pa_result.verified: + print(f"✅ C2PA Verified! Issuer: {c2pa_result.issuer}") + else: + print(f"⚠️ C2PA Not Verified: {c2pa_result.error}") + except Exception as e: + print(f"⚠️ C2PA Check Failed: {e}") + + # Validate metadata + assert meta["model"] == model + assert meta["provider"] == "fal.ai" + assert "source_url" in meta + + print(f"=== Model {model} PASSED ===") + + except Exception: + print(f"=== Model {model} FAILED ===") + print(traceback.format_exc()) + + +def test_invalid_api_key(): + print("\n=== Testing invalid API key ===") + original_key = os.environ.get("FAL_KEY") + os.environ["FAL_KEY"] = "invalid-key" + + service = FalAIService() + manager = TaskManager() + + task_id = manager.create_task( + modality="image", + prompt="test prompt", + parameters={"model": "fal-ai/flux/dev"}, + webhook_url=None, + signed_by="test" + ) + task = manager.get_task(task_id) + + try: + service.process(task) + print("❌ Should have failed with invalid API key!") + except Exception as e: + print(f"✔ Correctly failed: {e}") + finally: + if original_key: + os.environ["FAL_KEY"] = original_key + + +def run_full_test_suite(): + print("\n========== Fal.ai Full Test Suite ==========\n") + + service = FalAIService() + manager = TaskManager() + + if not service.is_available() or service.api_key == "key-YOUR-API-KEY": + print("❌ API key missing or default — skipping live tests") + print("ℹ️ Set FAL_KEY environment variable to run live tests") + else: + # Test Image Model + run_model_test(service, manager, "fal-ai/flux/dev", "image") + + # Test Video Model + run_model_test(service, manager, "fal-ai/kling-video/v1/standard/text-to-video", "video") + + # Negative tests + test_invalid_api_key() + + print("\n========== All Tests Completed ==========\n") + + +if __name__ == "__main__": + run_full_test_suite()