Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add controllerrevision for workspaceController #524

Merged
merged 4 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions charts/kaito/workspace/templates/clusterrole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ rules:
- apiGroups: [ "apps" ]
resources: ["deployments" ]
verbs: ["get","list","watch","create", "delete","update", "patch"]
- apiGroups: [ "apps" ]
resources: ["controllerrevisions" ]
verbs: [ "get","list","watch","create", "delete","update", "patch"]
- apiGroups: [ "apps" ]
resources: [ "statefulsets" ]
verbs: [ "get","list","watch","create", "delete","update", "patch" ]
Expand Down
11 changes: 5 additions & 6 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"knative.dev/pkg/injection/sharedmain"
"knative.dev/pkg/signals"
"knative.dev/pkg/webhook"

// Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.)
// to ensure that exec-entrypoint and run can make use of them.
_ "k8s.io/client-go/plugin/pkg/client/auth"
Expand Down Expand Up @@ -116,12 +117,10 @@ func main() {
k8sclient.SetGlobalClient(mgr.GetClient())
kClient := k8sclient.GetGlobalClient()

if err = (&controllers.WorkspaceReconciler{
Client: kClient,
Log: log.Log.WithName("controllers").WithName("Workspace"),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("KAITO-Workspace-controller"),
}).SetupWithManager(mgr); err != nil {
workspaceReconciler := controllers.NewWorkspaceReconciler(k8sclient.GetGlobalClient(),
mgr.GetScheme(), log.Log.WithName("controllers").WithName("Workspace"), mgr.GetEventRecorderFor("KAITO-Workspace-controller"))

if err = workspaceReconciler.SetupWithManager(mgr); err != nil {
klog.ErrorS(err, "unable to create controller", "controller", "Workspace")
exitWithErrorFunc()
}
Expand Down
105 changes: 102 additions & 3 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sort"
Expand Down Expand Up @@ -46,8 +49,10 @@
)

const (
gpuSkuPrefix = "Standard_N"
nodePluginInstallTimeout = 60 * time.Second
gpuSkuPrefix = "Standard_N"
nodePluginInstallTimeout = 60 * time.Second
WorkspaceRevisionAnnotation = "workspace.kaito.io/revision"
WorkspaceNameLabel = "workspace.kaito.io/name"
)

type WorkspaceReconciler struct {
Expand All @@ -57,6 +62,15 @@
Recorder record.EventRecorder
}

func NewWorkspaceReconciler(client client.Client, scheme *runtime.Scheme, log logr.Logger, Recorder record.EventRecorder) *WorkspaceReconciler {
return &WorkspaceReconciler{
Client: client,
Scheme: scheme,
Log: log,
Recorder: Recorder,

Check warning on line 70 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L65-L70

Added lines #L65 - L70 were not covered by tests
}
}

func (c *WorkspaceReconciler) Reconcile(ctx context.Context, req reconcile.Request) (reconcile.Result, error) {
workspaceObj := &kaitov1alpha1.Workspace{}
if err := c.Client.Get(ctx, req.NamespacedName, workspaceObj); err != nil {
Expand All @@ -83,7 +97,17 @@
}
}

return c.addOrUpdateWorkspace(ctx, workspaceObj)
result, err := c.addOrUpdateWorkspace(ctx, workspaceObj)
if err != nil {
return result, err

Check warning on line 102 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L100-L102

Added lines #L100 - L102 were not covered by tests
}

if err := c.updateControllerRevision(ctx, workspaceObj); err != nil {
klog.ErrorS(err, "failed to update ControllerRevision", "workspace", klog.KObj(workspaceObj))
return reconcile.Result{}, nil

Check warning on line 107 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L105-L107

Added lines #L105 - L107 were not covered by tests
}

return result, nil

Check warning on line 110 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L110

Added line #L110 was not covered by tests
}

func (c *WorkspaceReconciler) ensureFinalizer(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace) error {
Expand Down Expand Up @@ -180,6 +204,80 @@
return c.garbageCollectWorkspace(ctx, wObj)
}

func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
currentHash := computeHash(wObj)
annotations := wObj.GetAnnotations()

latestHash, exists := annotations[WorkspaceRevisionAnnotation]
if exists && latestHash == currentHash {
return nil
}

data := map[string]string{"hash": currentHash}
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal revision data: %w", err)

Check warning on line 219 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L219

Added line #L219 was not covered by tests
}

revisions := &appsv1.ControllerRevisionList{}
if err := c.List(ctx, revisions, client.InNamespace(wObj.Namespace), client.MatchingLabels{WorkspaceNameLabel: wObj.Name}); err != nil {
return fmt.Errorf("failed to list revisions: %w", err)

Check warning on line 224 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L224

Added line #L224 was not covered by tests
}
sort.Slice(revisions.Items, func(i, j int) bool {
return revisions.Items[i].Revision < revisions.Items[j].Revision
})

revisionNum := int64(1)

if len(revisions.Items) > 0 {
revisionNum = revisions.Items[len(revisions.Items)-1].Revision + 1
}

newRevision := &appsv1.ControllerRevision{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("%s-%s", wObj.Name, currentHash[:8]),
Namespace: wObj.Namespace,
Labels: map[string]string{
WorkspaceNameLabel: wObj.Name,
},
Annotations: map[string]string{
WorkspaceRevisionAnnotation: currentHash,
},
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(wObj, kaitov1alpha1.GroupVersion.WithKind("Workspace")),
},
},
Revision: revisionNum,
Data: runtime.RawExtension{Raw: jsonData},
}

annotations[WorkspaceRevisionAnnotation] = currentHash
wObj.SetAnnotations(annotations)
if err := c.Update(ctx, wObj); err != nil {
return fmt.Errorf("failed to update Workspace annotations: %w", err)

Check warning on line 257 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L257

Added line #L257 was not covered by tests
}

if err := c.Create(ctx, newRevision); err != nil {
return fmt.Errorf("failed to create new ControllerRevision: %w", err)
}

if len(revisions.Items) > consts.MaxRevisionHistoryLimit {
if err := c.Delete(ctx, &revisions.Items[0]); err != nil {
return fmt.Errorf("failed to delete old revision: %w", err)

Check warning on line 266 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L266

Added line #L266 was not covered by tests
}
}
return nil
}

func computeHash(w *kaitov1alpha1.Workspace) string {
hasher := sha256.New()
encoder := json.NewEncoder(hasher)
encoder.Encode(w.Resource)
encoder.Encode(w.Inference)
encoder.Encode(w.Tuning)
return hex.EncodeToString(hasher.Sum(nil))
}

func (c *WorkspaceReconciler) selectWorkspaceNodes(qualified []*corev1.Node, preferred []string, previous []string, count int) []*corev1.Node {

sort.Slice(qualified, func(i, j int) bool {
Expand Down Expand Up @@ -679,6 +777,7 @@

builder := ctrl.NewControllerManagedBy(mgr).
For(&kaitov1alpha1.Workspace{}).
Owns(&appsv1.ControllerRevision{}).

Check warning on line 780 in pkg/controllers/workspace_controller.go

View check run for this annotation

Codecov / codecov/patch

pkg/controllers/workspace_controller.go#L780

Added line #L780 was not covered by tests
Owns(&appsv1.Deployment{}).
Owns(&appsv1.StatefulSet{}).
Owns(&batchv1.Job{}).
Expand Down
110 changes: 110 additions & 0 deletions pkg/controllers/workspace_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package controllers
import (
"context"
"errors"
"fmt"
"os"
"reflect"
"sort"
Expand Down Expand Up @@ -954,3 +955,112 @@ func TestApplyWorkspaceResource(t *testing.T) {
})
}
}

func TestUpdateControllerRevision(t *testing.T) {
testcases := map[string]struct {
callMocks func(c *test.MockClient)
workspace v1alpha1.Workspace
expectedError error
verifyCalls func(c *test.MockClient)
}{
"No new revision needed": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(errors.New("should not be called"))
},
workspace: test.MockWorkspaceWithComputeHash,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Update", 0)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
"Fail to create ControllerRevision": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(errors.New("failed to create ControllerRevision"))
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceFailToCreateCR,
expectedError: errors.New("failed to create new ControllerRevision: failed to create ControllerRevision"),
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
"Successfully create new ControllerRevision": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceSuccessful,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
"Successfully delete old ControllerRevision": {
callMocks: func(c *test.MockClient) {
revisions := &appsv1.ControllerRevisionList{}
for i := 0; i <= consts.MaxRevisionHistoryLimit; i++ {
revision := &appsv1.ControllerRevision{
ObjectMeta: v1.ObjectMeta{
Name: fmt.Sprintf("revision-%d", i),
},
Revision: int64(i),
}
revisions.Items = append(revisions.Items, *revision)
}
relevantMap := c.CreateMapWithType(revisions)

for _, obj := range revisions.Items {
m := obj
objKey := client.ObjectKeyFromObject(&m)
relevantMap[objKey] = &m
}

c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
c.On("Delete", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceWithDeleteOldCR,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Delete", 1)
},
},
}
for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := test.NewClient()
tc.callMocks(mockClient)

reconciler := &WorkspaceReconciler{
Client: mockClient,
Scheme: test.NewTestScheme(),
}
ctx := context.Background()

err := reconciler.updateControllerRevision(ctx, &tc.workspace)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
if tc.verifyCalls != nil {
tc.verifyCalls(mockClient)
}
})
}
}
1 change: 1 addition & 0 deletions pkg/utils/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ const (
AWSCloudName = "aws"
GPUString = "gpu"
SKUString = "sku"
MaxRevisionHistoryLimit = 10
)
10 changes: 9 additions & 1 deletion pkg/utils/test/mockClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
"github.com/stretchr/testify/mock"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -83,7 +84,6 @@ func (m *MockClient) Get(ctx context.Context, key types.NamespacedName, obj k8sC
}

func (m *MockClient) List(ctx context.Context, list k8sClient.ObjectList, opts ...k8sClient.ListOption) error {

v := reflect.ValueOf(list).Elem()
newList := m.getObjectListFromMap(list)
v.Set(reflect.ValueOf(newList).Elem())
Expand Down Expand Up @@ -121,6 +121,14 @@ func (m *MockClient) getObjectListFromMap(list k8sClient.ObjectList) k8sClient.O
}
}
return nodeClaimList
case *appsv1.ControllerRevisionList:
controllerRevisionList := &appsv1.ControllerRevisionList{}
for _, obj := range relevantMap {
if m, ok := obj.(*appsv1.ControllerRevision); ok {
controllerRevisionList.Items = append(controllerRevisionList.Items, *m)
}
}
return controllerRevisionList
}
//add additional object lists as needed
return nil
Expand Down
Loading
Loading