diff --git a/main.go b/main.go index c1101f59..3cacb60a 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,8 @@ package main import ( "context" "flag" + "os" + "github.com/hobbyfarm/gargantua/v3/pkg/accesscode" "github.com/hobbyfarm/gargantua/v3/pkg/authserver" hfClientset "github.com/hobbyfarm/gargantua/v3/pkg/client/clientset/versioned" @@ -34,7 +36,6 @@ import ( "github.com/hobbyfarm/gargantua/v3/protos/authn" "github.com/hobbyfarm/gargantua/v3/protos/authr" "github.com/hobbyfarm/gargantua/v3/protos/setting" - "os" "github.com/ebauman/crder" "golang.org/x/sync/errgroup" diff --git a/v3/pkg/apis/hobbyfarm.io/v1/types.go b/v3/pkg/apis/hobbyfarm.io/v1/types.go index 28c7e88a..f460df0a 100644 --- a/v3/pkg/apis/hobbyfarm.io/v1/types.go +++ b/v3/pkg/apis/hobbyfarm.io/v1/types.go @@ -278,6 +278,7 @@ type ScenarioSpec struct { KeepAliveDuration string `json:"keepalive_duration"` PauseDuration string `json:"pause_duration"` Pauseable bool `json:"pauseable"` + Tasks []VirtualMachineTasks `json:"vm_tasks"` } type ScenarioStep struct { @@ -285,6 +286,19 @@ type ScenarioStep struct { Content string `json:"content"` } +type VirtualMachineTasks struct { + VMName string `json:"vm_name"` + Tasks []Task `json:"tasks"` +} +type Task struct { + Name string `json:"name"` + Description string `json:"description"` + Command string `json:"command"` + ExpectedOutputValue string `json:"expected_output_value"` + ExpectedReturnCode int `json:"expected_return_code"` + ReturnType string `json:"return_type"` +} + // +genclient // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object diff --git a/v3/pkg/apis/hobbyfarm.io/v1/zz_generated.deepcopy.go b/v3/pkg/apis/hobbyfarm.io/v1/zz_generated.deepcopy.go index fc9b4937..d317fbad 100644 --- a/v3/pkg/apis/hobbyfarm.io/v1/zz_generated.deepcopy.go +++ b/v3/pkg/apis/hobbyfarm.io/v1/zz_generated.deepcopy.go @@ -731,6 +731,13 @@ func (in *ScenarioSpec) DeepCopyInto(out *ScenarioSpec) { } } } + if in.Tasks != nil { + in, out := &in.Tasks, &out.Tasks + *out = make([]VirtualMachineTasks, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } return } @@ -1118,6 +1125,22 @@ func (in *SettingList) DeepCopyObject() runtime.Object { return nil } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Task) DeepCopyInto(out *Task) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Task. +func (in *Task) DeepCopy() *Task { + if in == nil { + return nil + } + out := new(Task) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *User) DeepCopyInto(out *User) { *out = *in @@ -1529,6 +1552,27 @@ func (in *VirtualMachineStatus) DeepCopy() *VirtualMachineStatus { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *VirtualMachineTasks) DeepCopyInto(out *VirtualMachineTasks) { + *out = *in + if in.Tasks != nil { + in, out := &in.Tasks, &out.Tasks + *out = make([]Task, len(*in)) + copy(*out, *in) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new VirtualMachineTasks. +func (in *VirtualMachineTasks) DeepCopy() *VirtualMachineTasks { + if in == nil { + return nil + } + out := new(VirtualMachineTasks) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *VirtualMachineTemplate) DeepCopyInto(out *VirtualMachineTemplate) { *out = *in diff --git a/v3/pkg/crd/crd.go b/v3/pkg/crd/crd.go index 1514ec20..27090c73 100644 --- a/v3/pkg/crd/crd.go +++ b/v3/pkg/crd/crd.go @@ -2,7 +2,7 @@ package crd import ( "github.com/ebauman/crder" - "github.com/hobbyfarm/gargantua/v3/pkg/apis/hobbyfarm.io/v1" + v1 "github.com/hobbyfarm/gargantua/v3/pkg/apis/hobbyfarm.io/v1" terraformv1 "github.com/hobbyfarm/gargantua/v3/pkg/apis/terraformcontroller.cattle.io/v1" ) diff --git a/v3/pkg/scenarioserver/scenarioserver.go b/v3/pkg/scenarioserver/scenarioserver.go index dd20aebc..ba3ad0f2 100644 --- a/v3/pkg/scenarioserver/scenarioserver.go +++ b/v3/pkg/scenarioserver/scenarioserver.go @@ -7,6 +7,11 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/http" + "slices" + "strconv" + "strings" + "github.com/hobbyfarm/gargantua/v3/pkg/accesscode" hfv1 "github.com/hobbyfarm/gargantua/v3/pkg/apis/hobbyfarm.io/v1" hfClientset "github.com/hobbyfarm/gargantua/v3/pkg/client/clientset/versioned" @@ -14,10 +19,6 @@ import ( "github.com/hobbyfarm/gargantua/v3/pkg/courseclient" rbac2 "github.com/hobbyfarm/gargantua/v3/pkg/rbac" "github.com/hobbyfarm/gargantua/v3/pkg/util" - "net/http" - "slices" - "strconv" - "strings" "github.com/hobbyfarm/gargantua/v3/protos/authn" "github.com/hobbyfarm/gargantua/v3/protos/authr" @@ -50,13 +51,14 @@ type PreparedScenarioStep struct { } type PreparedScenario struct { - Id string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - StepCount int `json:"stepcount"` - VirtualMachines []map[string]string `json:"virtualmachines"` - Pauseable bool `json:"pauseable"` - Printable bool `json:"printable"` + Id string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + StepCount int `json:"stepcount"` + VirtualMachines []map[string]string `json:"virtualmachines"` + Pauseable bool `json:"pauseable"` + Printable bool `json:"printable"` + Tasks []hfv1.VirtualMachineTasks `json:"vm_tasks"` } type AdminPreparedScenario struct { @@ -111,7 +113,7 @@ func (s ScenarioServer) prepareScenario(scenario hfv1.Scenario, printable bool) ps.Pauseable = scenario.Spec.Pauseable ps.Printable = printable ps.StepCount = len(scenario.Spec.Steps) - + ps.Tasks = scenario.Spec.Tasks return ps, nil } @@ -723,6 +725,31 @@ func (s ScenarioServer) CopyFunc(w http.ResponseWriter, r *http.Request) { return } +func VerifyTaskContent(vm_tasks []hfv1.VirtualMachineTasks) error { + //Verify that name, description, command must not empty + for _, vm_task := range vm_tasks { + if vm_task.VMName == "" { + glog.Errorf("error while vm_name empty") + return fmt.Errorf("bad") + } + for _, task := range vm_task.Tasks { + if task.Name == "" { + glog.Errorf("error while Name of task empty") + return fmt.Errorf("bad") + } + if task.Description == "" { + glog.Errorf("error while Description of task empty") + return fmt.Errorf("bad") + } + if task.Command == "" || task.Command == "[]" { + glog.Errorf("error while Command of task empty") + return fmt.Errorf("bad") + } + } + } + return nil +} + func (s ScenarioServer) CreateFunc(w http.ResponseWriter, r *http.Request) { user, err := rbac2.AuthenticateRequest(r, s.authnClient) if err != nil { @@ -813,6 +840,22 @@ func (s ScenarioServer) CreateFunc(w http.ResponseWriter, r *http.Request) { scenario.Spec.Categories = categories scenario.Spec.Tags = tags scenario.Spec.KeepAliveDuration = keepaliveDuration + rawVMTasks := r.PostFormValue("vm_tasks") + if rawVMTasks != "" { + vm_tasks := []hfv1.VirtualMachineTasks{} + + err = json.Unmarshal([]byte(rawVMTasks), &vm_tasks) + if err != nil { + glog.Errorf("error while unmarshaling tasks %v", err) + return + } + err = VerifyTaskContent(vm_tasks) + if err != nil { + glog.Errorf("error tasks content %v", err) + return + } + scenario.Spec.Tasks = vm_tasks + } scenario.Spec.Pauseable = false if pauseable != "" { @@ -875,6 +918,7 @@ func (s ScenarioServer) UpdateFunc(w http.ResponseWriter, r *http.Request) { rawVirtualMachines := r.PostFormValue("virtualmachines") rawCategories := r.PostFormValue("categories") rawTags := r.PostFormValue("tags") + rawVMTasks := r.PostFormValue("vm_tasks") if name != "" { scenario.Spec.Name = name @@ -956,6 +1000,23 @@ func (s ScenarioServer) UpdateFunc(w http.ResponseWriter, r *http.Request) { scenario.Spec.Tags = tagsSlice } + if rawVMTasks != "" { + vm_tasks := []hfv1.VirtualMachineTasks{} + + err = json.Unmarshal([]byte(rawVMTasks), &vm_tasks) + if err != nil { + glog.Errorf("error while unmarshaling tasks %v", err) + return fmt.Errorf("bad") + } + + err = VerifyTaskContent(vm_tasks) + if err != nil { + glog.Errorf("error tasks content %v", err) + return err + } + scenario.Spec.Tasks = vm_tasks + } + _, updateErr := s.hfClientSet.HobbyfarmV1().Scenarios(util.GetReleaseNamespace()).Update(s.ctx, scenario, metav1.UpdateOptions{}) return updateErr }) diff --git a/v3/pkg/shell/shell.go b/v3/pkg/shell/shell.go index 522f9334..a7533090 100644 --- a/v3/pkg/shell/shell.go +++ b/v3/pkg/shell/shell.go @@ -4,10 +4,6 @@ import ( "context" "encoding/json" "fmt" - hfClientset "github.com/hobbyfarm/gargantua/v3/pkg/client/clientset/versioned" - rbac2 "github.com/hobbyfarm/gargantua/v3/pkg/rbac" - "github.com/hobbyfarm/gargantua/v3/pkg/util" - "github.com/hobbyfarm/gargantua/v3/pkg/vmclient" "io" "net/http" "net/http/httputil" @@ -16,15 +12,23 @@ import ( "regexp" "strconv" "strings" + "sync" "time" + hfClientset "github.com/hobbyfarm/gargantua/v3/pkg/client/clientset/versioned" + rbac2 "github.com/hobbyfarm/gargantua/v3/pkg/rbac" + "github.com/hobbyfarm/gargantua/v3/pkg/util" + "github.com/hobbyfarm/gargantua/v3/pkg/vmclient" + "github.com/golang/glog" "github.com/gorilla/mux" "github.com/gorilla/websocket" + hfv1 "github.com/hobbyfarm/gargantua/v3/pkg/apis/hobbyfarm.io/v1" "github.com/hobbyfarm/gargantua/v3/protos/authn" "github.com/hobbyfarm/gargantua/v3/protos/authr" userProto "github.com/hobbyfarm/gargantua/v3/protos/user" "golang.org/x/crypto/ssh" + "golang.org/x/sync/semaphore" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" ) @@ -89,6 +93,7 @@ func NewShellProxy(authnClient authn.AuthNClient, authrClient authr.AuthRClient, func (sp ShellProxy) SetupRoutes(r *mux.Router) { r.HandleFunc("/shell/{vm_id}/connect", sp.ConnectSSHFunc) + r.HandleFunc("/shell/verify", sp.VerifyTasksFuncByVMIdGroupWithSemaphore) r.HandleFunc("/guacShell/{vm_id}/connect", sp.ConnectGuacFunc) r.HandleFunc("/p/{vm_id}/{port}/{rest:.*}", sp.checkCookieAndProxy) r.HandleFunc("/pa/{token}/{vm_id}/{port}/{rest:.*}", sp.authAndProxyFunc) @@ -489,6 +494,338 @@ func copyResponse(rw http.ResponseWriter, resp *http.Response) error { return err } +type VirtualMachineInputTask struct { + VMId string `json:"vm_id"` + VMName string `json:"vm_name"` + Tasks []hfv1.Task `json:"tasks"` +} + +type VirtualMachineOutputTask struct { + VMId string `json:"vm_id"` + VMName string `json:"vm_name"` + TaskOutputs []TaskWithOutput `json:"task_outputs"` +} + +type TaskOutputCommand struct { + ActualOutputValue string `json:"actual_output_value"` + ActualReturnCode int `json:"actual_return_code"` + Success bool `json:"success"` +} + +type TaskWithOutput struct { + Task hfv1.Task `json:"task"` + TaskOutput TaskOutputCommand `json:"task_output"` + Error string `json:"error"` +} + +func isMatchRegex(text, pattern string) bool { + re := regexp.MustCompile(pattern) + return re.MatchString(text) +} + +/* +Function executes a command on a remote server session and checks for success based on the provided task command. +It returns TaskOutputCommand struct containing the actual output value, return code, and success status, depen of ReturnType. +*/ +func VMTaskCommandRun(task_cmd *hfv1.Task, sess *ssh.Session) (*TaskOutputCommand, error) { + out, err := sess.CombinedOutput(task_cmd.Command) + actual_output_value := strings.TrimRight(string(out), "\r\n") + actual_return_code := 0 + if err != nil { + switch err.(type) { + case *ssh.ExitError: + actual_return_code = err.(*ssh.ExitError).ExitStatus() + glog.Infof("%v", actual_return_code) + default: + return nil, err + } + } + + is_task_success := false + switch task_cmd.ReturnType { + case "Return_Code_And_Text": + is_task_success = task_cmd.ExpectedOutputValue == actual_output_value && task_cmd.ExpectedReturnCode == actual_return_code + break + case "Return_Code": + is_task_success = task_cmd.ExpectedReturnCode == actual_return_code + break + case "Return_Text": + is_task_success = task_cmd.ExpectedOutputValue == actual_output_value + break + case "Match_Regex": + if !isMatchRegex(actual_output_value, task_cmd.ExpectedOutputValue) { + actual_output_value = "regex:error" + } + is_task_success = actual_output_value != "regex:error" + break + default: + actual_output_value = "undefined ReturnType" + is_task_success = false + } + + task_cmd_res := &TaskOutputCommand{ + ActualOutputValue: actual_output_value, + ActualReturnCode: actual_return_code, + Success: is_task_success, + } + return task_cmd_res, nil +} + +/* +Function retrieves output tasks from a virtual machine by executing multiple commands concurrently on the SSH connection. +It takes an SSH client connection, a VirtualMachineInputTask representing input tasks for the VM, +and an error channel to report any errors encountered during execution. +It returns a VirtualMachineOutputTask containing the output of the executed tasks, along with any errors encountered. +*/ +func GetVMOutputTask(sshConn *ssh.Client, closure_vm_input_task VirtualMachineInputTask, errorChan chan<- error) (*VirtualMachineOutputTask, error) { + // TODO: settings for define max command go routine run in same time in VM + const MAX_COMMANDS_GO = 3 + // Initialize slice to store task outputs + commands_resp := make([]TaskWithOutput, 0) + // Mutex for synchronizing access to commands_resp slice + var commands_mutex = &sync.Mutex{} + // WaitGroup to wait for all goroutines in VM to finish + var commands_wg sync.WaitGroup + // Semaphore for count goroutine run in same time in VM + // a context is required for the weighted semaphore pkg. + ctx := context.Background() + var commands_semaphore = semaphore.NewWeighted(int64(MAX_COMMANDS_GO)) + + for _, task_command := range closure_vm_input_task.Tasks { + commands_wg.Add(1) + if err := commands_semaphore.Acquire(ctx, 1); err != nil { + glog.Errorf("did not acquire vm_semafore") + } + go func(closure_task_command hfv1.Task) { + defer commands_wg.Done() + defer commands_semaphore.Release(1) + vm_task_with_output, _ := GetTaskWithOutput(sshConn, errorChan, closure_task_command) + commands_mutex.Lock() + commands_resp = append(commands_resp, *vm_task_with_output) + commands_mutex.Unlock() + + }(task_command) + } + commands_wg.Wait() + vm_output_task := &VirtualMachineOutputTask{ + VMId: closure_vm_input_task.VMId, + VMName: closure_vm_input_task.VMName, + TaskOutputs: commands_resp, + } + return vm_output_task, nil +} + +/* +Function executes a task command on the SSH connection with number of attempts MAX_TRY_COMMAND_RUN +to retrieve output when the task command execute with return code 141. +*/ +func GetTaskWithOutput(sshConn *ssh.Client, errorChan chan<- error, task_command hfv1.Task) (*TaskWithOutput, error) { + // TODO: settings for define max try command run in VM if return code 141 + const MAX_TRY_COMMAND_RUN = 5 + count_try_command_run := MAX_TRY_COMMAND_RUN + var errRun error + // try command run again when exit code == 141 + for count_try_command_run > 0 { + task_output, err := GetOutputTask(sshConn, errorChan, task_command) + count_try_command_run -= 1 + if task_output.ActualReturnCode != 141 { + vm_task_with_output := &TaskWithOutput{ + Task: task_command, + TaskOutput: *task_output, + } + return vm_task_with_output, nil + } + if count_try_command_run == 0 { + glog.Errorf("error try run command: %v", err) + vm_task_with_output := &TaskWithOutput{ + Task: task_command, + Error: "error try run command", + } + return vm_task_with_output, err + } + errRun = err + } + return nil, errRun +} + +func GetOutputTask(sshConn *ssh.Client, errorChan chan<- error, closure_task_command hfv1.Task) (*TaskOutputCommand, error) { + sess, err := CreateNewSession(sshConn, errorChan) + if err != nil { + return nil, err + } + task_output, err := VMTaskCommandRun(&closure_task_command, sess) + if err != nil { + glog.Infof("error sending command: %v", err) + if len(errorChan) < cap(errorChan) { + errorChan <- err + } + return nil, err + } + sess.Close() + return task_output, nil +} + +func CreateNewSession(sshConn *ssh.Client, errorChan chan<- error) (*ssh.Session, error) { + sess, err := sshConn.NewSession() + if err != nil { + glog.Errorf("did not setup ssh session properly") + if len(errorChan) < cap(errorChan) { + errorChan <- err + } + return nil, err + } + return sess, nil +} + +func (sp ShellProxy) GetSSHConn(w http.ResponseWriter, r *http.Request, user *userProto.User, vmId string, errorChan chan<- error) (*ssh.Client, error) { + + vm, err := sp.vmClient.GetVirtualMachineById(vmId) + + if err != nil { + glog.Errorf("did not find the right virtual machine ID") + if len(errorChan) < cap(errorChan) { + errorChan <- err + } + return nil, err + } + if vm.Spec.UserId != user.GetId() { + // check if the user has access to access user sessions + // TODO: add permission like 'virtualmachine/shell' similar to 'pod/exec' + impersonatedUserId := user.GetId() + authrResponse, err := rbac2.Authorize(r, sp.authrClient, impersonatedUserId, []*authr.Permission{ + rbac2.HobbyfarmPermission(rbac2.ResourcePluralUser, rbac2.VerbGet), + rbac2.HobbyfarmPermission(rbac2.ResourcePluralSession, rbac2.VerbGet), + rbac2.HobbyfarmPermission(rbac2.ResourcePluralVM, rbac2.VerbGet), + }, rbac2.OperatorAND) + if err != nil || !authrResponse.Success { + glog.Infof("Error doing authGrantWS %s", err) + util.ReturnHTTPMessage(w, r, 403, "forbidden", "access denied to connect to ssh shell session") + return nil, err + } + } + + // ok first get the secret for the vm + secret, err := sp.kubeClient.CoreV1().Secrets(util.GetReleaseNamespace()).Get(sp.ctx, vm.Spec.SecretName, v1.GetOptions{}) // idk? + if err != nil { + glog.Errorf("did not find secret for virtual machine") + util.ReturnHTTPMessage(w, r, 500, "error", "unable to find keypair secret for vm") + return nil, err + } + + // parse the private key + signer, err := ssh.ParsePrivateKey(secret.Data["private_key"]) + if err != nil { + glog.Errorf("did not correctly parse private key") + util.ReturnHTTPMessage(w, r, 500, "error", "unable to parse private key") + return nil, err + } + + sshUsername := vm.Spec.SshUsername + if len(sshUsername) < 1 { + sshUsername = defaultSshUsername + } + + // now use the secret and ssh off to something + config := &ssh.ClientConfig{ + User: sshUsername, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + // get the host and port + host, ok := vm.Annotations["sshEndpoint"] + if !ok { + host = vm.Status.PublicIP + } + port := "22" + if sshDev == "true" { + if sshDevHost != "" { + host = sshDevHost + } + if sshDevPort != "" { + port = sshDevPort + } + } + + // dial the instance + sshConn, err := ssh.Dial("tcp", host+":"+port, config) + if err != nil { + glog.Errorf("did not connect ssh successfully: %s", err) + if len(errorChan) < cap(errorChan) { + errorChan <- err + } + return nil, err + } + return sshConn, err +} + +/* +Function handles the HTTP request to verify tasks for a group of virtual machines using a semaphore for concurrency control. +It authenticates the request, decodes the incoming JSON payload containing VirtualMachineInputTasks, +and executes the tasks concurrently on the corresponding virtual machines. +*/ +func (sp ShellProxy) VerifyTasksFuncByVMIdGroupWithSemaphore(w http.ResponseWriter, r *http.Request) { + user, err := rbac2.AuthenticateRequest(r, sp.authnClient) + if err != nil { + util.ReturnHTTPMessage(w, r, 403, "forbidden", "no access to get vm") + return + } + + // Decode the incoming JSON payload containing VirtualMachineInputTasks + var vm_input_tasks []VirtualMachineInputTask + err = json.NewDecoder(r.Body).Decode(&vm_input_tasks) + if err != nil { + glog.Infof("%s", err) + } + + // Create an error channel to report errors encountered during task execution + errorChan := make(chan error, 1) + + // Initialize slice to store the output tasks for each virtual machine + vm_output_tasks := make([]VirtualMachineOutputTask, 0) + // Mutex for synchronizing access to vm_output_tasks slice + var vm_mutex = &sync.Mutex{} + // WaitGroup to wait for all goroutines of VMs to finish + var vm_wg sync.WaitGroup + + for _, vm_input_task := range vm_input_tasks { + vm_wg.Add(1) + go func(closure_vm_input_task VirtualMachineInputTask) { + defer vm_wg.Done() + + sshConn, err := sp.GetSSHConn(w, r, user, closure_vm_input_task.VMId, errorChan) + if err != nil { + return + } + vm_output_task, err := GetVMOutputTask(sshConn, closure_vm_input_task, errorChan) + if err != nil { + return + } + vm_mutex.Lock() + vm_output_tasks = append(vm_output_tasks, *vm_output_task) + vm_mutex.Unlock() + }(vm_input_task) + } + vm_wg.Wait() + + // Check for errors in the errorChan + select { + case err = <-errorChan: + // Handle the error (log, return HTTP error response) + close(errorChan) + glog.Infof("Error in goroutine: %v", err) + util.ReturnHTTPMessage(w, r, 500, "error", "could send command to vm") + return + default: + // No error in the errorChan + glog.Infof("No Error in goroutine: %v", vm_output_tasks) + jsonStr, _ := json.Marshal(vm_output_tasks) + util.ReturnHTTPContent(w, r, 200, "success", jsonStr) + } +} + /* * This is mainly used for SSH Connections to VMs */