Skip to content

Commit

Permalink
Added method to obtain URL from asset ID
Browse files Browse the repository at this point in the history
URLs are only valid for a limited time, if the asset ID is known, a
valid URL can be obtained from the API.
  • Loading branch information
igolaizola committed Jan 30, 2024
1 parent f7fb0fe commit 2d270d3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 24 deletions.
4 changes: 2 additions & 2 deletions cmd/vidai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ func newGenerateCommand() *ffcli.Command {
return fmt.Errorf("image or text is required")
}
c := vidai.New(&cfg)
u, err := c.Generate(ctx, *image, *text, *output, *extend,
id, u, err := c.Generate(ctx, *image, *text, *output, *extend,
*interpolate, *upscale, *watermark)
if err != nil {
return err
}
fmt.Printf("Video URL: %s\n", u)
fmt.Printf("ID: %s URL: %s\n", id, u)
return nil
},
}
Expand Down
43 changes: 31 additions & 12 deletions pkg/runway/runway.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,16 @@ type artifact struct {
ID string `json:"id"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
Name string `json:"name"`
MediaType string `json:"mediaType"`
UserID int `json:"userId"`
CreatedBy int `json:"createdBy"`
TaskID string `json:"taskId"`
ParentAssetGroupId string `json:"parentAssetGroupId"`
Filename string `json:"filename"`
URL string `json:"url"`
FileSize string `json:"fileSize"`
FileSize any `json:"fileSize"`
FileExtension string `json:"fileExtStandardized"`
IsDirectory bool `json:"isDirectory"`
PreviewURLs []string `json:"previewUrls"`
Private bool `json:"private"`
Expand All @@ -261,10 +264,10 @@ type artifact struct {
} `json:"metadata"`
}

func (c *Client) Generate(ctx context.Context, assetURL, textPrompt string, interpolate, upscale, watermark, extend bool) (string, error) {
func (c *Client) Generate(ctx context.Context, assetURL, textPrompt string, interpolate, upscale, watermark, extend bool) (string, string, error) {
// Load team ID
if err := c.loadTeamID(ctx); err != nil {
return "", fmt.Errorf("runway: couldn't load team id: %w", err)
return "", "", fmt.Errorf("runway: couldn't load team id: %w", err)
}

// Generate seed
Expand Down Expand Up @@ -311,35 +314,36 @@ func (c *Client) Generate(ctx context.Context, assetURL, textPrompt string, inte
}
var taskResp taskResponse
if err := c.do(ctx, "POST", "tasks", createReq, &taskResp); err != nil {
return "", fmt.Errorf("runway: couldn't create task: %w", err)
return "", "", fmt.Errorf("runway: couldn't create task: %w", err)
}

// Wait for task to finish
for {
switch taskResp.Task.Status {
case "SUCCEEDED":
if len(taskResp.Task.Artifacts) == 0 {
return "", fmt.Errorf("runway: no artifacts returned")
return "", "", fmt.Errorf("runway: no artifacts returned")
}
if taskResp.Task.Artifacts[0].URL == "" {
return "", fmt.Errorf("runway: empty artifact url")
artifact := taskResp.Task.Artifacts[0]
if artifact.URL == "" {
return "", "", fmt.Errorf("runway: empty artifact url")
}
return taskResp.Task.Artifacts[0].URL, nil
return artifact.ID, artifact.URL, nil
case "PENDING", "RUNNING":
c.log("runway: task %s: %s", taskResp.Task.ID, taskResp.Task.ProgressRatio)
default:
return "", fmt.Errorf("runway: task failed: %s", taskResp.Task.Status)
return "", "", fmt.Errorf("runway: task failed: %s", taskResp.Task.Status)
}

select {
case <-ctx.Done():
return "", fmt.Errorf("runway: %w", ctx.Err())
return "", "", fmt.Errorf("runway: %w", ctx.Err())
case <-time.After(5 * time.Second):
}

path := fmt.Sprintf("tasks/%s?asTeamId=%d", taskResp.Task.ID, c.teamID)
if err := c.do(ctx, "GET", path, nil, &taskResp); err != nil {
return "", fmt.Errorf("runway: couldn't get task: %w", err)
return "", "", fmt.Errorf("runway: couldn't get task: %w", err)
}
}
}
Expand All @@ -351,7 +355,10 @@ type assetDeleteResponse struct {
Success bool `json:"success"`
}

// TODO: Delete asset by url instead
type assetResponse struct {
Asset artifact `json:"asset"`
}

func (c *Client) DeleteAsset(ctx context.Context, id string) error {
path := fmt.Sprintf("assets/%s", id)
var resp assetDeleteResponse
Expand All @@ -364,6 +371,18 @@ func (c *Client) DeleteAsset(ctx context.Context, id string) error {
return nil
}

func (c *Client) GetAsset(ctx context.Context, id string) (string, error) {
path := fmt.Sprintf("assets/%s", id)
var resp assetResponse
if err := c.do(ctx, "GET", path, nil, &resp); err != nil {
return "", fmt.Errorf("runway: couldn't get asset %s: %w", id, err)
}
if resp.Asset.URL == "" {
return "", fmt.Errorf("runway: empty asset url")
}
return resp.Asset.URL, nil
}

func (c *Client) log(format string, args ...interface{}) {
if c.debug {
format += "\n"
Expand Down
25 changes: 15 additions & 10 deletions vidai.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,30 @@ func New(cfg *Config) *Client {

// Generate generates a video from an image and a text prompt.
func (c *Client) Generate(ctx context.Context, image, text, output string,
extend int, interpolate, upscale, watermark bool) (string, error) {
extend int, interpolate, upscale, watermark bool) (string, string, error) {
b, err := os.ReadFile(image)
if err != nil {
return "", fmt.Errorf("vidai: couldn't read image: %w", err)
return "", "", fmt.Errorf("vidai: couldn't read image: %w", err)
}
name := filepath.Base(image)

var imageURL string
if image != "" {
imageURL, err = c.client.Upload(ctx, name, b)
if err != nil {
return "", fmt.Errorf("vidai: couldn't upload image: %w", err)
return "", "", fmt.Errorf("vidai: couldn't upload image: %w", err)
}
}
videoURL, err := c.client.Generate(ctx, imageURL, text, interpolate, upscale, watermark, false)
id, videoURL, err := c.client.Generate(ctx, imageURL, text, interpolate, upscale, watermark, false)
if err != nil {
return "", fmt.Errorf("vidai: couldn't generate video: %w", err)
return "", "", fmt.Errorf("vidai: couldn't generate video: %w", err)
}

// Extend video
for i := 0; i < extend; i++ {
videoURL, err = c.client.Generate(ctx, videoURL, "", interpolate, upscale, watermark, true)
id, videoURL, err = c.client.Generate(ctx, videoURL, "", interpolate, upscale, watermark, true)
if err != nil {
return "", fmt.Errorf("vidai: couldn't extend video: %w", err)
return "", "", fmt.Errorf("vidai: couldn't extend video: %w", err)
}
}

Expand All @@ -85,11 +85,11 @@ func (c *Client) Generate(ctx context.Context, image, text, output string,
// Download video
if videoPath != "" {
if err := c.download(ctx, videoURL, videoPath); err != nil {
return "", fmt.Errorf("vidai: couldn't download video: %w", err)
return "", "", fmt.Errorf("vidai: couldn't download video: %w", err)
}
}

return videoURL, nil
return id, videoURL, nil
}

// Extend extends a video using the previous video.
Expand Down Expand Up @@ -131,7 +131,7 @@ func (c *Client) Extend(ctx context.Context, input, output string, n int,
if err != nil {
return nil, fmt.Errorf("vidai: couldn't upload image: %w", err)
}
videoURL, err := c.client.Generate(ctx, imageURL, "", interpolate, upscale, watermark, false)
_, videoURL, err := c.client.Generate(ctx, imageURL, "", interpolate, upscale, watermark, false)
if err != nil {
return nil, fmt.Errorf("vidai: couldn't generate video: %w", err)
}
Expand Down Expand Up @@ -184,6 +184,11 @@ func (c *Client) Extend(ctx context.Context, input, output string, n int,
return urls, nil
}

// URL returns the URL of a video.
func (c *Client) URL(ctx context.Context, id string) (string, error) {
return c.client.GetAsset(ctx, id)
}

func (c *Client) download(ctx context.Context, url, output string) error {
// Create request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
Expand Down

0 comments on commit 2d270d3

Please sign in to comment.