-
Notifications
You must be signed in to change notification settings - Fork 44.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(block): Add AI video generator block with Fal txt 2 vid (#8528)
### Background Implements an AI Video Generator Block for text to image models hosted on Fal ![image](https://github.com/user-attachments/assets/9cb70015-4174-4419-8c1a-4144f324442f) --------- Co-authored-by: Aarushi <50577581+aarushik93@users.noreply.github.com> Co-authored-by: Aarushi <aarushik93@gmail.com>
- Loading branch information
1 parent
75f9b07
commit 4aa5f53
Showing
6 changed files
with
240 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Literal | ||
|
||
from pydantic import SecretStr | ||
|
||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput | ||
|
||
FalCredentials = APIKeyCredentials | ||
FalCredentialsInput = CredentialsMetaInput[ | ||
Literal["fal"], | ||
Literal["api_key"], | ||
] | ||
|
||
TEST_CREDENTIALS = APIKeyCredentials( | ||
id="01234567-89ab-cdef-0123-456789abcdef", | ||
provider="fal", | ||
api_key=SecretStr("mock-fal-api-key"), | ||
title="Mock FAL API key", | ||
expires_at=None, | ||
) | ||
TEST_CREDENTIALS_INPUT = { | ||
"provider": TEST_CREDENTIALS.provider, | ||
"id": TEST_CREDENTIALS.id, | ||
"type": TEST_CREDENTIALS.type, | ||
"title": TEST_CREDENTIALS.title, | ||
} | ||
|
||
|
||
def FalCredentialsField() -> FalCredentialsInput: | ||
""" | ||
Creates a FAL credentials input on a block. | ||
""" | ||
return CredentialsField( | ||
provider="fal", | ||
supported_credential_types={"api_key"}, | ||
description="The FAL integration can be used with an API Key.", | ||
) |
199 changes: 199 additions & 0 deletions
199
autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import logging | ||
import time | ||
from enum import Enum | ||
from typing import Any, Dict | ||
|
||
import httpx | ||
|
||
from backend.blocks.fal._auth import ( | ||
TEST_CREDENTIALS, | ||
TEST_CREDENTIALS_INPUT, | ||
FalCredentials, | ||
FalCredentialsField, | ||
FalCredentialsInput, | ||
) | ||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema | ||
from backend.data.model import SchemaField | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FalModel(str, Enum): | ||
MOCHI = "fal-ai/mochi-v1" | ||
LUMA = "fal-ai/luma-dream-machine" | ||
|
||
|
||
class AIVideoGeneratorBlock(Block): | ||
class Input(BlockSchema): | ||
prompt: str = SchemaField( | ||
description="Description of the video to generate.", | ||
placeholder="A dog running in a field.", | ||
) | ||
model: FalModel = SchemaField( | ||
title="FAL Model", | ||
default=FalModel.MOCHI, | ||
description="The FAL model to use for video generation.", | ||
) | ||
credentials: FalCredentialsInput = FalCredentialsField() | ||
|
||
class Output(BlockSchema): | ||
video_url: str = SchemaField(description="The URL of the generated video.") | ||
error: str = SchemaField( | ||
description="Error message if video generation failed." | ||
) | ||
logs: list[str] = SchemaField( | ||
description="Generation progress logs.", optional=True | ||
) | ||
|
||
def __init__(self): | ||
super().__init__( | ||
id="530cf046-2ce0-4854-ae2c-659db17c7a46", | ||
description="Generate videos using FAL AI models.", | ||
categories={BlockCategory.AI}, | ||
input_schema=self.Input, | ||
output_schema=self.Output, | ||
test_input={ | ||
"prompt": "A dog running in a field.", | ||
"model": FalModel.MOCHI, | ||
"credentials": TEST_CREDENTIALS_INPUT, | ||
}, | ||
test_credentials=TEST_CREDENTIALS, | ||
test_output=[("video_url", "https://fal.media/files/example/video.mp4")], | ||
test_mock={ | ||
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4" | ||
}, | ||
) | ||
|
||
def _get_headers(self, api_key: str) -> Dict[str, str]: | ||
"""Get headers for FAL API requests.""" | ||
return { | ||
"Authorization": f"Key {api_key}", | ||
"Content-Type": "application/json", | ||
} | ||
|
||
def _submit_request( | ||
self, url: str, headers: Dict[str, str], data: Dict[str, Any] | ||
) -> Dict[str, Any]: | ||
"""Submit a request to the FAL API.""" | ||
try: | ||
response = httpx.post(url, headers=headers, json=data) | ||
response.raise_for_status() | ||
return response.json() | ||
except httpx.HTTPError as e: | ||
logger.error(f"FAL API request failed: {str(e)}") | ||
raise RuntimeError(f"Failed to submit request: {str(e)}") | ||
|
||
def _poll_status(self, status_url: str, headers: Dict[str, str]) -> Dict[str, Any]: | ||
"""Poll the status endpoint until completion or failure.""" | ||
try: | ||
response = httpx.get(status_url, headers=headers) | ||
response.raise_for_status() | ||
return response.json() | ||
except httpx.HTTPError as e: | ||
logger.error(f"Failed to get status: {str(e)}") | ||
raise RuntimeError(f"Failed to get status: {str(e)}") | ||
|
||
def generate_video(self, input_data: Input, credentials: FalCredentials) -> str: | ||
"""Generate video using the specified FAL model.""" | ||
base_url = "https://queue.fal.run" | ||
api_key = credentials.api_key.get_secret_value() | ||
headers = self._get_headers(api_key) | ||
|
||
# Submit generation request | ||
submit_url = f"{base_url}/{input_data.model.value}" | ||
submit_data = {"prompt": input_data.prompt} | ||
|
||
seen_logs = set() | ||
|
||
try: | ||
# Submit request to queue | ||
submit_response = httpx.post(submit_url, headers=headers, json=submit_data) | ||
submit_response.raise_for_status() | ||
request_data = submit_response.json() | ||
|
||
# Get request_id and urls from initial response | ||
request_id = request_data.get("request_id") | ||
status_url = request_data.get("status_url") | ||
result_url = request_data.get("response_url") | ||
|
||
if not all([request_id, status_url, result_url]): | ||
raise ValueError("Missing required data in submission response") | ||
|
||
# Poll for status with exponential backoff | ||
max_attempts = 30 | ||
attempt = 0 | ||
base_wait_time = 5 | ||
|
||
while attempt < max_attempts: | ||
status_response = httpx.get(f"{status_url}?logs=1", headers=headers) | ||
status_response.raise_for_status() | ||
status_data = status_response.json() | ||
|
||
# Process new logs only | ||
logs = status_data.get("logs", []) | ||
if logs and isinstance(logs, list): | ||
for log in logs: | ||
if isinstance(log, dict): | ||
# Create a unique key for this log entry | ||
log_key = ( | ||
f"{log.get('timestamp', '')}-{log.get('message', '')}" | ||
) | ||
if log_key not in seen_logs: | ||
seen_logs.add(log_key) | ||
message = log.get("message", "") | ||
if message: | ||
logger.debug( | ||
f"[FAL Generation] [{log.get('level', 'INFO')}] [{log.get('source', '')}] [{log.get('timestamp', '')}] {message}" | ||
) | ||
|
||
status = status_data.get("status") | ||
if status == "COMPLETED": | ||
# Get the final result | ||
result_response = httpx.get(result_url, headers=headers) | ||
result_response.raise_for_status() | ||
result_data = result_response.json() | ||
|
||
if "video" not in result_data or not isinstance( | ||
result_data["video"], dict | ||
): | ||
raise ValueError("Invalid response format - missing video data") | ||
|
||
video_url = result_data["video"].get("url") | ||
if not video_url: | ||
raise ValueError("No video URL in response") | ||
|
||
return video_url | ||
|
||
elif status == "FAILED": | ||
error_msg = status_data.get("error", "No error details provided") | ||
raise RuntimeError(f"Video generation failed: {error_msg}") | ||
elif status == "IN_QUEUE": | ||
position = status_data.get("queue_position", "unknown") | ||
logger.debug( | ||
f"[FAL Generation] Status: In queue, position: {position}" | ||
) | ||
elif status == "IN_PROGRESS": | ||
logger.debug( | ||
"[FAL Generation] Status: Request is being processed..." | ||
) | ||
else: | ||
logger.info(f"[FAL Generation] Status: Unknown status: {status}") | ||
|
||
wait_time = min(base_wait_time * (2**attempt), 60) # Cap at 60 seconds | ||
time.sleep(wait_time) | ||
attempt += 1 | ||
|
||
raise RuntimeError("Maximum polling attempts reached") | ||
|
||
except httpx.HTTPError as e: | ||
raise RuntimeError(f"API request failed: {str(e)}") | ||
|
||
def run( | ||
self, input_data: Input, *, credentials: FalCredentials, **kwargs | ||
) -> BlockOutput: | ||
try: | ||
video_url = self.generate_video(input_data, credentials) | ||
yield "video_url", video_url | ||
except Exception as e: | ||
error_message = str(e) | ||
yield "error", error_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters