Skip to content

Commit

Permalink
feat: remote exec
Browse files Browse the repository at this point in the history
  • Loading branch information
justinmerrell committed Feb 9, 2024
1 parent f5cde85 commit 53ce856
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 84 deletions.
12 changes: 11 additions & 1 deletion api/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package api
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"runtime"
Expand All @@ -19,17 +21,25 @@ type Input struct {
func Query(input Input) (res *http.Response, err error) {
jsonValue, err := json.Marshal(input)
if err != nil {
return
return nil, err
}

apiUrl := os.Getenv("RUNPOD_API_URL")
if apiUrl == "" {
apiUrl = viper.GetString("apiUrl")
}

apiKey := os.Getenv("RUNPOD_API_KEY")
if apiKey == "" {
apiKey = viper.GetString("apiKey")
}

// Check if the API key is present
if apiKey == "" {
fmt.Println("API key not found")
return nil, errors.New("API key not found")
}

req, err := http.NewRequest("POST", apiUrl+"?api_key="+apiKey, bytes.NewBuffer(jsonValue))
if err != nil {
return
Expand Down
18 changes: 18 additions & 0 deletions cmd/exec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package cmd

import (
"cli/cmd/exec"

"github.com/spf13/cobra"
)

// execCmd represents the base command for executing commands in a pod
var execCmd = &cobra.Command{
Use: "exec",
Short: "Execute commands in a pod",
Long: `Execute a local file remotely in a pod.`,
}

func init() {
execCmd.AddCommand(exec.RemotePythonCmd)
}
39 changes: 39 additions & 0 deletions cmd/exec/commands.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package exec

import (
"fmt"
"os"

"github.com/spf13/cobra"
)

var RemotePythonCmd = &cobra.Command{
Use: "python [file]",
Short: "Runs a remote Python shell",
Long: `Runs a remote Python shell with a local script file.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
podID, _ := cmd.Flags().GetString("pod_id")
file := args[0]

// Default to the session pod if no pod_id is provided
// if podID == "" {
// var err error
// podID, err = api.GetSessionPod()
// if err != nil {
// fmt.Fprintf(os.Stderr, "Error retrieving session pod: %v\n", err)
// return
// }
// }

fmt.Println("Running remote Python shell...")
if err := PythonOverSSH(podID, file); err != nil {
fmt.Fprintf(os.Stderr, "Error executing Python over SSH: %v\n", err)
}
},
}

func init() {
RemotePythonCmd.Flags().String("pod_id", "", "The ID of the pod to run the command on.")
RemotePythonCmd.MarkFlagRequired("file")
}
25 changes: 25 additions & 0 deletions cmd/exec/functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package exec

import (
"cli/cmd/project"
"fmt"
)

func PythonOverSSH(podID string, file string) error {
sshConn, err := project.PodSSHConnection(podID)
if err != nil {
return fmt.Errorf("getting SSH connection: %w", err)
}

// Copy the file to the pod using Rsync
if err := sshConn.Rsync(file, "/tmp/"+file, false); err != nil {
return fmt.Errorf("copying file to pod: %w", err)
}

// Run the file on the pod
if err := sshConn.RunCommand("python3.11 /tmp/" + file); err != nil {
return fmt.Errorf("running Python command: %w", err)
}

return nil
}
18 changes: 15 additions & 3 deletions cmd/project/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,31 @@ func copyFiles(files fs.FS, source string, dest string) error {
if path == source {
return nil
}

relPath, err := filepath.Rel(source, path)
if err != nil {
return err
}

// Generate the corresponding path in the new project folder
newPath := filepath.Join(dest, path[len(source):])
newPath := filepath.Join(dest, relPath)
if d.IsDir() {
return os.MkdirAll(newPath, os.ModePerm)
if err := os.MkdirAll(newPath, os.ModePerm); err != nil {
return err
}
} else {
content, err := fs.ReadFile(files, path)
if err != nil {
return err
}
return os.WriteFile(newPath, content, 0644)
if err := os.WriteFile(newPath, content, 0644); err != nil {
return err
}
}
return nil
})
}

func createNewProject(projectName string, cudaVersion string,
pythonVersion string, modelType string, modelName string, initCurrentDir bool) {
projectFolder, _ := os.Getwd()
Expand Down
Loading

0 comments on commit 53ce856

Please sign in to comment.