diff --git a/charts/kaito/workspace/templates/clusterrole.yaml b/charts/kaito/workspace/templates/clusterrole.yaml index 742c1bbed..5e24c46e1 100644 --- a/charts/kaito/workspace/templates/clusterrole.yaml +++ b/charts/kaito/workspace/templates/clusterrole.yaml @@ -30,6 +30,9 @@ rules: - apiGroups: [ "apps" ] resources: ["deployments" ] verbs: ["get","list","watch","create", "delete","update", "patch"] + - apiGroups: [ "apps" ] + resources: ["controllerrevisions" ] + verbs: [ "get","list","watch","create", "delete","update", "patch"] - apiGroups: [ "apps" ] resources: [ "statefulsets" ] verbs: [ "get","list","watch","create", "delete","update", "patch" ] diff --git a/cmd/main.go b/cmd/main.go index 36adac881..e4b062626 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -26,6 +26,7 @@ import ( "knative.dev/pkg/injection/sharedmain" "knative.dev/pkg/signals" "knative.dev/pkg/webhook" + // Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.) // to ensure that exec-entrypoint and run can make use of them. _ "k8s.io/client-go/plugin/pkg/client/auth" @@ -116,12 +117,10 @@ func main() { k8sclient.SetGlobalClient(mgr.GetClient()) kClient := k8sclient.GetGlobalClient() - if err = (&controllers.WorkspaceReconciler{ - Client: kClient, - Log: log.Log.WithName("controllers").WithName("Workspace"), - Scheme: mgr.GetScheme(), - Recorder: mgr.GetEventRecorderFor("KAITO-Workspace-controller"), - }).SetupWithManager(mgr); err != nil { + workspaceReconciler := controllers.NewWorkspaceReconciler(k8sclient.GetGlobalClient(), + mgr.GetScheme(), log.Log.WithName("controllers").WithName("Workspace"), mgr.GetEventRecorderFor("KAITO-Workspace-controller")) + + if err = workspaceReconciler.SetupWithManager(mgr); err != nil { klog.ErrorS(err, "unable to create controller", "controller", "Workspace") exitWithErrorFunc() } diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index cd5f63920..3441079e9 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -5,6 +5,9 @@ package controllers import ( "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "fmt" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sort" @@ -46,8 +49,10 @@ import ( ) const ( - gpuSkuPrefix = "Standard_N" - nodePluginInstallTimeout = 60 * time.Second + gpuSkuPrefix = "Standard_N" + nodePluginInstallTimeout = 60 * time.Second + WorkspaceRevisionAnnotation = "workspace.kaito.io/revision" + WorkspaceNameLabel = "workspace.kaito.io/name" ) type WorkspaceReconciler struct { @@ -57,6 +62,15 @@ type WorkspaceReconciler struct { Recorder record.EventRecorder } +func NewWorkspaceReconciler(client client.Client, scheme *runtime.Scheme, log logr.Logger, Recorder record.EventRecorder) *WorkspaceReconciler { + return &WorkspaceReconciler{ + Client: client, + Scheme: scheme, + Log: log, + Recorder: Recorder, + } +} + func (c *WorkspaceReconciler) Reconcile(ctx context.Context, req reconcile.Request) (reconcile.Result, error) { workspaceObj := &kaitov1alpha1.Workspace{} if err := c.Client.Get(ctx, req.NamespacedName, workspaceObj); err != nil { @@ -83,7 +97,17 @@ func (c *WorkspaceReconciler) Reconcile(ctx context.Context, req reconcile.Reque } } - return c.addOrUpdateWorkspace(ctx, workspaceObj) + result, err := c.addOrUpdateWorkspace(ctx, workspaceObj) + if err != nil { + return result, err + } + + if err := c.updateControllerRevision(ctx, workspaceObj); err != nil { + klog.ErrorS(err, "failed to update ControllerRevision", "workspace", klog.KObj(workspaceObj)) + return reconcile.Result{}, nil + } + + return result, nil } func (c *WorkspaceReconciler) ensureFinalizer(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace) error { @@ -180,6 +204,80 @@ func (c *WorkspaceReconciler) deleteWorkspace(ctx context.Context, wObj *kaitov1 return c.garbageCollectWorkspace(ctx, wObj) } +func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { + currentHash := computeHash(wObj) + annotations := wObj.GetAnnotations() + + latestHash, exists := annotations[WorkspaceRevisionAnnotation] + if exists && latestHash == currentHash { + return nil + } + + data := map[string]string{"hash": currentHash} + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal revision data: %w", err) + } + + revisions := &appsv1.ControllerRevisionList{} + if err := c.List(ctx, revisions, client.InNamespace(wObj.Namespace), client.MatchingLabels{WorkspaceNameLabel: wObj.Name}); err != nil { + return fmt.Errorf("failed to list revisions: %w", err) + } + sort.Slice(revisions.Items, func(i, j int) bool { + return revisions.Items[i].Revision < revisions.Items[j].Revision + }) + + revisionNum := int64(1) + + if len(revisions.Items) > 0 { + revisionNum = revisions.Items[len(revisions.Items)-1].Revision + 1 + } + + newRevision := &appsv1.ControllerRevision{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-%s", wObj.Name, currentHash[:8]), + Namespace: wObj.Namespace, + Labels: map[string]string{ + WorkspaceNameLabel: wObj.Name, + }, + Annotations: map[string]string{ + WorkspaceRevisionAnnotation: currentHash, + }, + OwnerReferences: []metav1.OwnerReference{ + *metav1.NewControllerRef(wObj, kaitov1alpha1.GroupVersion.WithKind("Workspace")), + }, + }, + Revision: revisionNum, + Data: runtime.RawExtension{Raw: jsonData}, + } + + annotations[WorkspaceRevisionAnnotation] = currentHash + wObj.SetAnnotations(annotations) + if err := c.Update(ctx, wObj); err != nil { + return fmt.Errorf("failed to update Workspace annotations: %w", err) + } + + if err := c.Create(ctx, newRevision); err != nil { + return fmt.Errorf("failed to create new ControllerRevision: %w", err) + } + + if len(revisions.Items) > consts.MaxRevisionHistoryLimit { + if err := c.Delete(ctx, &revisions.Items[0]); err != nil { + return fmt.Errorf("failed to delete old revision: %w", err) + } + } + return nil +} + +func computeHash(w *kaitov1alpha1.Workspace) string { + hasher := sha256.New() + encoder := json.NewEncoder(hasher) + encoder.Encode(w.Resource) + encoder.Encode(w.Inference) + encoder.Encode(w.Tuning) + return hex.EncodeToString(hasher.Sum(nil)) +} + func (c *WorkspaceReconciler) selectWorkspaceNodes(qualified []*corev1.Node, preferred []string, previous []string, count int) []*corev1.Node { sort.Slice(qualified, func(i, j int) bool { @@ -679,6 +777,7 @@ func (c *WorkspaceReconciler) SetupWithManager(mgr ctrl.Manager) error { builder := ctrl.NewControllerManagedBy(mgr). For(&kaitov1alpha1.Workspace{}). + Owns(&appsv1.ControllerRevision{}). Owns(&appsv1.Deployment{}). Owns(&appsv1.StatefulSet{}). Owns(&batchv1.Job{}). diff --git a/pkg/controllers/workspace_controller_test.go b/pkg/controllers/workspace_controller_test.go index b132da54a..a86078a06 100644 --- a/pkg/controllers/workspace_controller_test.go +++ b/pkg/controllers/workspace_controller_test.go @@ -6,6 +6,7 @@ package controllers import ( "context" "errors" + "fmt" "os" "reflect" "sort" @@ -954,3 +955,112 @@ func TestApplyWorkspaceResource(t *testing.T) { }) } } + +func TestUpdateControllerRevision(t *testing.T) { + testcases := map[string]struct { + callMocks func(c *test.MockClient) + workspace v1alpha1.Workspace + expectedError error + verifyCalls func(c *test.MockClient) + }{ + "No new revision needed": { + callMocks: func(c *test.MockClient) { + c.On("List", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(errors.New("should not be called")) + }, + workspace: test.MockWorkspaceWithComputeHash, + expectedError: nil, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 0) + c.AssertNumberOfCalls(t, "Update", 0) + c.AssertNumberOfCalls(t, "Delete", 0) + }, + }, + "Fail to create ControllerRevision": { + callMocks: func(c *test.MockClient) { + c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(errors.New("failed to create ControllerRevision")) + c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil) + }, + workspace: test.MockWorkspaceFailToCreateCR, + expectedError: errors.New("failed to create new ControllerRevision: failed to create ControllerRevision"), + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 1) + c.AssertNumberOfCalls(t, "Create", 1) + c.AssertNumberOfCalls(t, "Update", 1) + c.AssertNumberOfCalls(t, "Delete", 0) + }, + }, + "Successfully create new ControllerRevision": { + callMocks: func(c *test.MockClient) { + c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil) + c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil) + }, + workspace: test.MockWorkspaceSuccessful, + expectedError: nil, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 1) + c.AssertNumberOfCalls(t, "Create", 1) + c.AssertNumberOfCalls(t, "Update", 1) + c.AssertNumberOfCalls(t, "Delete", 0) + }, + }, + "Successfully delete old ControllerRevision": { + callMocks: func(c *test.MockClient) { + revisions := &appsv1.ControllerRevisionList{} + for i := 0; i <= consts.MaxRevisionHistoryLimit; i++ { + revision := &appsv1.ControllerRevision{ + ObjectMeta: v1.ObjectMeta{ + Name: fmt.Sprintf("revision-%d", i), + }, + Revision: int64(i), + } + revisions.Items = append(revisions.Items, *revision) + } + relevantMap := c.CreateMapWithType(revisions) + + for _, obj := range revisions.Items { + m := obj + objKey := client.ObjectKeyFromObject(&m) + relevantMap[objKey] = &m + } + + c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil) + c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil) + c.On("Delete", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil) + }, + workspace: test.MockWorkspaceWithDeleteOldCR, + expectedError: nil, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 1) + c.AssertNumberOfCalls(t, "Create", 1) + c.AssertNumberOfCalls(t, "Update", 1) + c.AssertNumberOfCalls(t, "Delete", 1) + }, + }, + } + for k, tc := range testcases { + t.Run(k, func(t *testing.T) { + mockClient := test.NewClient() + tc.callMocks(mockClient) + + reconciler := &WorkspaceReconciler{ + Client: mockClient, + Scheme: test.NewTestScheme(), + } + ctx := context.Background() + + err := reconciler.updateControllerRevision(ctx, &tc.workspace) + if tc.expectedError == nil { + assert.Check(t, err == nil, "Not expected to return error") + } else { + assert.Equal(t, tc.expectedError.Error(), err.Error()) + } + if tc.verifyCalls != nil { + tc.verifyCalls(mockClient) + } + }) + } +} diff --git a/pkg/utils/consts/consts.go b/pkg/utils/consts/consts.go index d5b15c1ca..31ba780b7 100644 --- a/pkg/utils/consts/consts.go +++ b/pkg/utils/consts/consts.go @@ -12,4 +12,5 @@ const ( AWSCloudName = "aws" GPUString = "gpu" SKUString = "sku" + MaxRevisionHistoryLimit = 10 ) diff --git a/pkg/utils/test/mockClient.go b/pkg/utils/test/mockClient.go index 98ec00ec0..4d2b7da79 100644 --- a/pkg/utils/test/mockClient.go +++ b/pkg/utils/test/mockClient.go @@ -9,6 +9,7 @@ import ( "github.com/aws/karpenter-core/pkg/apis/v1alpha5" "github.com/stretchr/testify/mock" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/runtime" @@ -83,7 +84,6 @@ func (m *MockClient) Get(ctx context.Context, key types.NamespacedName, obj k8sC } func (m *MockClient) List(ctx context.Context, list k8sClient.ObjectList, opts ...k8sClient.ListOption) error { - v := reflect.ValueOf(list).Elem() newList := m.getObjectListFromMap(list) v.Set(reflect.ValueOf(newList).Elem()) @@ -121,6 +121,14 @@ func (m *MockClient) getObjectListFromMap(list k8sClient.ObjectList) k8sClient.O } } return nodeClaimList + case *appsv1.ControllerRevisionList: + controllerRevisionList := &appsv1.ControllerRevisionList{} + for _, obj := range relevantMap { + if m, ok := obj.(*appsv1.ControllerRevision); ok { + controllerRevisionList.Items = append(controllerRevisionList.Items, *m) + } + } + return controllerRevisionList } //add additional object lists as needed return nil diff --git a/pkg/utils/test/testUtils.go b/pkg/utils/test/testUtils.go index bf766482f..d6cabcab3 100644 --- a/pkg/utils/test/testUtils.go +++ b/pkg/utils/test/testUtils.go @@ -72,6 +72,114 @@ var ( } ) +var ( + MockWorkspaceWithDeleteOldCR = v1alpha1.Workspace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testWorkspace", + Namespace: "kaito", + Annotations: map[string]string{ + "workspace.kaito.io/revision": "1171dc5d15043c92e684c8f06689eb241763a735181fdd2b59c8bd8fd6eecdd4", + }, + }, + Resource: v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "workspace.kaito.io/name": "testWorkspace", + }, + }, + }, + Inference: &v1alpha1.InferenceSpec{ + Preset: &v1alpha1.PresetSpec{ + PresetMeta: v1alpha1.PresetMeta{ + Name: "test-model-DeleteOldCR", // presetMeta name is changed + }, + }, + }, + } +) + +var ( + MockWorkspaceFailToCreateCR = v1alpha1.Workspace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testWorkspace-failedtocreateCR", + Namespace: "kaito", + Annotations: map[string]string{}, + }, + Resource: v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "workspace.kaito.io/name": "testWorkspace", + }, + }, + }, + Inference: &v1alpha1.InferenceSpec{ + Preset: &v1alpha1.PresetSpec{ + PresetMeta: v1alpha1.PresetMeta{ + Name: "test-model", + }, + }, + }, + } +) + +var ( + MockWorkspaceSuccessful = v1alpha1.Workspace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testWorkspace-successful", + Namespace: "kaito", + Annotations: map[string]string{}, + }, + Resource: v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "workspace.kaito.io/name": "testWorkspace", + }, + }, + }, + Inference: &v1alpha1.InferenceSpec{ + Preset: &v1alpha1.PresetSpec{ + PresetMeta: v1alpha1.PresetMeta{ + Name: "test-model", + }, + }, + }, + } +) + +var ( + MockWorkspaceWithComputeHash = v1alpha1.Workspace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testWorkspace", + Namespace: "kaito", + Annotations: map[string]string{ + "workspace.kaito.io/revision": "1171dc5d15043c92e684c8f06689eb241763a735181fdd2b59c8bd8fd6eecdd4", + }, + }, + Resource: v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "workspace.kaito.io/name": "testWorkspace", + }, + }, + }, + Inference: &v1alpha1.InferenceSpec{ + Preset: &v1alpha1.PresetSpec{ + PresetMeta: v1alpha1.PresetMeta{ + Name: "test-model", + }, + }, + }, + } +) + var ( MockWorkspaceWithInferenceTemplate = &v1alpha1.Workspace{ ObjectMeta: metav1.ObjectMeta{ diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index b479b5df1..16c3cfea1 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -8,8 +8,6 @@ import ( "fmt" "io" "io/ioutil" - "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" "log" "math/rand" "os" @@ -17,6 +15,9 @@ import ( "strings" "time" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" "github.com/samber/lo" "gopkg.in/yaml.v2" @@ -35,7 +36,7 @@ var ( // PollInterval defines the interval time for a poll operation. PollInterval = 2 * time.Second // PollTimeout defines the time after which the poll operation times out. - PollTimeout = 60 * time.Second + PollTimeout = 120 * time.Second ) func GetEnv(envVar string) string {