Skip to content

Commit

Permalink
Add runway extend feature
Browse files Browse the repository at this point in the history
  • Loading branch information
igolaizola committed Jan 23, 2024
1 parent 34d3b80 commit f7fb0fe
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 48 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ This is a CLI tool for [RunwayML Gen-2](https://runwayml.com/) that adds some ex
## 🚀 Features

- Generate videos directly from the command line using a text or image prompt.
- Use RunwayML's extend feature to generate longer videos.
- Create or extend videos longer than 4 seconds by reusing the last frame of the video as the input for the next generation.
- Other handy tools to edit videos, like generating loops or resizing videos.

Expand Down Expand Up @@ -42,6 +43,12 @@ Generate a video from a text prompt:
vidai generate --token RUNWAYML_TOKEN --text "a car in the middle of the road" --output car.mp4
```

Generate a video from a image prompt and extend it twice (using RunwayML's extend feature):

```bash
vidai generate --token RUNWAYML_TOKEN --image car.jpg --output car.mp4 --extend 2
```

Extend a video by reusing the last frame twice:

```bash
Expand Down
10 changes: 2 additions & 8 deletions cmd/vidai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,12 @@ func newGenerateCommand() *ffcli.Command {
return fmt.Errorf("image or text is required")
}
c := vidai.New(&cfg)
urls, err := c.Generate(ctx, *image, *text, *output, *extend,
u, err := c.Generate(ctx, *image, *text, *output, *extend,
*interpolate, *upscale, *watermark)
if err != nil {
return err
}
if len(urls) == 1 {
fmt.Printf("Video URL: %s\n", urls[0])
} else {
for i, u := range urls {
fmt.Printf("Video URL %d: %s\n", i+1, u)
}
}
fmt.Printf("Video URL: %s\n", u)
return nil
},
}
Expand Down
57 changes: 36 additions & 21 deletions pkg/runway/runway.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,18 @@ type createTaskRequest struct {
}

type gen2Options struct {
Interpolate bool `json:"interpolate"`
Seed int `json:"seed"`
Upscale bool `json:"upscale"`
TextPrompt string `json:"text_prompt"`
Watermark bool `json:"watermark"`
ImagePrompt string `json:"image_prompt"`
InitImage string `json:"init_image"`
Mode string `json:"mode"`
Interpolate bool `json:"interpolate"`
Seed int `json:"seed"`
Upscale bool `json:"upscale"`
TextPrompt string `json:"text_prompt"`
Watermark bool `json:"watermark"`
ImagePrompt string `json:"image_prompt,omitempty"`
InitImage string `json:"init_image,omitempty"`
Mode string `json:"mode"`
InitVideo string `json:"init_video,omitempty"`
MotionScore int `json:"motion_score"`
UseMotionScore bool `json:"use_motion_score"`
UseMotionVectors bool `json:"use_motion_vectors"`
}

type taskResponse struct {
Expand Down Expand Up @@ -243,21 +247,21 @@ type artifact struct {
ParentAssetGroupId string `json:"parentAssetGroupId"`
Filename string `json:"filename"`
URL string `json:"url"`
FileSize int `json:"fileSize"`
FileSize string `json:"fileSize"`
IsDirectory bool `json:"isDirectory"`
PreviewURLs []string `json:"previewUrls"`
Private bool `json:"private"`
PrivateInTeam bool `json:"privateInTeam"`
Deleted bool `json:"deleted"`
Reported bool `json:"reported"`
Metadata struct {
FrameRate int `json:"frameRate"`
Duration int `json:"duration"`
Dimensions []int `json:"dimensions"`
FrameRate int `json:"frameRate"`
Duration float32 `json:"duration"`
Dimensions []int `json:"dimensions"`
} `json:"metadata"`
}

func (c *Client) Generate(ctx context.Context, imageURL, textPrompt string, interpolate, upscale, watermark bool) (string, error) {
func (c *Client) Generate(ctx context.Context, assetURL, textPrompt string, interpolate, upscale, watermark, extend bool) (string, error) {
// Load team ID
if err := c.loadTeamID(ctx); err != nil {
return "", fmt.Errorf("runway: couldn't load team id: %w", err)
Expand All @@ -266,6 +270,14 @@ func (c *Client) Generate(ctx context.Context, imageURL, textPrompt string, inte
// Generate seed
seed := rand.Intn(1000000000)

var imageURL string
var videoURL string
if extend {
videoURL = assetURL
} else {
imageURL = assetURL
}

// Create task
createReq := &createTaskRequest{
TaskType: "gen2",
Expand All @@ -279,14 +291,17 @@ func (c *Client) Generate(ctx context.Context, imageURL, textPrompt string, inte
}{
Seconds: 4,
Gen2Options: gen2Options{
Interpolate: interpolate,
Seed: seed,
Upscale: upscale,
TextPrompt: textPrompt,
Watermark: watermark,
ImagePrompt: imageURL,
InitImage: imageURL,
Mode: "gen2",
Interpolate: interpolate,
Seed: seed,
Upscale: upscale,
TextPrompt: textPrompt,
Watermark: watermark,
ImagePrompt: imageURL,
InitImage: imageURL,
InitVideo: videoURL,
Mode: "gen2",
UseMotionScore: true,
MotionScore: 22,
},
Name: fmt.Sprintf("Gen-2, %d", seed),
AssetGroupName: "Gen-2",
Expand Down
87 changes: 87 additions & 0 deletions pkg/runway/runway_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package runway

import (
"encoding/json"
"testing"
)

func TestUnmarshal(t *testing.T) {
js := `{
"task": {
"id": "00000000-0000-0000-0000-000000000000",
"name": "Gen-2, 100000",
"image": null,
"createdAt": "2024-01-01T01:01:01.001Z",
"updatedAt": "2024-01-01T01:01:01.001Z",
"taskType": "gen2",
"options": {
"seconds": 4,
"gen2Options": {
"interpolate": true,
"seed": 100000,
"upscale": true,
"text_prompt": "",
"watermark": false,
"image_prompt": "https://a.url.test",
"init_image": "https://a.url.test",
"mode": "gen2",
"motion_score": 22,
"use_motion_score": true,
"use_motion_vectors": false
},
"name": "Gen-2, 100000",
"assetGroupName": "Gen-2",
"exploreMode": false,
"recordingEnabled": true
},
"status": "SUCCEEDED",
"error": null,
"progressText": null,
"progressRatio": "1",
"placeInLine": null,
"estimatedTimeToStartSeconds": null,
"artifacts": [
{
"id": "00000000-0000-0000-0000-000000000000",
"createdAt": "2024-01-01T01:01:01.001Z",
"updatedAt": "2024-01-01T01:01:01.001Z",
"userId": 100000,
"createdBy": 100000,
"taskId": "00000000-0000-0000-0000-000000000000",
"parentAssetGroupId": "00000000-0000-0000-0000-000000000000",
"filename": "Gen-2, 100000.mp4",
"url": "https://a.url.test",
"fileSize": "100000",
"isDirectory": false,
"previewUrls": [
"https://a.url.test",
"https://a.url.test",
"https://a.url.test",
"https://a.url.test"
],
"private": true,
"privateInTeam": true,
"deleted": false,
"reported": false,
"metadata": {
"frameRate": 24,
"duration": 8.1,
"dimensions": [
2816,
1536
],
"size": {
"width": 2816,
"height": 1536
}
}
}
],
"sharedAsset": null
}
}`
var resp taskResponse
if err := json.Unmarshal([]byte(js), &resp); err != nil {
t.Fatal(err)
}
}
36 changes: 17 additions & 19 deletions vidai.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,31 @@ func New(cfg *Config) *Client {

// Generate generates a video from an image and a text prompt.
func (c *Client) Generate(ctx context.Context, image, text, output string,
extend int, interpolate, upscale, watermark bool) ([]string, error) {
extend int, interpolate, upscale, watermark bool) (string, error) {
b, err := os.ReadFile(image)
if err != nil {
return nil, fmt.Errorf("vidai: couldn't read image: %w", err)
return "", fmt.Errorf("vidai: couldn't read image: %w", err)
}
name := filepath.Base(image)

var imageURL string
if image != "" {
imageURL, err = c.client.Upload(ctx, name, b)
if err != nil {
return nil, fmt.Errorf("vidai: couldn't upload image: %w", err)
return "", fmt.Errorf("vidai: couldn't upload image: %w", err)
}
}
videoURL, err := c.client.Generate(ctx, imageURL, text, interpolate, upscale, watermark)
videoURL, err := c.client.Generate(ctx, imageURL, text, interpolate, upscale, watermark, false)
if err != nil {
return nil, fmt.Errorf("vidai: couldn't generate video: %w", err)
return "", fmt.Errorf("vidai: couldn't generate video: %w", err)
}

// Extend video
for i := 0; i < extend; i++ {
videoURL, err = c.client.Generate(ctx, videoURL, "", interpolate, upscale, watermark, true)
if err != nil {
return "", fmt.Errorf("vidai: couldn't extend video: %w", err)
}
}

// Use temp file if no output is set and we need to extend the video
Expand All @@ -77,24 +85,14 @@ func (c *Client) Generate(ctx context.Context, image, text, output string,
// Download video
if videoPath != "" {
if err := c.download(ctx, videoURL, videoPath); err != nil {
return nil, fmt.Errorf("vidai: couldn't download video: %w", err)
}
}

// Extend video
if extend > 0 {
extendURLs, err := c.Extend(ctx, videoPath, output, extend,
interpolate, upscale, watermark)
if err != nil {
return nil, fmt.Errorf("vidai: couldn't extend video: %w", err)
return "", fmt.Errorf("vidai: couldn't download video: %w", err)
}
return append([]string{output}, extendURLs...), nil
}

return []string{videoURL}, nil
return videoURL, nil
}

// Extend extends a video using the last frame of the previous video.
// Extend extends a video using the previous video.
func (c *Client) Extend(ctx context.Context, input, output string, n int,
interpolate, upscale, watermark bool) ([]string, error) {
base := strings.TrimSuffix(filepath.Base(input), filepath.Ext(input))
Expand Down Expand Up @@ -133,7 +131,7 @@ func (c *Client) Extend(ctx context.Context, input, output string, n int,
if err != nil {
return nil, fmt.Errorf("vidai: couldn't upload image: %w", err)
}
videoURL, err := c.client.Generate(ctx, imageURL, "", interpolate, upscale, watermark)
videoURL, err := c.client.Generate(ctx, imageURL, "", interpolate, upscale, watermark, false)
if err != nil {
return nil, fmt.Errorf("vidai: couldn't generate video: %w", err)
}
Expand Down

0 comments on commit f7fb0fe

Please sign in to comment.