-
Notifications
You must be signed in to change notification settings - Fork 28
Add Fal.ai generation service wrapper #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| import os | ||
| import time | ||
| import requests | ||
| import bittensor as bt | ||
| from typing import Dict, Any, Optional | ||
|
|
||
| from .base_service import BaseGenerationService | ||
| from ..task_manager import GenerationTask | ||
|
|
||
|
|
||
| class FalAIService(BaseGenerationService): | ||
| """ | ||
| Fal.ai API service for media generation. | ||
| Supports: | ||
| - Images via FLUX.1 [dev] (fal-ai/flux/dev) | ||
| - Video via Kling (fal-ai/kling-video/v1/standard/text-to-video) | ||
| """ | ||
|
|
||
| def __init__(self, config: Any = None): | ||
| super().__init__(config) | ||
| self.api_key = os.getenv("FAL_KEY") | ||
| self.base_url = "https://queue.fal.run" | ||
|
|
||
| # Default models | ||
| self.image_model = "fal-ai/flux/dev" | ||
| self.video_model = "fal-ai/kling-video/v1/standard/text-to-video" | ||
|
|
||
| if self.api_key: | ||
| bt.logging.info("FalAIService initialized with API key") | ||
| else: | ||
| bt.logging.warning("FAL_KEY not found. Fal.ai service will not be available.") | ||
|
|
||
| def is_available(self) -> bool: | ||
| return self.api_key is not None and self.api_key.strip() != "" | ||
|
|
||
| def supports_modality(self, modality: str) -> bool: | ||
| return modality in {"image", "video"} | ||
|
|
||
| def get_supported_tasks(self) -> Dict[str, list]: | ||
| return { | ||
| "image": ["image_generation"], | ||
| "video": ["video_generation"], | ||
| } | ||
|
|
||
| def get_api_key_requirements(self) -> Dict[str, str]: | ||
| return {"FAL_KEY": "Fal.ai API key for image and video generation"} | ||
|
|
||
| def process(self, task: GenerationTask) -> Dict[str, Any]: | ||
| if task.modality == "image": | ||
| return self._generate_image(task) | ||
| elif task.modality == "video": | ||
| return self._generate_video(task) | ||
| else: | ||
| raise ValueError(f"Unsupported modality: {task.modality}") | ||
|
|
||
| def _generate_image(self, task: GenerationTask) -> Dict[str, Any]: | ||
| """Generate an image using Fal.ai.""" | ||
| params = task.parameters or {} | ||
| model = params.get("model", self.image_model) | ||
|
|
||
| # Map common parameters to Fal.ai specific ones if needed | ||
| # FLUX.1 [dev] supports: prompt, image_size, num_inference_steps, seed, guidance_scale, etc. | ||
| payload = { | ||
| "prompt": task.prompt, | ||
| "image_size": params.get("size", "landscape_4_3"), # default to landscape | ||
| "num_inference_steps": params.get("steps", 28), | ||
| "seed": params.get("seed"), | ||
| "guidance_scale": params.get("guidance_scale", 3.5), | ||
| "num_images": 1, | ||
| "enable_safety_checker": params.get("safety_checker", True), | ||
| "sync_mode": False # Use queue for reliability | ||
| } | ||
|
|
||
| # Remove None values | ||
| payload = {k: v for k, v in payload.items() if v is not None} | ||
|
|
||
| return self._run_fal_request(model, payload, "image") | ||
|
|
||
| def _generate_video(self, task: GenerationTask) -> Dict[str, Any]: | ||
| """Generate a video using Fal.ai.""" | ||
| params = task.parameters or {} | ||
| model = params.get("model", self.video_model) | ||
|
|
||
| # Kling supports: prompt, duration, aspect_ratio | ||
| payload = { | ||
| "prompt": task.prompt, | ||
| "duration": str(params.get("duration", "5")), # "5" or "10" | ||
| "aspect_ratio": params.get("aspect_ratio", "16:9"), | ||
| } | ||
|
|
||
| return self._run_fal_request(model, payload, "video") | ||
|
|
||
| def _run_fal_request(self, model: str, payload: Dict[str, Any], modality: str) -> Dict[str, Any]: | ||
| """Execute the request against Fal.ai queue and poll for result.""" | ||
| headers = { | ||
| "Authorization": f"Key {self.api_key}", | ||
| "Content-Type": "application/json", | ||
| } | ||
|
|
||
| url = f"{self.base_url}/{model}" | ||
|
|
||
| bt.logging.info(f"Fal.ai submitting job to {model}...") | ||
|
|
||
| try: | ||
| # Submit job | ||
| response = requests.post(url, json=payload, headers=headers) | ||
| response.raise_for_status() | ||
| data = response.json() | ||
|
|
||
| request_id = data.get("request_id") | ||
| if not request_id: | ||
| # Some endpoints might return result immediately if sync_mode=True, | ||
| # but we forced sync_mode=False or are using queue. | ||
| # If we get immediate result (unlikely with queue URL), handle it. | ||
| if "images" in data or "video" in data: | ||
| return self._process_completed_response(data, model, modality) | ||
| raise RuntimeError(f"No request_id returned from Fal.ai: {data}") | ||
|
|
||
| bt.logging.info(f"Fal.ai job submitted. Request ID: {request_id}") | ||
|
|
||
| # Poll for status | ||
| start_time = time.time() | ||
| timeout = 600 # 10 minutes max | ||
| poll_interval = 2.0 | ||
|
|
||
| while time.time() - start_time < timeout: | ||
| status_url = f"{self.base_url}/requests/{request_id}/status" | ||
| status_res = requests.get(status_url, headers=headers) | ||
| status_res.raise_for_status() | ||
| status_data = status_res.json() | ||
|
|
||
| status = status_data.get("status") | ||
|
|
||
| if status == "COMPLETED": | ||
| # Fetch final result | ||
| result_url = f"{self.base_url}/requests/{request_id}" | ||
| result_res = requests.get(result_url, headers=headers) | ||
| result_res.raise_for_status() | ||
| result_data = result_res.json() | ||
|
|
||
| return self._process_completed_response(result_data, model, modality) | ||
|
|
||
| elif status == "FAILED": | ||
| error = status_data.get("error", "Unknown error") | ||
| raise RuntimeError(f"Fal.ai job failed: {error}") | ||
|
|
||
| elif status in ["IN_QUEUE", "IN_PROGRESS"]: | ||
| time.sleep(poll_interval) | ||
| poll_interval = min(poll_interval * 1.5, 10) # Exponential backoff cap at 10s | ||
|
|
||
| else: | ||
| bt.logging.warning(f"Unknown Fal.ai status: {status}") | ||
| time.sleep(poll_interval) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Timeout ineffective when HTTP requests hang indefinitelyThe 600-second timeout at line 124 is ineffective because the |
||
|
|
||
| raise TimeoutError(f"Fal.ai job timed out after {timeout}s") | ||
|
|
||
| except Exception as e: | ||
| bt.logging.error(f"Fal.ai request failed: {e}") | ||
| raise | ||
|
|
||
| def _process_completed_response(self, data: Dict[str, Any], model: str, modality: str) -> Dict[str, Any]: | ||
| """Download and format the result.""" | ||
|
|
||
| media_url = None | ||
| if modality == "image": | ||
| # FLUX returns 'images': [{'url': ..., ...}] | ||
| if "images" in data and len(data["images"]) > 0: | ||
| media_url = data["images"][0]["url"] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Image URL extraction lacks defensive key checkThe image URL extraction at line 169 directly accesses |
||
| elif modality == "video": | ||
| # Kling returns 'video': {'url': ...} | ||
| if "video" in data and "url" in data["video"]: | ||
| media_url = data["video"]["url"] | ||
|
|
||
| if not media_url: | ||
| raise RuntimeError(f"No media URL found in Fal.ai response: {data}") | ||
|
|
||
| bt.logging.info(f"Downloading media from {media_url}...") | ||
| media_res = requests.get(media_url) | ||
| media_res.raise_for_status() | ||
| media_bytes = media_res.content | ||
|
|
||
| return { | ||
| "data": media_bytes, | ||
| "metadata": { | ||
| "model": model, | ||
| "provider": "fal.ai", | ||
| "source_url": media_url, | ||
| "size": len(media_bytes) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,14 +6,18 @@ | |
| from .openai_service import OpenAIService | ||
| from .openrouter_service import OpenRouterService | ||
| from .stabilityai_service import StabilityAIService | ||
| from .stabilityai_service import StabilityAIService | ||
| from .local_service import LocalService | ||
| from .fal_service import FalAIService | ||
|
|
||
|
|
||
| SERVICE_MAP = { | ||
| "openai": OpenAIService, | ||
| "openrouter": OpenRouterService, | ||
| "local": LocalService, | ||
| "stabilityai": StabilityAIService | ||
| "local": LocalService, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Duplicate SERVICE_MAP entries introduced by merge errorThe diff introduces duplicate entries that appear to be merge artifacts. The Additional Locations (1) |
||
| "stabilityai": StabilityAIService, | ||
| "fal": FalAIService | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -144,8 +148,8 @@ def get_available_services(self) -> List[Dict[str, Any]]: | |
| def get_all_api_key_requirements(self) -> Dict[str, str]: | ||
| """Get API key requirements from all services.""" | ||
| all_requirements = { | ||
| "IMAGE_SERVICE": "Service for images: openai, openrouter, local, or none", | ||
| "VIDEO_SERVICE": "Service for videos: openai, openrouter, local, or none", | ||
| "IMAGE_SERVICE": "Service for images: openai, openrouter, local, fal, or none", | ||
| "VIDEO_SERVICE": "Service for videos: openai, openrouter, local, fal, or none", | ||
| } | ||
|
|
||
| for name, service_class in SERVICE_MAP.items(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| import os | ||
| import traceback | ||
| from PIL import Image | ||
| import io | ||
| import time | ||
| import requests | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import sys | ||
| from unittest.mock import MagicMock | ||
|
|
||
| # Mock bittensor before importing services | ||
| mock_bt = MagicMock() | ||
| sys.modules["bittensor"] = mock_bt | ||
|
|
||
| from neurons.generator.services.fal_service import FalAIService | ||
| from neurons.generator.task_manager import TaskManager | ||
| from gas.verification.c2pa_verification import verify_c2pa | ||
|
|
||
| # Set API key if one isn't already set | ||
| os.environ.setdefault( | ||
| "FAL_KEY", | ||
| "key-YOUR-API-KEY" # replace with your test key | ||
| ) | ||
|
|
||
| def save_image(img_bytes, filename): | ||
| os.makedirs("outputs", exist_ok=True) | ||
| out_path = f"outputs/{filename}" | ||
| with open(out_path, "wb") as f: | ||
| f.write(img_bytes) | ||
| return out_path | ||
|
|
||
| def validate_image(img_bytes): | ||
| try: | ||
| Image.open(io.BytesIO(img_bytes)).verify() | ||
| return True | ||
| except Exception: | ||
| return False | ||
|
|
||
|
|
||
| def run_model_test(service, manager, model, modality="image"): | ||
| print(f"\n=== Running generation test for model: {model} ({modality}) ===") | ||
|
|
||
| task_id = manager.create_task( | ||
| modality=modality, | ||
| prompt="A futuristic city with flying cars", | ||
| parameters={ | ||
| "model": model, | ||
| "seed": 777, | ||
| "duration": "5" if modality == "video" else None | ||
| }, | ||
| webhook_url=None, | ||
| signed_by="test-suite" | ||
| ) | ||
|
|
||
| task = manager.get_task(task_id) | ||
|
|
||
| try: | ||
| start_time = time.time() | ||
| result = service.process(task) | ||
| elapsed = time.time() - start_time | ||
|
|
||
| data_bytes = result["data"] | ||
| meta = result["metadata"] | ||
|
|
||
| print(f"✔ Generated media in {elapsed:.2f}s ({len(data_bytes)/1024:.1f} KB)") | ||
| print(f"✔ Metadata keys: {list(meta.keys())}") | ||
|
|
||
| # Save output | ||
| ext = "mp4" if modality == "video" else "png" | ||
| out = save_image(data_bytes, f"{model.replace('/', '_')}.{ext}") | ||
| print(f"✔ Saved to {out}") | ||
|
|
||
| if modality == "image": | ||
| # Validate image integrity | ||
| assert validate_image(data_bytes), "Image failed Pillow validation" | ||
|
|
||
| # Check C2PA | ||
| try: | ||
| c2pa_result = verify_c2pa(data_bytes) | ||
| if c2pa_result.verified: | ||
| print(f"✅ C2PA Verified! Issuer: {c2pa_result.issuer}") | ||
| else: | ||
| print(f"⚠️ C2PA Not Verified: {c2pa_result.error}") | ||
| except Exception as e: | ||
| print(f"⚠️ C2PA Check Failed: {e}") | ||
|
|
||
| # Validate metadata | ||
| assert meta["model"] == model | ||
| assert meta["provider"] == "fal.ai" | ||
| assert "source_url" in meta | ||
|
|
||
| print(f"=== Model {model} PASSED ===") | ||
|
|
||
| except Exception: | ||
| print(f"=== Model {model} FAILED ===") | ||
| print(traceback.format_exc()) | ||
|
|
||
|
|
||
| def test_invalid_api_key(): | ||
| print("\n=== Testing invalid API key ===") | ||
| original_key = os.environ.get("FAL_KEY") | ||
| os.environ["FAL_KEY"] = "invalid-key" | ||
|
|
||
| service = FalAIService() | ||
| manager = TaskManager() | ||
|
|
||
| task_id = manager.create_task( | ||
| modality="image", | ||
| prompt="test prompt", | ||
| parameters={"model": "fal-ai/flux/dev"}, | ||
| webhook_url=None, | ||
| signed_by="test" | ||
| ) | ||
| task = manager.get_task(task_id) | ||
|
|
||
| try: | ||
| service.process(task) | ||
| print("❌ Should have failed with invalid API key!") | ||
| except Exception as e: | ||
| print(f"✔ Correctly failed: {e}") | ||
| finally: | ||
| if original_key: | ||
| os.environ["FAL_KEY"] = original_key | ||
|
|
||
|
|
||
| def run_full_test_suite(): | ||
| print("\n========== Fal.ai Full Test Suite ==========\n") | ||
|
|
||
| service = FalAIService() | ||
| manager = TaskManager() | ||
|
|
||
| if not service.is_available() or service.api_key == "key-YOUR-API-KEY": | ||
| print("❌ API key missing or default — skipping live tests") | ||
| print("ℹ️ Set FAL_KEY environment variable to run live tests") | ||
| else: | ||
| # Test Image Model | ||
| run_model_test(service, manager, "fal-ai/flux/dev", "image") | ||
|
|
||
| # Test Video Model | ||
| run_model_test(service, manager, "fal-ai/kling-video/v1/standard/text-to-video", "video") | ||
|
|
||
| # Negative tests | ||
| test_invalid_api_key() | ||
|
|
||
| print("\n========== All Tests Completed ==========\n") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_full_test_suite() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Status URL missing model path breaks polling
The status URL and result URL are constructed incorrectly. Jobs are submitted to
{base_url}/{model}(e.g.,https://queue.fal.run/fal-ai/flux/dev), but the status and result URLs use{base_url}/requests/{request_id}instead of{base_url}/{model}/requests/{request_id}. The model path is missing from the status and result polling endpoints, which will cause HTTP 404 errors when checking job status.Additional Locations (1)
neurons/generator/services/fal_service.py#L136-L137