Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions neurons/generator/services/fal_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
import time
import requests
import bittensor as bt
from typing import Dict, Any, Optional

from .base_service import BaseGenerationService
from ..task_manager import GenerationTask


class FalAIService(BaseGenerationService):
"""
Fal.ai API service for media generation.
Supports:
- Images via FLUX.1 [dev] (fal-ai/flux/dev)
- Video via Kling (fal-ai/kling-video/v1/standard/text-to-video)
"""

def __init__(self, config: Any = None):
super().__init__(config)
self.api_key = os.getenv("FAL_KEY")
self.base_url = "https://queue.fal.run"

# Default models
self.image_model = "fal-ai/flux/dev"
self.video_model = "fal-ai/kling-video/v1/standard/text-to-video"

if self.api_key:
bt.logging.info("FalAIService initialized with API key")
else:
bt.logging.warning("FAL_KEY not found. Fal.ai service will not be available.")

def is_available(self) -> bool:
return self.api_key is not None and self.api_key.strip() != ""

def supports_modality(self, modality: str) -> bool:
return modality in {"image", "video"}

def get_supported_tasks(self) -> Dict[str, list]:
return {
"image": ["image_generation"],
"video": ["video_generation"],
}

def get_api_key_requirements(self) -> Dict[str, str]:
return {"FAL_KEY": "Fal.ai API key for image and video generation"}

def process(self, task: GenerationTask) -> Dict[str, Any]:
if task.modality == "image":
return self._generate_image(task)
elif task.modality == "video":
return self._generate_video(task)
else:
raise ValueError(f"Unsupported modality: {task.modality}")

def _generate_image(self, task: GenerationTask) -> Dict[str, Any]:
"""Generate an image using Fal.ai."""
params = task.parameters or {}
model = params.get("model", self.image_model)

# Map common parameters to Fal.ai specific ones if needed
# FLUX.1 [dev] supports: prompt, image_size, num_inference_steps, seed, guidance_scale, etc.
payload = {
"prompt": task.prompt,
"image_size": params.get("size", "landscape_4_3"), # default to landscape
"num_inference_steps": params.get("steps", 28),
"seed": params.get("seed"),
"guidance_scale": params.get("guidance_scale", 3.5),
"num_images": 1,
"enable_safety_checker": params.get("safety_checker", True),
"sync_mode": False # Use queue for reliability
}

# Remove None values
payload = {k: v for k, v in payload.items() if v is not None}

return self._run_fal_request(model, payload, "image")

def _generate_video(self, task: GenerationTask) -> Dict[str, Any]:
"""Generate a video using Fal.ai."""
params = task.parameters or {}
model = params.get("model", self.video_model)

# Kling supports: prompt, duration, aspect_ratio
payload = {
"prompt": task.prompt,
"duration": str(params.get("duration", "5")), # "5" or "10"
"aspect_ratio": params.get("aspect_ratio", "16:9"),
}

return self._run_fal_request(model, payload, "video")

def _run_fal_request(self, model: str, payload: Dict[str, Any], modality: str) -> Dict[str, Any]:
"""Execute the request against Fal.ai queue and poll for result."""
headers = {
"Authorization": f"Key {self.api_key}",
"Content-Type": "application/json",
}

url = f"{self.base_url}/{model}"

bt.logging.info(f"Fal.ai submitting job to {model}...")

try:
# Submit job
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
data = response.json()

request_id = data.get("request_id")
if not request_id:
# Some endpoints might return result immediately if sync_mode=True,
# but we forced sync_mode=False or are using queue.
# If we get immediate result (unlikely with queue URL), handle it.
if "images" in data or "video" in data:
return self._process_completed_response(data, model, modality)
raise RuntimeError(f"No request_id returned from Fal.ai: {data}")

bt.logging.info(f"Fal.ai job submitted. Request ID: {request_id}")

# Poll for status
start_time = time.time()
timeout = 600 # 10 minutes max
poll_interval = 2.0

while time.time() - start_time < timeout:
status_url = f"{self.base_url}/requests/{request_id}/status"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Status URL missing model path breaks polling

The status URL and result URL are constructed incorrectly. Jobs are submitted to {base_url}/{model} (e.g., https://queue.fal.run/fal-ai/flux/dev), but the status and result URLs use {base_url}/requests/{request_id} instead of {base_url}/{model}/requests/{request_id}. The model path is missing from the status and result polling endpoints, which will cause HTTP 404 errors when checking job status.

Additional Locations (1)

Fix in Cursor Fix in Web

status_res = requests.get(status_url, headers=headers)
status_res.raise_for_status()
status_data = status_res.json()

status = status_data.get("status")

if status == "COMPLETED":
# Fetch final result
result_url = f"{self.base_url}/requests/{request_id}"
result_res = requests.get(result_url, headers=headers)
result_res.raise_for_status()
result_data = result_res.json()

return self._process_completed_response(result_data, model, modality)

elif status == "FAILED":
error = status_data.get("error", "Unknown error")
raise RuntimeError(f"Fal.ai job failed: {error}")

elif status in ["IN_QUEUE", "IN_PROGRESS"]:
time.sleep(poll_interval)
poll_interval = min(poll_interval * 1.5, 10) # Exponential backoff cap at 10s

else:
bt.logging.warning(f"Unknown Fal.ai status: {status}")
time.sleep(poll_interval)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Timeout ineffective when HTTP requests hang indefinitely

The 600-second timeout at line 124 is ineffective because the requests.get() and requests.post() calls (lines 107, 129, 138, 179) have no individual request timeouts. If any HTTP request hangs due to network issues, the code blocks indefinitely since the while loop condition at line 127 is only checked between iterations. The intended 10-minute timeout never triggers if a single request never returns.

Fix in Cursor Fix in Web


raise TimeoutError(f"Fal.ai job timed out after {timeout}s")

except Exception as e:
bt.logging.error(f"Fal.ai request failed: {e}")
raise

def _process_completed_response(self, data: Dict[str, Any], model: str, modality: str) -> Dict[str, Any]:
"""Download and format the result."""

media_url = None
if modality == "image":
# FLUX returns 'images': [{'url': ..., ...}]
if "images" in data and len(data["images"]) > 0:
media_url = data["images"][0]["url"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Image URL extraction lacks defensive key check

The image URL extraction at line 169 directly accesses data["images"][0]["url"] without verifying that the "url" key exists in the image object. In contrast, the video handling at lines 172-173 properly checks "url" in data["video"] before accessing it. If the API returns an image object without a "url" key, this code raises a confusing KeyError instead of falling through to the intended RuntimeError with a helpful message about "No media URL found".

Fix in Cursor Fix in Web

elif modality == "video":
# Kling returns 'video': {'url': ...}
if "video" in data and "url" in data["video"]:
media_url = data["video"]["url"]

if not media_url:
raise RuntimeError(f"No media URL found in Fal.ai response: {data}")

bt.logging.info(f"Downloading media from {media_url}...")
media_res = requests.get(media_url)
media_res.raise_for_status()
media_bytes = media_res.content

return {
"data": media_bytes,
"metadata": {
"model": model,
"provider": "fal.ai",
"source_url": media_url,
"size": len(media_bytes)
}
}
10 changes: 7 additions & 3 deletions neurons/generator/services/service_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
from .openai_service import OpenAIService
from .openrouter_service import OpenRouterService
from .stabilityai_service import StabilityAIService
from .stabilityai_service import StabilityAIService
from .local_service import LocalService
from .fal_service import FalAIService


SERVICE_MAP = {
"openai": OpenAIService,
"openrouter": OpenRouterService,
"local": LocalService,
"stabilityai": StabilityAIService
"local": LocalService,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Duplicate SERVICE_MAP entries introduced by merge error

The diff introduces duplicate entries that appear to be merge artifacts. The "local": LocalService key appears twice in SERVICE_MAP (lines 17-18), and StabilityAIService is imported twice (lines 8-9). While Python dictionaries silently overwrite duplicate keys, this indicates a merge error and the duplicate "local" entry may have been intended to be a different service entry that was lost during the merge.

Additional Locations (1)

Fix in Cursor Fix in Web

"stabilityai": StabilityAIService,
"fal": FalAIService
}


Expand Down Expand Up @@ -144,8 +148,8 @@ def get_available_services(self) -> List[Dict[str, Any]]:
def get_all_api_key_requirements(self) -> Dict[str, str]:
"""Get API key requirements from all services."""
all_requirements = {
"IMAGE_SERVICE": "Service for images: openai, openrouter, local, or none",
"VIDEO_SERVICE": "Service for videos: openai, openrouter, local, or none",
"IMAGE_SERVICE": "Service for images: openai, openrouter, local, fal, or none",
"VIDEO_SERVICE": "Service for videos: openai, openrouter, local, fal, or none",
}

for name, service_class in SERVICE_MAP.items():
Expand Down
150 changes: 150 additions & 0 deletions tests/generator/fal_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import traceback
from PIL import Image
import io
import time
import requests
from unittest.mock import MagicMock, patch

import sys
from unittest.mock import MagicMock

# Mock bittensor before importing services
mock_bt = MagicMock()
sys.modules["bittensor"] = mock_bt

from neurons.generator.services.fal_service import FalAIService
from neurons.generator.task_manager import TaskManager
from gas.verification.c2pa_verification import verify_c2pa

# Set API key if one isn't already set
os.environ.setdefault(
"FAL_KEY",
"key-YOUR-API-KEY" # replace with your test key
)

def save_image(img_bytes, filename):
os.makedirs("outputs", exist_ok=True)
out_path = f"outputs/{filename}"
with open(out_path, "wb") as f:
f.write(img_bytes)
return out_path

def validate_image(img_bytes):
try:
Image.open(io.BytesIO(img_bytes)).verify()
return True
except Exception:
return False


def run_model_test(service, manager, model, modality="image"):
print(f"\n=== Running generation test for model: {model} ({modality}) ===")

task_id = manager.create_task(
modality=modality,
prompt="A futuristic city with flying cars",
parameters={
"model": model,
"seed": 777,
"duration": "5" if modality == "video" else None
},
webhook_url=None,
signed_by="test-suite"
)

task = manager.get_task(task_id)

try:
start_time = time.time()
result = service.process(task)
elapsed = time.time() - start_time

data_bytes = result["data"]
meta = result["metadata"]

print(f"✔ Generated media in {elapsed:.2f}s ({len(data_bytes)/1024:.1f} KB)")
print(f"✔ Metadata keys: {list(meta.keys())}")

# Save output
ext = "mp4" if modality == "video" else "png"
out = save_image(data_bytes, f"{model.replace('/', '_')}.{ext}")
print(f"✔ Saved to {out}")

if modality == "image":
# Validate image integrity
assert validate_image(data_bytes), "Image failed Pillow validation"

# Check C2PA
try:
c2pa_result = verify_c2pa(data_bytes)
if c2pa_result.verified:
print(f"✅ C2PA Verified! Issuer: {c2pa_result.issuer}")
else:
print(f"⚠️ C2PA Not Verified: {c2pa_result.error}")
except Exception as e:
print(f"⚠️ C2PA Check Failed: {e}")

# Validate metadata
assert meta["model"] == model
assert meta["provider"] == "fal.ai"
assert "source_url" in meta

print(f"=== Model {model} PASSED ===")

except Exception:
print(f"=== Model {model} FAILED ===")
print(traceback.format_exc())


def test_invalid_api_key():
print("\n=== Testing invalid API key ===")
original_key = os.environ.get("FAL_KEY")
os.environ["FAL_KEY"] = "invalid-key"

service = FalAIService()
manager = TaskManager()

task_id = manager.create_task(
modality="image",
prompt="test prompt",
parameters={"model": "fal-ai/flux/dev"},
webhook_url=None,
signed_by="test"
)
task = manager.get_task(task_id)

try:
service.process(task)
print("❌ Should have failed with invalid API key!")
except Exception as e:
print(f"✔ Correctly failed: {e}")
finally:
if original_key:
os.environ["FAL_KEY"] = original_key


def run_full_test_suite():
print("\n========== Fal.ai Full Test Suite ==========\n")

service = FalAIService()
manager = TaskManager()

if not service.is_available() or service.api_key == "key-YOUR-API-KEY":
print("❌ API key missing or default — skipping live tests")
print("ℹ️ Set FAL_KEY environment variable to run live tests")
else:
# Test Image Model
run_model_test(service, manager, "fal-ai/flux/dev", "image")

# Test Video Model
run_model_test(service, manager, "fal-ai/kling-video/v1/standard/text-to-video", "video")

# Negative tests
test_invalid_api_key()

print("\n========== All Tests Completed ==========\n")


if __name__ == "__main__":
run_full_test_suite()