From 53ce8566be506b4970b7c791bf6cca9e0fd98676 Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Fri, 9 Feb 2024 09:07:29 -0500 Subject: [PATCH] feat: remote exec --- api/query.go | 12 ++- cmd/exec.go | 18 +++++ cmd/exec/commands.go | 39 ++++++++++ cmd/exec/functions.go | 25 ++++++ cmd/project/functions.go | 18 ++++- cmd/project/ssh.go | 161 ++++++++++++++++++++------------------- cmd/root.go | 3 + 7 files changed, 192 insertions(+), 84 deletions(-) create mode 100644 cmd/exec.go create mode 100644 cmd/exec/commands.go create mode 100644 cmd/exec/functions.go diff --git a/api/query.go b/api/query.go index a6ca5f4..7908da1 100644 --- a/api/query.go +++ b/api/query.go @@ -3,6 +3,8 @@ package api import ( "bytes" "encoding/json" + "errors" + "fmt" "net/http" "os" "runtime" @@ -19,17 +21,25 @@ type Input struct { func Query(input Input) (res *http.Response, err error) { jsonValue, err := json.Marshal(input) if err != nil { - return + return nil, err } apiUrl := os.Getenv("RUNPOD_API_URL") if apiUrl == "" { apiUrl = viper.GetString("apiUrl") } + apiKey := os.Getenv("RUNPOD_API_KEY") if apiKey == "" { apiKey = viper.GetString("apiKey") } + + // Check if the API key is present + if apiKey == "" { + fmt.Println("API key not found") + return nil, errors.New("API key not found") + } + req, err := http.NewRequest("POST", apiUrl+"?api_key="+apiKey, bytes.NewBuffer(jsonValue)) if err != nil { return diff --git a/cmd/exec.go b/cmd/exec.go new file mode 100644 index 0000000..fa17a64 --- /dev/null +++ b/cmd/exec.go @@ -0,0 +1,18 @@ +package cmd + +import ( + "cli/cmd/exec" + + "github.com/spf13/cobra" +) + +// execCmd represents the base command for executing commands in a pod +var execCmd = &cobra.Command{ + Use: "exec", + Short: "Execute commands in a pod", + Long: `Execute a local file remotely in a pod.`, +} + +func init() { + execCmd.AddCommand(exec.RemotePythonCmd) +} diff --git a/cmd/exec/commands.go b/cmd/exec/commands.go new file mode 100644 index 0000000..aa1898b --- /dev/null +++ b/cmd/exec/commands.go @@ -0,0 +1,39 @@ +package exec + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var RemotePythonCmd = &cobra.Command{ + Use: "python [file]", + Short: "Runs a remote Python shell", + Long: `Runs a remote Python shell with a local script file.`, + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + podID, _ := cmd.Flags().GetString("pod_id") + file := args[0] + + // Default to the session pod if no pod_id is provided + // if podID == "" { + // var err error + // podID, err = api.GetSessionPod() + // if err != nil { + // fmt.Fprintf(os.Stderr, "Error retrieving session pod: %v\n", err) + // return + // } + // } + + fmt.Println("Running remote Python shell...") + if err := PythonOverSSH(podID, file); err != nil { + fmt.Fprintf(os.Stderr, "Error executing Python over SSH: %v\n", err) + } + }, +} + +func init() { + RemotePythonCmd.Flags().String("pod_id", "", "The ID of the pod to run the command on.") + RemotePythonCmd.MarkFlagRequired("file") +} diff --git a/cmd/exec/functions.go b/cmd/exec/functions.go new file mode 100644 index 0000000..294ef96 --- /dev/null +++ b/cmd/exec/functions.go @@ -0,0 +1,25 @@ +package exec + +import ( + "cli/cmd/project" + "fmt" +) + +func PythonOverSSH(podID string, file string) error { + sshConn, err := project.PodSSHConnection(podID) + if err != nil { + return fmt.Errorf("getting SSH connection: %w", err) + } + + // Copy the file to the pod using Rsync + if err := sshConn.Rsync(file, "/tmp/"+file, false); err != nil { + return fmt.Errorf("copying file to pod: %w", err) + } + + // Run the file on the pod + if err := sshConn.RunCommand("python3.11 /tmp/" + file); err != nil { + return fmt.Errorf("running Python command: %w", err) + } + + return nil +} diff --git a/cmd/project/functions.go b/cmd/project/functions.go index 3d0b8d7..84e3b19 100644 --- a/cmd/project/functions.go +++ b/cmd/project/functions.go @@ -42,19 +42,31 @@ func copyFiles(files fs.FS, source string, dest string) error { if path == source { return nil } + + relPath, err := filepath.Rel(source, path) + if err != nil { + return err + } + // Generate the corresponding path in the new project folder - newPath := filepath.Join(dest, path[len(source):]) + newPath := filepath.Join(dest, relPath) if d.IsDir() { - return os.MkdirAll(newPath, os.ModePerm) + if err := os.MkdirAll(newPath, os.ModePerm); err != nil { + return err + } } else { content, err := fs.ReadFile(files, path) if err != nil { return err } - return os.WriteFile(newPath, content, 0644) + if err := os.WriteFile(newPath, content, 0644); err != nil { + return err + } } + return nil }) } + func createNewProject(projectName string, cudaVersion string, pythonVersion string, modelType string, modelName string, initCurrentDir bool) { projectFolder, _ := os.Getwd() diff --git a/cmd/project/ssh.go b/cmd/project/ssh.go index 8cfe353..35a24e9 100644 --- a/cmd/project/ssh.go +++ b/cmd/project/ssh.go @@ -5,6 +5,7 @@ import ( "cli/api" "errors" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -15,33 +16,36 @@ import ( "golang.org/x/crypto/ssh" ) -func getPodSSHInfo(podId string) (podIp string, podPort int, err error) { +const ( + pollInterval = 1 * time.Second + maxPollTime = 5 * time.Minute // Adjusted for clarity +) + +func getPodSSHInfo(podID string) (string, int, error) { pods, err := api.GetPods() if err != nil { - return "", 0, err + return "", 0, fmt.Errorf("getting pods: %w", err) } - var pod api.Pod - for _, p := range pods { - if p.Id == podId { - pod = *p + + for _, pod := range pods { + if pod.Id != podID { + continue } - } - //is pod ready for ssh yet? - if pod.DesiredStatus != "RUNNING" { - return "", 0, errors.New("pod desired status not RUNNING") - } - if pod.Runtime == nil { - return "", 0, errors.New("pod runtime is nil") - } - if pod.Runtime.Ports == nil { - return "", 0, errors.New("pod runtime ports is nil") - } - for _, port := range pod.Runtime.Ports { - if port.PrivatePort == 22 { - return port.Ip, port.PublicPort, nil + + if pod.DesiredStatus != "RUNNING" { + return "", 0, fmt.Errorf("pod desired status not RUNNING") + } + if pod.Runtime == nil || pod.Runtime.Ports == nil { + return "", 0, fmt.Errorf("pod runtime information is missing") + } + for _, port := range pod.Runtime.Ports { + if port.PrivatePort == 22 { + return port.Ip, port.PublicPort, nil + } } + } - return "", 0, errors.New("no SSH port exposed on Pod") + return "", 0, fmt.Errorf("no SSH port exposed on pod %s", podID) } type SSHConnection struct { @@ -60,12 +64,14 @@ func (sshConn *SSHConnection) getSshOptions() []string { "-i", sshConn.sshKeyPath, } } + func (sshConn *SSHConnection) Rsync(localDir string, remoteDir string, quiet bool) error { rsyncCmdArgs := []string{"-avz", "--no-owner", "--no-group"} patterns, err := GetIgnoreList() if err != nil { return err } + for _, pat := range patterns { rsyncCmdArgs = append(rsyncCmdArgs, "--exclude", pat) } @@ -76,26 +82,25 @@ func (sshConn *SSHConnection) Rsync(localDir string, remoteDir string, quiet boo sshOptions := strings.Join(sshConn.getSshOptions(), " ") rsyncCmdArgs = append(rsyncCmdArgs, "-e", fmt.Sprintf("ssh %s", sshOptions)) rsyncCmdArgs = append(rsyncCmdArgs, localDir, fmt.Sprintf("root@%s:%s", sshConn.podIp, remoteDir)) + cmd := exec.Command("rsync", rsyncCmdArgs...) cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { - fmt.Println("could not run rsync command: ", err) - return err + return fmt.Errorf("executing rsync command: %w", err) } return nil } // hasChanges checks if there are any modified files in localDir since lastSyncTime. func hasChanges(localDir string, lastSyncTime time.Time) (bool, string) { - var hasModifications bool - var firstModifiedFile string + var firstModifiedFile string = "" err := filepath.Walk(localDir, func(path string, info os.FileInfo, err error) error { if err != nil { if os.IsNotExist(err) { // Handle the case where a file has been removed fmt.Printf("Detected a removed file at: %s\n", path) - hasModifications = true return errors.New("change detected") // Stop walking } return err @@ -103,7 +108,6 @@ func hasChanges(localDir string, lastSyncTime time.Time) (bool, string) { // Check if the file was modified after the last sync time if info.ModTime().After(lastSyncTime) { - hasModifications = true firstModifiedFile = path return filepath.SkipDir // Skip the rest of the directory if a change is found } @@ -116,7 +120,7 @@ func hasChanges(localDir string, lastSyncTime time.Time) (bool, string) { return false, "" } - return hasModifications, firstModifiedFile + return firstModifiedFile != "", firstModifiedFile } func (sshConn *SSHConnection) SyncDir(localDir string, remoteDir string) { @@ -143,96 +147,94 @@ func (sshConn *SSHConnection) SyncDir(localDir string, remoteDir string) { <-make(chan struct{}) } -func (sshConn *SSHConnection) RunCommand(command string) error { - return sshConn.RunCommands([]string{command}) +// RunCommand runs a command on the remote pod. +func (conn *SSHConnection) RunCommand(command string) error { + return conn.RunCommands([]string{command}) } +// RunCommands runs a list of commands on the remote pod. func (sshConn *SSHConnection) RunCommands(commands []string) error { - - stdoutColor := color.New(color.FgGreen) - stderrColor := color.New(color.FgRed) + stdoutColor, stderrColor := color.New(color.FgGreen), color.New(color.FgRed) for _, command := range commands { - // Create a session session, err := sshConn.client.NewSession() if err != nil { - fmt.Println("Failed to create session: %s", err) - return err + return fmt.Errorf("failed to create SSH session: %w", err) } + defer session.Close() + // Set up pipes for stdout and stderr stdout, err := session.StdoutPipe() if err != nil { - return err + return fmt.Errorf("failed to get stdout pipe: %w", err) } + go scanAndPrint(stdout, stdoutColor, sshConn.podId) + stderr, err := session.StderrPipe() if err != nil { - return err + return fmt.Errorf("failed to get stderr pipe: %w", err) } + go scanAndPrint(stderr, stderrColor, sshConn.podId) - //listen to stdout - go func() { - scanner := bufio.NewScanner(stdout) - for scanner.Scan() { - if showPrefixInPodLogs { - stdoutColor.Printf("[%s] ", sshConn.podId) - } - fmt.Println(scanner.Text()) - } - }() - - //listen to stderr - go func() { - scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - if showPrefixInPodLogs { - stderrColor.Printf("[%s] ", sshConn.podId) - } - fmt.Println(scanner.Text()) - } - }() + // Run the command fullCommand := strings.Join([]string{ "source /root/.bashrc", "source /etc/rp_environment", "while IFS= read -r -d '' line; do export \"$line\"; done < /proc/1/environ", command, }, " && ") - err = session.Run(fullCommand) - if err != nil { - session.Close() - return err + + if err := session.Run(fullCommand); err != nil { + return fmt.Errorf("failed to run command %q: %w", command, err) } - session.Close() } return nil } +// Utility function to scan and print output from SSH sessions. +func scanAndPrint(pipe io.Reader, color *color.Color, podID string) { + scanner := bufio.NewScanner(pipe) + for scanner.Scan() { + color.Printf("[%s] %s\n", podID, scanner.Text()) + } +} + func PodSSHConnection(podId string) (*SSHConnection, error) { - //check ssh key exists - home, _ := os.UserHomeDir() - sshFilePath := filepath.Join(home, ".runpod", "ssh", "RunPod-Key-Go") - privateKeyBytes, err := os.ReadFile(sshFilePath) + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("getting user home directory: %w", err) + } + + sshKeyPath := filepath.Join(homeDir, ".runpod", "ssh", "RunPod-Key-Go") + privateKeyBytes, err := os.ReadFile(sshKeyPath) if err != nil { - fmt.Println("failed to get private key") - return nil, err + return nil, fmt.Errorf("reading private SSH key from %s: %w", sshKeyPath, err) } + privateKey, err := ssh.ParsePrivateKey(privateKeyBytes) if err != nil { - fmt.Println("failed to parse private key") - return nil, err + return nil, fmt.Errorf("parsing private SSH key: %w", err) } + //loop until pod ready - pollIntervalSeconds := 1 - maxPollTimeSeconds := 300 - startTime := time.Now() + fmt.Print("Waiting for Pod to come online... ") //look up ip and ssh port for pod id var podIp string var podPort int - for podIp, podPort, err = getPodSSHInfo(podId); err != nil && time.Since(startTime) < time.Duration(maxPollTimeSeconds*int(time.Second)); { - time.Sleep(time.Duration(pollIntervalSeconds * int(time.Second))) + + startTime := time.Now() + for podIp, podPort, err = getPodSSHInfo(podId); err != nil && time.Since(startTime) < maxPollTime; { + time.Sleep(pollInterval) podIp, podPort, err = getPodSSHInfo(podId) } + if err != nil { + return nil, fmt.Errorf("failed to get SSH info for pod %s: %w", podId, err) + } else if time.Since(startTime) >= time.Duration(maxPollTime) { + return nil, fmt.Errorf("timeout waiting for pod %s to come online", podId) + } + // Configure the SSH client config := &ssh.ClientConfig{ User: "root", @@ -246,10 +248,9 @@ func PodSSHConnection(podId string) (*SSHConnection, error) { host := fmt.Sprintf("%s:%d", podIp, podPort) client, err := ssh.Dial("tcp", host, config) if err != nil { - fmt.Println("Failed to dial for SSH conn: %s", err) - return nil, err + return nil, fmt.Errorf("establishing SSH connection to %s: %w", host, err) } - return &SSHConnection{podId: podId, client: client, podIp: podIp, podPort: podPort, sshKeyPath: sshFilePath}, nil + return &SSHConnection{podId: podId, client: client, podIp: podIp, podPort: podPort, sshKeyPath: sshKeyPath}, nil } diff --git a/cmd/root.go b/cmd/root.go index d2e1199..7dbd587 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -45,6 +45,9 @@ func registerCommands() { rootCmd.AddCommand(updateCmd) rootCmd.AddCommand(sshCmd) + // Remote File Execution + rootCmd.AddCommand(execCmd) + // file transfer via croc rootCmd.AddCommand(croc.ReceiveCmd) rootCmd.AddCommand(croc.SendCmd)