From edb176f011b29394e458a1ee92e7c6d661efdd22 Mon Sep 17 00:00:00 2001 From: ffforest Date: Wed, 17 Apr 2024 18:26:37 +0800 Subject: [PATCH] refactor: offload heavy logic from handler to manager --- pkg/engine/api/apply_test.go | 125 ++++ pkg/engine/api/destroy_test.go | 166 +++++ pkg/engine/api/preview_test.go | 104 ++++ pkg/engine/api/source/source.go | 2 +- pkg/server/handler/backend/handler.go | 129 +--- pkg/server/handler/backend/handler_test.go | 299 +++++++++ pkg/server/handler/backend/types.go | 20 +- pkg/server/handler/organization/handler.go | 130 +--- .../handler/organization/handler_test.go | 302 +++++++++ pkg/server/handler/organization/types.go | 20 +- pkg/server/handler/project/handler.go | 168 +---- pkg/server/handler/project/handler_test.go | 333 ++++++++++ pkg/server/handler/project/types.go | 28 +- pkg/server/handler/source/handler.go | 143 +---- pkg/server/handler/source/handler_test.go | 7 +- pkg/server/handler/source/types.go | 20 +- pkg/server/handler/stack/execute.go | 588 ++---------------- pkg/server/handler/stack/handler.go | 169 +---- pkg/server/handler/stack/handler_test.go | 330 ++++++++++ pkg/server/handler/stack/types.go | 42 +- pkg/server/handler/types.go | 1 - pkg/server/handler/util.go | 3 +- pkg/server/handler/workspace/handler.go | 138 +--- pkg/server/handler/workspace/handler_test.go | 307 +++++++++ pkg/server/handler/workspace/types.go | 24 +- pkg/server/manager/backend/backend_manager.go | 86 +++ pkg/server/manager/backend/types.go | 23 + .../organization/organization_manager.go | 91 +++ pkg/server/manager/organization/types.go | 23 + pkg/server/manager/project/project_manager.go | 120 ++++ pkg/server/manager/project/types.go | 29 + pkg/server/manager/source/source_manager.go | 101 +++ pkg/server/manager/source/types.go | 23 + pkg/server/manager/stack/stack_manager.go | 335 ++++++++++ pkg/server/manager/stack/types.go | 30 + pkg/server/{handler => manager}/stack/util.go | 167 +++-- pkg/server/manager/workspace/types.go | 26 + .../manager/workspace/workspace_manager.go | 95 +++ pkg/server/route/route.go | 47 +- pkg/server/route/route_test.go | 59 ++ 40 files changed, 3438 insertions(+), 1415 deletions(-) create mode 100644 pkg/engine/api/apply_test.go create mode 100644 pkg/engine/api/destroy_test.go create mode 100644 pkg/engine/api/preview_test.go create mode 100644 pkg/server/handler/backend/handler_test.go create mode 100644 pkg/server/handler/organization/handler_test.go create mode 100644 pkg/server/handler/project/handler_test.go create mode 100644 pkg/server/handler/stack/handler_test.go create mode 100644 pkg/server/handler/workspace/handler_test.go create mode 100644 pkg/server/manager/backend/backend_manager.go create mode 100644 pkg/server/manager/backend/types.go create mode 100644 pkg/server/manager/organization/organization_manager.go create mode 100644 pkg/server/manager/organization/types.go create mode 100644 pkg/server/manager/project/project_manager.go create mode 100644 pkg/server/manager/project/types.go create mode 100644 pkg/server/manager/source/source_manager.go create mode 100644 pkg/server/manager/source/types.go create mode 100644 pkg/server/manager/stack/stack_manager.go create mode 100644 pkg/server/manager/stack/types.go rename pkg/server/{handler => manager}/stack/util.go (51%) create mode 100644 pkg/server/manager/workspace/types.go create mode 100644 pkg/server/manager/workspace/workspace_manager.go create mode 100644 pkg/server/route/route_test.go diff --git a/pkg/engine/api/apply_test.go b/pkg/engine/api/apply_test.go new file mode 100644 index 00000000..2639e74f --- /dev/null +++ b/pkg/engine/api/apply_test.go @@ -0,0 +1,125 @@ +// Copyright 2024 KusionStack Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/bytedance/mockey" + "github.com/stretchr/testify/assert" + + apiv1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" + v1 "kusionstack.io/kusion/pkg/apis/status/v1" + "kusionstack.io/kusion/pkg/engine/operation" + "kusionstack.io/kusion/pkg/engine/operation/models" + statestorages "kusionstack.io/kusion/pkg/engine/state/storages" +) + +func TestApply(t *testing.T) { + stateStorage := statestorages.NewLocalStorage(filepath.Join("", "state.yaml")) + mockey.PatchConvey("dry run", t, func() { + planResources := &apiv1.Spec{Resources: []apiv1.Resource{sa1}} + order := &models.ChangeOrder{ + StepKeys: []string{sa1.ID}, + ChangeSteps: map[string]*models.ChangeStep{ + sa1.ID: { + ID: sa1.ID, + Action: models.Create, + From: sa1, + }, + }, + } + changes := models.NewChanges(proj, stack, order) + o := &APIOptions{} + o.DryRun = true + err := Apply(o, stateStorage, planResources, changes, os.Stdout) + assert.Nil(t, err) + }) + mockey.PatchConvey("apply success", t, func() { + mockOperationApply(models.Success) + o := &APIOptions{} + planResources := &apiv1.Spec{Resources: []apiv1.Resource{sa1, sa2}} + order := &models.ChangeOrder{ + StepKeys: []string{sa1.ID, sa2.ID}, + ChangeSteps: map[string]*models.ChangeStep{ + sa1.ID: { + ID: sa1.ID, + Action: models.Create, + From: &sa1, + }, + sa2.ID: { + ID: sa2.ID, + Action: models.UnChanged, + From: &sa2, + }, + }, + } + changes := models.NewChanges(proj, stack, order) + + err := Apply(o, stateStorage, planResources, changes, os.Stdout) + assert.Nil(t, err) + }) + mockey.PatchConvey("apply failed", t, func() { + mockOperationApply(models.Failed) + + o := &APIOptions{} + planResources := &apiv1.Spec{Resources: []apiv1.Resource{sa1}} + order := &models.ChangeOrder{ + StepKeys: []string{sa1.ID}, + ChangeSteps: map[string]*models.ChangeStep{ + sa1.ID: { + ID: sa1.ID, + Action: models.Create, + From: &sa1, + }, + }, + } + changes := models.NewChanges(proj, stack, order) + + err := Apply(o, stateStorage, planResources, changes, os.Stdout) + assert.NotNil(t, err) + }) +} + +func mockOperationApply(res models.OpResult) { + mockey.Mock((*operation.ApplyOperation).Apply).To( + func(o *operation.ApplyOperation, request *operation.ApplyRequest) (*operation.ApplyResponse, v1.Status) { + var err error + if res == models.Failed { + err = errors.New("mock error") + } + for _, r := range request.Intent.Resources { + // ing -> $res + o.MsgCh <- models.Message{ + ResourceID: r.ResourceKey(), + OpResult: "", + OpErr: nil, + } + o.MsgCh <- models.Message{ + ResourceID: r.ResourceKey(), + OpResult: res, + OpErr: err, + } + } + close(o.MsgCh) + if res == models.Failed { + return nil, v1.NewErrorStatus(err) + } + return &operation.ApplyResponse{}, nil + }).Build() +} diff --git a/pkg/engine/api/destroy_test.go b/pkg/engine/api/destroy_test.go new file mode 100644 index 00000000..b610bde4 --- /dev/null +++ b/pkg/engine/api/destroy_test.go @@ -0,0 +1,166 @@ +// Copyright 2024 KusionStack Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "errors" + "path/filepath" + "testing" + + "github.com/bytedance/mockey" + "github.com/stretchr/testify/assert" + + apiv1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" + v1 "kusionstack.io/kusion/pkg/apis/status/v1" + "kusionstack.io/kusion/pkg/engine/operation" + "kusionstack.io/kusion/pkg/engine/operation/models" + "kusionstack.io/kusion/pkg/engine/runtime" + "kusionstack.io/kusion/pkg/engine/runtime/kubernetes" + statestorages "kusionstack.io/kusion/pkg/engine/state/storages" +) + +func TestDestroyPreview(t *testing.T) { + stateStorage := statestorages.NewLocalStorage(filepath.Join("", "state.yaml")) + mockey.PatchConvey("preview success", t, func() { + mockNewKubernetesRuntime() + mockOperationPreview() + + o := &APIOptions{} + _, err := DestroyPreview(o, &apiv1.Spec{Resources: []apiv1.Resource{sa1}}, proj, stack, stateStorage) + assert.Nil(t, err) + }) +} + +func mockNewKubernetesRuntime() { + mockey.Mock(kubernetes.NewKubernetesRuntime).To(func() (runtime.Runtime, error) { + return &fakerRuntime{}, nil + }).Build() +} + +var _ runtime.Runtime = (*fakerRuntime)(nil) + +type fakerRuntime struct{} + +func (f *fakerRuntime) Import(_ context.Context, request *runtime.ImportRequest) *runtime.ImportResponse { + return &runtime.ImportResponse{Resource: request.PlanResource} +} + +func (f *fakerRuntime) Apply(_ context.Context, request *runtime.ApplyRequest) *runtime.ApplyResponse { + return &runtime.ApplyResponse{ + Resource: request.PlanResource, + Status: nil, + } +} + +func (f *fakerRuntime) Read(_ context.Context, request *runtime.ReadRequest) *runtime.ReadResponse { + if request.PlanResource.ResourceKey() == "fake-id" { + return &runtime.ReadResponse{ + Resource: nil, + Status: nil, + } + } + return &runtime.ReadResponse{ + Resource: request.PlanResource, + Status: nil, + } +} + +func (f *fakerRuntime) Delete(_ context.Context, _ *runtime.DeleteRequest) *runtime.DeleteResponse { + return nil +} + +func (f *fakerRuntime) Watch(_ context.Context, _ *runtime.WatchRequest) *runtime.WatchResponse { + return nil +} + +func TestDestroy(t *testing.T) { + stateStorage := statestorages.NewLocalStorage(filepath.Join("", "state.yaml")) + mockey.PatchConvey("destroy success", t, func() { + mockNewKubernetesRuntime() + mockOperationDestroy(models.Success) + + o := &APIOptions{} + planResources := &apiv1.Spec{Resources: []apiv1.Resource{sa2}} + order := &models.ChangeOrder{ + StepKeys: []string{sa1.ID, sa2.ID}, + ChangeSteps: map[string]*models.ChangeStep{ + sa1.ID: { + ID: sa1.ID, + Action: models.Delete, + From: nil, + }, + sa2.ID: { + ID: sa2.ID, + Action: models.UnChanged, + From: &sa2, + }, + }, + } + changes := models.NewChanges(proj, stack, order) + + err := Destroy(o, planResources, changes, stateStorage) + assert.Nil(t, err) + }) + mockey.PatchConvey("destroy failed", t, func() { + mockNewKubernetesRuntime() + mockOperationDestroy(models.Failed) + + o := &APIOptions{} + planResources := &apiv1.Spec{Resources: []apiv1.Resource{sa1}} + order := &models.ChangeOrder{ + StepKeys: []string{sa1.ID}, + ChangeSteps: map[string]*models.ChangeStep{ + sa1.ID: { + ID: sa1.ID, + Action: models.Delete, + From: nil, + }, + }, + } + changes := models.NewChanges(proj, stack, order) + + err := Destroy(o, planResources, changes, stateStorage) + assert.NotNil(t, err) + }) +} + +func mockOperationDestroy(res models.OpResult) { + mockey.Mock((*operation.DestroyOperation).Destroy).To( + func(o *operation.DestroyOperation, request *operation.DestroyRequest) v1.Status { + var err error + if res == models.Failed { + err = errors.New("mock error") + } + for _, r := range request.Intent.Resources { + // ing -> $res + o.MsgCh <- models.Message{ + ResourceID: r.ResourceKey(), + OpResult: "", + OpErr: nil, + } + o.MsgCh <- models.Message{ + ResourceID: r.ResourceKey(), + OpResult: res, + OpErr: err, + } + } + close(o.MsgCh) + if res == models.Failed { + return v1.NewErrorStatus(err) + } + return nil + }).Build() +} diff --git a/pkg/engine/api/preview_test.go b/pkg/engine/api/preview_test.go new file mode 100644 index 00000000..dd2ce974 --- /dev/null +++ b/pkg/engine/api/preview_test.go @@ -0,0 +1,104 @@ +// Copyright 2024 KusionStack Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "path/filepath" + "testing" + + "github.com/bytedance/mockey" + "github.com/stretchr/testify/assert" + + apiv1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" + v1 "kusionstack.io/kusion/pkg/apis/status/v1" + "kusionstack.io/kusion/pkg/engine" + "kusionstack.io/kusion/pkg/engine/operation" + "kusionstack.io/kusion/pkg/engine/operation/models" + statestorages "kusionstack.io/kusion/pkg/engine/state/storages" +) + +var ( + apiVersion = "v1" + kind = "ServiceAccount" + namespace = "test-ns" + + proj = &apiv1.Project{ + Name: "testdata", + } + stack = &apiv1.Stack{ + Name: "dev", + } + + sa1 = newSA("sa1") + sa2 = newSA("sa2") + sa3 = newSA("sa3") +) + +func TestPreview(t *testing.T) { + stateStorage := statestorages.NewLocalStorage(filepath.Join("", "state.yaml")) + t.Run("preview success", func(t *testing.T) { + m := mockOperationPreview() + defer m.UnPatch() + + o := &APIOptions{} + _, err := Preview(o, stateStorage, &apiv1.Spec{Resources: []apiv1.Resource{sa1, sa2, sa3}}, proj, stack) + assert.Nil(t, err) + }) +} + +func mockOperationPreview() *mockey.Mocker { + return mockey.Mock((*operation.PreviewOperation).Preview).To(func( + *operation.PreviewOperation, + *operation.PreviewRequest, + ) (rsp *operation.PreviewResponse, s v1.Status) { + return &operation.PreviewResponse{ + Order: &models.ChangeOrder{ + StepKeys: []string{sa1.ID, sa2.ID, sa3.ID}, + ChangeSteps: map[string]*models.ChangeStep{ + sa1.ID: { + ID: sa1.ID, + Action: models.Create, + From: &sa1, + }, + sa2.ID: { + ID: sa2.ID, + Action: models.UnChanged, + From: &sa2, + }, + sa3.ID: { + ID: sa3.ID, + Action: models.Undefined, + From: &sa1, + }, + }, + }, + }, nil + }).Build() +} + +func newSA(name string) apiv1.Resource { + return apiv1.Resource{ + ID: engine.BuildID(apiVersion, kind, namespace, name), + Type: "Kubernetes", + Attributes: map[string]interface{}{ + "apiVersion": apiVersion, + "kind": kind, + "metadata": map[string]interface{}{ + "name": name, + "namespace": namespace, + }, + }, + } +} diff --git a/pkg/engine/api/source/source.go b/pkg/engine/api/source/source.go index 44abb9f8..2f0be4d0 100644 --- a/pkg/engine/api/source/source.go +++ b/pkg/engine/api/source/source.go @@ -31,7 +31,7 @@ func Pull(ctx context.Context, source *entity.Source) (string, error) { return directory, nil } -// Cleanup() is a method that cleans up tje temporary source code from the source provider. +// Cleanup() is a method that cleans up the temporary source code from the source provider. func Cleanup(ctx context.Context, localDirectory string) { logger := util.GetLogger(ctx) logger.Info("Cleaning up temp directory...") diff --git a/pkg/server/handler/backend/handler.go b/pkg/server/handler/backend/handler.go index 2e310c8c..9c23e934 100644 --- a/pkg/server/handler/backend/handler.go +++ b/pkg/server/handler/backend/handler.go @@ -1,18 +1,16 @@ package backend import ( - "errors" + "context" "net/http" "strconv" - "time" "github.com/go-chi/chi/v5" "github.com/go-chi/render" - "github.com/jinzhu/copier" - "gorm.io/gorm" - "kusionstack.io/kusion/pkg/domain/entity" + "github.com/go-logr/logr" "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/server/handler" + backendmanager "kusionstack.io/kusion/pkg/server/manager/backend" "kusionstack.io/kusion/pkg/server/util" ) @@ -42,22 +40,7 @@ func (h *Handler) CreateBackend() http.HandlerFunc { return } - // Convert request payload to domain model - var createdEntity entity.Backend - if err := copier.Copy(&createdEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // The default state is UnSynced - createdEntity.CreationTimestamp = time.Now() - createdEntity.UpdateTimestamp = time.Now() - - // Create backend with repository - err := h.backendRepo.Create(ctx, &createdEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } + createdEntity, err := h.backendManager.CreateBackend(ctx, requestPayload) handler.HandleResult(w, r, ctx, err, createdEntity) } } @@ -77,22 +60,14 @@ func (h *Handler) CreateBackend() http.HandlerFunc { func (h *Handler) DeleteBackend() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Deleting source...") - backendID := chi.URLParam(r, "backendID") - - // Delete backend with repository - id, err := strconv.Atoi(backendID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidBackendID)) - return - } - err = h.backendRepo.Delete(ctx, uint(id)) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Deleting backend...", "backendID", params.BackendID) + + err = h.backendManager.DeleteBackendByID(ctx, params.BackendID) handler.HandleResult(w, r, ctx, err, "Deletion Success") } } @@ -112,17 +87,12 @@ func (h *Handler) DeleteBackend() http.HandlerFunc { func (h *Handler) UpdateBackend() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Updating backend...") - backendID := chi.URLParam(r, "backendID") - - // convert backend ID to int - id, err := strconv.Atoi(backendID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidBackendID)) + render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Updating backend..., backendID", params.BackendID) // Decode the request body into the payload. var requestPayload request.UpdateBackendRequest @@ -131,35 +101,7 @@ func (h *Handler) UpdateBackend() http.HandlerFunc { return } - // Convert request payload to domain model - var requestEntity entity.Backend - if err := copier.Copy(&requestEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get the existing backend by id - updatedEntity, err := h.backendRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrUpdatingNonExistingBackend)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Overwrite non-zero values in request entity to existing entity - copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) - - // Update backend with repository - err = h.backendRepo.Update(ctx, updatedEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return updated backend + updatedEntity, err := h.backendManager.UpdateBackendByID(ctx, params.BackendID, requestPayload) handler.HandleResult(w, r, ctx, err, updatedEntity) } } @@ -178,28 +120,14 @@ func (h *Handler) UpdateBackend() http.HandlerFunc { func (h *Handler) GetBackend() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Getting backend...") - backendID := chi.URLParam(r, "backendID") - - // Get backend with repository - id, err := strconv.Atoi(backendID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidBackendID)) - return - } - existingEntity, err := h.backendRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingBackend)) - return - } render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Getting backend...", "backendID", params.BackendID) - // Return found backend + existingEntity, err := h.backendManager.GetBackendByID(ctx, params.BackendID) handler.HandleResult(w, r, ctx, err, existingEntity) } } @@ -221,17 +149,22 @@ func (h *Handler) ListBackends() http.HandlerFunc { logger := util.GetLogger(ctx) logger.Info("Listing backend...") - backendEntities, err := h.backendRepo.List(ctx) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingBackend)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return found backends + backendEntities, err := h.backendManager.ListBackends(ctx) handler.HandleResult(w, r, ctx, err, backendEntities) } } + +func requestHelper(r *http.Request) (context.Context, *logr.Logger, *BackendRequestParams, error) { + ctx := r.Context() + backendID := chi.URLParam(r, "backendID") + // Get stack with repository + id, err := strconv.Atoi(backendID) + if err != nil { + return nil, nil, nil, backendmanager.ErrInvalidBackendID + } + logger := util.GetLogger(ctx) + params := BackendRequestParams{ + BackendID: uint(id), + } + return ctx, &logger, ¶ms, nil +} diff --git a/pkg/server/handler/backend/handler_test.go b/pkg/server/handler/backend/handler_test.go new file mode 100644 index 00000000..de75ae19 --- /dev/null +++ b/pkg/server/handler/backend/handler_test.go @@ -0,0 +1,299 @@ +package backend + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/request" + "kusionstack.io/kusion/pkg/infra/persistence" + "kusionstack.io/kusion/pkg/server/handler" + backendmanager "kusionstack.io/kusion/pkg/server/manager/backend" +) + +func TestBackendHandler(t *testing.T) { + backendName := "test-backend" + backendNameSecond := "test-backend-2" + backendNameUpdated := "test-backend-updated" + t.Run("ListBackends", func(t *testing.T) { + sqlMock, fakeGDB, recorder, backendHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, backendName). + AddRow(2, backendNameSecond)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/backends", nil) + assert.NoError(t, err) + + // Call the ListBackends handler function + backendHandler.ListBackends()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, 2, len(resp.Data.([]any))) + }) + + t.Run("GetBackend", func(t *testing.T) { + sqlMock, fakeGDB, recorder, backendHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, backendName)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/backend/{backendID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("backendID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Call the ListBackends handler function + backendHandler.GetBackend()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, backendName, resp.Data.(map[string]any)["name"]) + }) + + t.Run("CreateBackend", func(t *testing.T) { + sqlMock, fakeGDB, recorder, backendHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("POST", "/backend/{backendID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("backendID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.CreateBackendRequest{ + Name: backendName, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectBegin() + sqlMock.ExpectExec("INSERT"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + sqlMock.ExpectCommit() + + // Call the CreateBackend handler function + backendHandler.CreateBackend()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, backendName, resp.Data.(map[string]any)["name"]) + }) + + t.Run("UpdateExistingBackend", func(t *testing.T) { + sqlMock, fakeGDB, recorder, backendHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("POST", "/backend/{backendID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("backendID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateBackendRequest{ + // Set your request payload fields here + ID: 1, + Name: backendNameUpdated, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "Backend__id"}). + AddRow(1, "test-backend-updated", 1)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + + // Call the ListBackends handler function + backendHandler.UpdateBackend()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, backendNameUpdated, resp.Data.(map[string]any)["name"]) + }) + + t.Run("Delete Existing Backend", func(t *testing.T) { + sqlMock, fakeGDB, recorder, backendHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/backend/{backendID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("backendID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Mock the Delete method of the backend repository + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). + AddRow(1)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(1, 1)) + sqlMock.ExpectCommit() + + // Call the DeleteBackend handler function + backendHandler.DeleteBackend()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, "Deletion Success", resp.Data) + }) + + t.Run("Delete Nonexisting Backend", func(t *testing.T) { + sqlMock, fakeGDB, recorder, backendHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/backend/{backendID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("backendID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the DeleteBackend handler function + backendHandler.DeleteBackend()(recorder, req) + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, backendmanager.ErrGettingNonExistingBackend.Error(), resp.Message) + }) + + t.Run("Update Nonexisting Backend", func(t *testing.T) { + sqlMock, fakeGDB, recorder, backendHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("POST", "/backend/{backendID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("backendID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateBackendRequest{ + // Set your request payload fields here + ID: 1, + Name: "test-backend-updated", + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the UpdateBackend handler function + backendHandler.UpdateBackend()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, backendmanager.ErrUpdatingNonExistingBackend.Error(), resp.Message) + }) +} + +func setupTest(t *testing.T) (sqlmock.Sqlmock, *gorm.DB, *httptest.ResponseRecorder, *Handler) { + fakeGDB, sqlMock, err := persistence.GetMockDB() + require.NoError(t, err) + backendRepo := persistence.NewBackendRepository(fakeGDB) + backendHandler := &Handler{ + backendManager: backendmanager.NewBackendManager(backendRepo), + } + recorder := httptest.NewRecorder() + return sqlMock, fakeGDB, recorder, backendHandler +} diff --git a/pkg/server/handler/backend/types.go b/pkg/server/handler/backend/types.go index ab360477..effccdba 100644 --- a/pkg/server/handler/backend/types.go +++ b/pkg/server/handler/backend/types.go @@ -1,25 +1,21 @@ package backend import ( - "errors" - - "kusionstack.io/kusion/pkg/domain/repository" -) - -var ( - ErrGettingNonExistingBackend = errors.New("the backend does not exist") - ErrUpdatingNonExistingBackend = errors.New("the backend to update does not exist") - ErrInvalidBackendID = errors.New("the backend ID should be a uuid") + backendmanager "kusionstack.io/kusion/pkg/server/manager/backend" ) func NewHandler( - backendRepo repository.BackendRepository, + backendManager *backendmanager.BackendManager, ) (*Handler, error) { return &Handler{ - backendRepo: backendRepo, + backendManager: backendManager, }, nil } type Handler struct { - backendRepo repository.BackendRepository + backendManager *backendmanager.BackendManager +} + +type BackendRequestParams struct { + BackendID uint } diff --git a/pkg/server/handler/organization/handler.go b/pkg/server/handler/organization/handler.go index 25e5190b..fc6614a0 100644 --- a/pkg/server/handler/organization/handler.go +++ b/pkg/server/handler/organization/handler.go @@ -1,18 +1,16 @@ package organization import ( - "errors" + "context" "net/http" "strconv" - "time" "github.com/go-chi/chi/v5" "github.com/go-chi/render" - "github.com/jinzhu/copier" - "gorm.io/gorm" - "kusionstack.io/kusion/pkg/domain/entity" + "github.com/go-logr/logr" "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/server/handler" + organizationmanager "kusionstack.io/kusion/pkg/server/manager/organization" "kusionstack.io/kusion/pkg/server/util" ) @@ -42,22 +40,7 @@ func (h *Handler) CreateOrganization() http.HandlerFunc { return } - // Convert request payload to domain model - var createdEntity entity.Organization - if err := copier.Copy(&createdEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // The default state is UnSynced - createdEntity.CreationTimestamp = time.Now() - createdEntity.UpdateTimestamp = time.Now() - - // Create organization with repository - err := h.organizationRepo.Create(ctx, &createdEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } + createdEntity, err := h.organizationManager.CreateOrganization(ctx, requestPayload) handler.HandleResult(w, r, ctx, err, createdEntity) } } @@ -72,27 +55,18 @@ func (h *Handler) CreateOrganization() http.HandlerFunc { // @Failure 429 {object} errors.DetailError "Too Many Requests" // @Failure 404 {object} errors.DetailError "Not Found" // @Failure 500 {object} errors.DetailError "Internal Server Error" -// @Router /api/v1/organization/{organizationName} [delete] // @Router /api/v1/organization/{organizationID} [delete] func (h *Handler) DeleteOrganization() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Deleting source...") - organizationID := chi.URLParam(r, "organizationID") - - // Delete organization with repository - id, err := strconv.Atoi(organizationID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidOrganizationID)) - return - } - err = h.organizationRepo.Delete(ctx, uint(id)) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Deleting organization...") + + err = h.organizationManager.DeleteOrganizationByID(ctx, params.OrganizationID) handler.HandleResult(w, r, ctx, err, "Deletion Success") } } @@ -112,17 +86,12 @@ func (h *Handler) DeleteOrganization() http.HandlerFunc { func (h *Handler) UpdateOrganization() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Updating organization...") - organizationID := chi.URLParam(r, "organizationID") - - // convert organization ID to int - id, err := strconv.Atoi(organizationID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidOrganizationID)) + render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Updating organization...") // Decode the request body into the payload. var requestPayload request.UpdateOrganizationRequest @@ -131,35 +100,7 @@ func (h *Handler) UpdateOrganization() http.HandlerFunc { return } - // Convert request payload to domain model - var requestEntity entity.Organization - if err := copier.Copy(&requestEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get the existing organization by id - updatedEntity, err := h.organizationRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrUpdatingNonExistingOrganization)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Overwrite non-zero values in request entity to existing entity - copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) - - // Update organization with repository - err = h.organizationRepo.Update(ctx, updatedEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return updated organization + updatedEntity, err := h.organizationManager.UpdateOrganizationByID(ctx, params.OrganizationID, requestPayload) handler.HandleResult(w, r, ctx, err, updatedEntity) } } @@ -178,28 +119,14 @@ func (h *Handler) UpdateOrganization() http.HandlerFunc { func (h *Handler) GetOrganization() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Getting organization...") - organizationID := chi.URLParam(r, "organizationID") - - // Get organization with repository - id, err := strconv.Atoi(organizationID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidOrganizationID)) - return - } - existingEntity, err := h.organizationRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingOrganization)) - return - } render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Getting organization...") - // Return found organization + existingEntity, err := h.organizationManager.GetOrganizationByID(ctx, params.OrganizationID) handler.HandleResult(w, r, ctx, err, existingEntity) } } @@ -221,17 +148,22 @@ func (h *Handler) ListOrganizations() http.HandlerFunc { logger := util.GetLogger(ctx) logger.Info("Listing organization...") - organizationEntities, err := h.organizationRepo.List(ctx) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingOrganization)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return found organizations + organizationEntities, err := h.organizationManager.ListOrganizations(ctx) handler.HandleResult(w, r, ctx, err, organizationEntities) } } + +func requestHelper(r *http.Request) (context.Context, *logr.Logger, *OrganizationRequestParams, error) { + ctx := r.Context() + organizationID := chi.URLParam(r, "organizationID") + // Get stack with repository + id, err := strconv.Atoi(organizationID) + if err != nil { + return nil, nil, nil, organizationmanager.ErrInvalidOrganizationID + } + logger := util.GetLogger(ctx) + params := OrganizationRequestParams{ + OrganizationID: uint(id), + } + return ctx, &logger, ¶ms, nil +} diff --git a/pkg/server/handler/organization/handler_test.go b/pkg/server/handler/organization/handler_test.go new file mode 100644 index 00000000..32794f49 --- /dev/null +++ b/pkg/server/handler/organization/handler_test.go @@ -0,0 +1,302 @@ +package organization + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/request" + "kusionstack.io/kusion/pkg/infra/persistence" + "kusionstack.io/kusion/pkg/server/handler" + organizationmanager "kusionstack.io/kusion/pkg/server/manager/organization" +) + +func TestOrganizationHandler(t *testing.T) { + var ( + orgName = "test-org" + orgNameSecond = "test-org-2" + orgNameUpdated = "test-org-updated" + ) + t.Run("ListOrganizations", func(t *testing.T) { + sqlMock, fakeGDB, recorder, organizationHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "Backend__id"}). + AddRow(1, orgName, 1). + AddRow(2, orgNameSecond, 2)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/organizations", nil) + assert.NoError(t, err) + + // Call the ListOrganizations handler function + organizationHandler.ListOrganizations()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, 2, len(resp.Data.([]any))) + }) + + t.Run("GetOrganization", func(t *testing.T) { + sqlMock, fakeGDB, recorder, organizationHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow(1, orgName)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/organization/{organizationID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("organizationID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Call the ListOrganizations handler function + organizationHandler.GetOrganization()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, orgName, resp.Data.(map[string]any)["name"]) + }) + + t.Run("CreateOrganization", func(t *testing.T) { + sqlMock, fakeGDB, recorder, organizationHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("POST", "/organization/{organizationID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("organizationID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.CreateOrganizationRequest{ + Name: orgName, + Owners: []string{"hua.li", "xiaoming.li"}, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectBegin() + sqlMock.ExpectExec("INSERT"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + sqlMock.ExpectCommit() + + // Call the CreateOrganization handler function + organizationHandler.CreateOrganization()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, orgName, resp.Data.(map[string]any)["name"]) + }) + + t.Run("UpdateExistingOrganization", func(t *testing.T) { + sqlMock, fakeGDB, recorder, organizationHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("POST", "/organization/{organizationID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("organizationID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateOrganizationRequest{ + // Set your request payload fields here + ID: 1, + Name: orgNameUpdated, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "Backend__id"}). + AddRow(1, orgName, 1)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + + // Call the ListOrganizations handler function + organizationHandler.UpdateOrganization()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, orgNameUpdated, resp.Data.(map[string]any)["name"]) + }) + + t.Run("Delete Existing Organization", func(t *testing.T) { + sqlMock, fakeGDB, recorder, organizationHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/organization/{organizationID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("organizationID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Mock the Delete method of the organization repository + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). + AddRow(1)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(1, 1)) + sqlMock.ExpectCommit() + + // Call the DeleteOrganization handler function + organizationHandler.DeleteOrganization()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, "Deletion Success", resp.Data) + }) + + t.Run("Delete Nonexisting Organization", func(t *testing.T) { + sqlMock, fakeGDB, recorder, organizationHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/organization/{organizationID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("organizationID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the DeleteOrganization handler function + organizationHandler.DeleteOrganization()(recorder, req) + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, organizationmanager.ErrGettingNonExistingOrganization.Error(), resp.Message) + }) + + t.Run("Update Nonexisting Organization", func(t *testing.T) { + sqlMock, fakeGDB, recorder, organizationHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("POST", "/organization/{organizationID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("organizationID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateOrganizationRequest{ + // Set your request payload fields here + ID: 1, + Name: orgNameUpdated, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the UpdateOrganization handler function + organizationHandler.UpdateOrganization()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, organizationmanager.ErrUpdatingNonExistingOrganization.Error(), resp.Message) + }) +} + +func setupTest(t *testing.T) (sqlmock.Sqlmock, *gorm.DB, *httptest.ResponseRecorder, *Handler) { + fakeGDB, sqlMock, err := persistence.GetMockDB() + require.NoError(t, err) + organizationRepo := persistence.NewOrganizationRepository(fakeGDB) + organizationHandler := &Handler{ + organizationManager: organizationmanager.NewOrganizationManager(organizationRepo), + } + recorder := httptest.NewRecorder() + return sqlMock, fakeGDB, recorder, organizationHandler +} diff --git a/pkg/server/handler/organization/types.go b/pkg/server/handler/organization/types.go index 3ee4e557..bdc1c8f0 100644 --- a/pkg/server/handler/organization/types.go +++ b/pkg/server/handler/organization/types.go @@ -1,25 +1,21 @@ package organization import ( - "errors" - - "kusionstack.io/kusion/pkg/domain/repository" -) - -var ( - ErrGettingNonExistingOrganization = errors.New("the organization does not exist") - ErrUpdatingNonExistingOrganization = errors.New("the organization to update does not exist") - ErrInvalidOrganizationID = errors.New("the organization ID should be a uuid") + organizationmanager "kusionstack.io/kusion/pkg/server/manager/organization" ) func NewHandler( - organizationRepo repository.OrganizationRepository, + organizationManager *organizationmanager.OrganizationManager, ) (*Handler, error) { return &Handler{ - organizationRepo: organizationRepo, + organizationManager: organizationManager, }, nil } type Handler struct { - organizationRepo repository.OrganizationRepository + organizationManager *organizationmanager.OrganizationManager +} + +type OrganizationRequestParams struct { + OrganizationID uint } diff --git a/pkg/server/handler/project/handler.go b/pkg/server/handler/project/handler.go index 1156165e..99e5a408 100644 --- a/pkg/server/handler/project/handler.go +++ b/pkg/server/handler/project/handler.go @@ -1,18 +1,16 @@ package project import ( - "errors" + "context" "net/http" "strconv" - "time" "github.com/go-chi/chi/v5" "github.com/go-chi/render" - "github.com/jinzhu/copier" - "gorm.io/gorm" - "kusionstack.io/kusion/pkg/domain/entity" + "github.com/go-logr/logr" "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/server/handler" + projectmanager "kusionstack.io/kusion/pkg/server/manager/project" "kusionstack.io/kusion/pkg/server/util" ) @@ -42,43 +40,7 @@ func (h *Handler) CreateProject() http.HandlerFunc { return } - // Convert request payload to domain model - var createdEntity entity.Project - if err := copier.Copy(&createdEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - createdEntity.CreationTimestamp = time.Now() - createdEntity.UpdateTimestamp = time.Now() - - // Get source by id - sourceEntity, err := h.sourceRepo.Get(ctx, requestPayload.SourceID) - if err != nil && err == gorm.ErrRecordNotFound { - render.Render(w, r, handler.FailureResponse(ctx, ErrSourceNotFound)) - return - } else if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - createdEntity.Source = sourceEntity - - // Get org by id - organizationEntity, err := h.organizationRepo.Get(ctx, requestPayload.OrganizationID) - if err != nil && err == gorm.ErrRecordNotFound { - render.Render(w, r, handler.FailureResponse(ctx, ErrOrgNotFound)) - return - } else if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - createdEntity.Organization = organizationEntity - - // Create project with repository - err = h.projectRepo.Create(ctx, &createdEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } + createdEntity, err := h.projectManager.CreateProject(ctx, requestPayload) handler.HandleResult(w, r, ctx, err, createdEntity) } } @@ -98,22 +60,14 @@ func (h *Handler) CreateProject() http.HandlerFunc { func (h *Handler) DeleteProject() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Deleting source...") - projectID := chi.URLParam(r, "projectID") - - // Delete project with repository - id, err := strconv.Atoi(projectID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidProjectID)) - return - } - err = h.projectRepo.Delete(ctx, uint(id)) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Deleting source...", "projectID", params.ProjectID) + + err = h.projectManager.DeleteProjectByID(ctx, params.ProjectID) handler.HandleResult(w, r, ctx, err, "Deletion Success") } } @@ -133,17 +87,12 @@ func (h *Handler) DeleteProject() http.HandlerFunc { func (h *Handler) UpdateProject() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Updating project...") - projectID := chi.URLParam(r, "projectID") - - // convert project ID to int - id, err := strconv.Atoi(projectID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidProjectID)) + render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Updating project...", "projectID", params.ProjectID) // Decode the request body into the payload. var requestPayload request.UpdateProjectRequest @@ -151,54 +100,8 @@ func (h *Handler) UpdateProject() http.HandlerFunc { render.Render(w, r, handler.FailureResponse(ctx, err)) return } - // fmt.Printf("requestPayload.SourceID: %v; requestPayload.Organization: %v", requestPayload.SourceID, requestPayload.OrganizationID) - - // Convert request payload to domain model - var requestEntity entity.Project - if err := copier.Copy(&requestEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // Get source by id - sourceEntity, err := handler.GetSourceByID(ctx, h.sourceRepo, requestPayload.SourceID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - requestEntity.Source = sourceEntity - - // Get organization by id - organizationEntity, err := handler.GetOrganizationByID(ctx, h.organizationRepo, requestPayload.OrganizationID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - requestEntity.Organization = organizationEntity - - // Get the existing project by id - updatedEntity, err := h.projectRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrUpdatingNonExistingProject)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Overwrite non-zero values in request entity to existing entity - copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) - // fmt.Printf("updatedEntity.Source: %v; updatedEntity.Organization: %v", updatedEntity.Source, updatedEntity.Organization) - - // Update project with repository - err = h.projectRepo.Update(ctx, updatedEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return updated project + updatedEntity, err := h.projectManager.UpdateProjectByID(ctx, params.ProjectID, requestPayload) handler.HandleResult(w, r, ctx, err, updatedEntity) } } @@ -217,28 +120,14 @@ func (h *Handler) UpdateProject() http.HandlerFunc { func (h *Handler) GetProject() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Getting project...") - projectID := chi.URLParam(r, "projectID") - - // Get project with repository - id, err := strconv.Atoi(projectID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidProjectID)) - return - } - existingEntity, err := h.projectRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingProject)) - return - } render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Getting project...", "projectID", params.ProjectID) - // Return found project + existingEntity, err := h.projectManager.GetProjectByID(ctx, params.ProjectID) handler.HandleResult(w, r, ctx, err, existingEntity) } } @@ -260,17 +149,22 @@ func (h *Handler) ListProjects() http.HandlerFunc { logger := util.GetLogger(ctx) logger.Info("Listing project...") - projectEntities, err := h.projectRepo.List(ctx) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingProject)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return found projects + projectEntities, err := h.projectManager.ListProjects(ctx) handler.HandleResult(w, r, ctx, err, projectEntities) } } + +func requestHelper(r *http.Request) (context.Context, *logr.Logger, *ProjectRequestParams, error) { + ctx := r.Context() + projectID := chi.URLParam(r, "projectID") + // Get stack with repository + id, err := strconv.Atoi(projectID) + if err != nil { + return nil, nil, nil, projectmanager.ErrInvalidProjectID + } + logger := util.GetLogger(ctx) + params := ProjectRequestParams{ + ProjectID: uint(id), + } + return ctx, &logger, ¶ms, nil +} diff --git a/pkg/server/handler/project/handler_test.go b/pkg/server/handler/project/handler_test.go new file mode 100644 index 00000000..c2aa75cb --- /dev/null +++ b/pkg/server/handler/project/handler_test.go @@ -0,0 +1,333 @@ +package project + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/constant" + "kusionstack.io/kusion/pkg/domain/request" + "kusionstack.io/kusion/pkg/infra/persistence" + "kusionstack.io/kusion/pkg/server/handler" + projectmanager "kusionstack.io/kusion/pkg/server/manager/project" +) + +func TestProjectHandler(t *testing.T) { + var ( + projectName = "test-project" + projectNameSecond = "test-project-2" + projectPath = "/path/to/project" + projectNameUpdated = "test-project-updated" + projectPathUpdated = "/path/to/project/updated" + owners = persistence.MultiString{"hua.li", "xiaoming.li"} + ) + t.Run("ListProjects", func(t *testing.T) { + sqlMock, fakeGDB, recorder, projectHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "Organization__id", "Organization__name", "Organization__owners", "Source__id", "Source__remote", "Source__source_provider"}). + AddRow(1, projectName, projectPath, 1, "test-org", owners, 1, "https://github.com/test/repo", constant.SourceProviderTypeGithub). + AddRow(2, projectNameSecond, projectPath, 2, "test-org-2", owners, 1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/projects", nil) + assert.NoError(t, err) + + // Call the ListProjects handler function + projectHandler.ListProjects()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, 2, len(resp.Data.([]any))) + }) + + t.Run("GetProject", func(t *testing.T) { + sqlMock, fakeGDB, recorder, projectHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "Organization__id", "Organization__name", "Organization__owners", "Source__id", "Source__remote", "Source__source_provider"}). + AddRow(1, projectName, projectPath, 1, "test-org", owners, 1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/project/{projectID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("projectID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Call the ListProjects handler function + projectHandler.GetProject()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, projectName, resp.Data.(map[string]any)["name"]) + }) + + t.Run("CreateProject", func(t *testing.T) { + sqlMock, fakeGDB, recorder, projectHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("POST", "/project/{projectID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("projectID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.CreateProjectRequest{ + Name: projectName, + Path: projectPath, + SourceID: uint(1), + OrganizationID: uint(1), + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "remote", "source_provider"}). + AddRow(1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owners"}). + AddRow(1, "test-org", owners)) + sqlMock.ExpectBegin() + sqlMock.ExpectExec("INSERT"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + sqlMock.ExpectCommit() + + // Call the CreateProject handler function + projectHandler.CreateProject()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, projectName, resp.Data.(map[string]any)["name"]) + }) + + t.Run("UpdateExistingProject", func(t *testing.T) { + sqlMock, fakeGDB, recorder, projectHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("PUT", "/project/{projectID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("projectID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateProjectRequest{ + // Set your request payload fields here + ID: 1, + Name: projectNameUpdated, + Path: projectPathUpdated, + OrganizationID: 1, + SourceID: 1, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "remote", "source_provider"}). + AddRow(1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owners"}). + AddRow(1, "test-org", owners)) + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "Organization__id", "Organization__name", "Organization__owners", "Source__id", "Source__remote", "Source__source_provider"}). + AddRow(1, projectName, projectPath, 1, "test-org", owners, 1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + + // Call the ListProjects handler function + projectHandler.UpdateProject()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, projectNameUpdated, resp.Data.(map[string]any)["name"]) + assert.Equal(t, projectPathUpdated, resp.Data.(map[string]any)["path"]) + }) + + t.Run("Delete Existing Project", func(t *testing.T) { + sqlMock, fakeGDB, recorder, projectHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/project/{projectID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("projectID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Mock the Delete method of the project repository + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). + AddRow(1)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(1, 1)) + sqlMock.ExpectCommit() + + // Call the DeleteProject handler function + projectHandler.DeleteProject()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, "Deletion Success", resp.Data) + }) + + t.Run("Delete Nonexisting Project", func(t *testing.T) { + sqlMock, fakeGDB, recorder, projectHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/project/{projectID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("projectID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the DeleteProject handler function + projectHandler.DeleteProject()(recorder, req) + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, projectmanager.ErrGettingNonExistingProject.Error(), resp.Message) + }) + + t.Run("Update Nonexisting Project", func(t *testing.T) { + sqlMock, fakeGDB, recorder, projectHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("POST", "/project/{projectID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("projectID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateProjectRequest{ + // Set your request payload fields here + ID: 1, + Name: "test-project-updated", + Path: projectPathUpdated, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "remote", "source_provider"}). + AddRow(1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owners"}). + AddRow(1, "test-org", owners)) + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the UpdateProject handler function + projectHandler.UpdateProject()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, projectmanager.ErrUpdatingNonExistingProject.Error(), resp.Message) + }) +} + +func setupTest(t *testing.T) (sqlmock.Sqlmock, *gorm.DB, *httptest.ResponseRecorder, *Handler) { + fakeGDB, sqlMock, err := persistence.GetMockDB() + require.NoError(t, err) + projectRepo := persistence.NewProjectRepository(fakeGDB) + sourceRepo := persistence.NewSourceRepository(fakeGDB) + organizationRepo := persistence.NewOrganizationRepository(fakeGDB) + projectHandler := &Handler{ + projectManager: projectmanager.NewProjectManager(projectRepo, organizationRepo, sourceRepo), + } + recorder := httptest.NewRecorder() + return sqlMock, fakeGDB, recorder, projectHandler +} diff --git a/pkg/server/handler/project/types.go b/pkg/server/handler/project/types.go index 1d706b91..214db948 100644 --- a/pkg/server/handler/project/types.go +++ b/pkg/server/handler/project/types.go @@ -1,33 +1,21 @@ package project import ( - "errors" - - "kusionstack.io/kusion/pkg/domain/repository" -) - -var ( - ErrGettingNonExistingProject = errors.New("the project does not exist") - ErrUpdatingNonExistingProject = errors.New("the project to update does not exist") - ErrSourceNotFound = errors.New("the specified source does not exist") - ErrOrgNotFound = errors.New("the specified org does not exist") - ErrInvalidProjectID = errors.New("the project ID should be a uuid") + projectmanager "kusionstack.io/kusion/pkg/server/manager/project" ) func NewHandler( - organizationRepo repository.OrganizationRepository, - projectRepo repository.ProjectRepository, - sourceRepo repository.SourceRepository, + projectManager *projectmanager.ProjectManager, ) (*Handler, error) { return &Handler{ - organizationRepo: organizationRepo, - projectRepo: projectRepo, - sourceRepo: sourceRepo, + projectManager: projectManager, }, nil } type Handler struct { - organizationRepo repository.OrganizationRepository - projectRepo repository.ProjectRepository - sourceRepo repository.SourceRepository + projectManager *projectmanager.ProjectManager +} + +type ProjectRequestParams struct { + ProjectID uint } diff --git a/pkg/server/handler/source/handler.go b/pkg/server/handler/source/handler.go index 184c47a1..e223a986 100644 --- a/pkg/server/handler/source/handler.go +++ b/pkg/server/handler/source/handler.go @@ -1,18 +1,16 @@ package source import ( + "context" "net/http" - "net/url" "strconv" "github.com/go-chi/chi/v5" "github.com/go-chi/render" - "github.com/jinzhu/copier" - "github.com/pkg/errors" - "gorm.io/gorm" - "kusionstack.io/kusion/pkg/domain/entity" + "github.com/go-logr/logr" "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/server/handler" + sourcemanager "kusionstack.io/kusion/pkg/server/manager/source" "kusionstack.io/kusion/pkg/server/util" ) @@ -42,29 +40,8 @@ func (h *Handler) CreateSource() http.HandlerFunc { return } - // Convert request payload to domain model - var createdEntity entity.Source - if err := copier.Copy(&createdEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Convert Remote string to URL - remote, err := url.Parse(requestPayload.Remote) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - createdEntity.Remote = remote - - // Create source with repository - err = h.sourceRepo.Create(ctx, &createdEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // Return created entity + createdEntity, err := h.sourceManager.CreateSource(ctx, requestPayload) handler.HandleResult(w, r, ctx, err, createdEntity) } } @@ -83,22 +60,14 @@ func (h *Handler) CreateSource() http.HandlerFunc { func (h *Handler) DeleteSource() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Deleting source...") - sourceID := chi.URLParam(r, "sourceID") - - // Delete source with repository - id, err := strconv.Atoi(sourceID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidSourceID)) - return - } - err = h.sourceRepo.Delete(ctx, uint(id)) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Deleting source...") + + err = h.sourceManager.DeleteSourceByID(ctx, params.SourceID) handler.HandleResult(w, r, ctx, err, "Deletion Success") } } @@ -118,17 +87,12 @@ func (h *Handler) DeleteSource() http.HandlerFunc { func (h *Handler) UpdateSource() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Updating source...") - sourceID := chi.URLParam(r, "sourceID") - - // Convert sourceID to int - id, err := strconv.Atoi(sourceID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidSourceID)) + render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Updating source...") // Decode the request body into the payload. var requestPayload request.UpdateSourceRequest @@ -137,43 +101,8 @@ func (h *Handler) UpdateSource() http.HandlerFunc { return } - // Convert request payload to domain model - var requestEntity entity.Source - if err := copier.Copy(&requestEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Convert Remote string to URL - remote, err := url.Parse(requestPayload.Remote) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - requestEntity.Remote = remote - - // Get the existing source by id - updatedEntity, err := h.sourceRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrUpdatingNonExistingSource)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Overwrite non-zero values in request entity to existing entity - copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) - - // Update source with repository - err = h.sourceRepo.Update(ctx, updatedEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // Return updated source + updatedEntity, err := h.sourceManager.UpdateSourceByID(ctx, params.SourceID, requestPayload) handler.HandleResult(w, r, ctx, err, updatedEntity) } } @@ -192,28 +121,14 @@ func (h *Handler) UpdateSource() http.HandlerFunc { func (h *Handler) GetSource() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Getting source...") - sourceID := chi.URLParam(r, "sourceID") - - // Get source with repository - id, err := strconv.Atoi(sourceID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidSourceID)) - return - } - existingEntity, err := h.sourceRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingSource)) - return - } render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Getting source...") - // Return found source + existingEntity, err := h.sourceManager.GetSourceByID(ctx, params.SourceID) handler.HandleResult(w, r, ctx, err, existingEntity) } } @@ -235,17 +150,23 @@ func (h *Handler) ListSources() http.HandlerFunc { logger := util.GetLogger(ctx) logger.Info("Listing source...") - existingEntity, err := h.sourceRepo.List(ctx) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingSource)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } + // List sources + sourceEntities, err := h.sourceManager.ListSources(ctx) + handler.HandleResult(w, r, ctx, err, sourceEntities) + } +} - // Return found source - handler.HandleResult(w, r, ctx, err, existingEntity) +func requestHelper(r *http.Request) (context.Context, *logr.Logger, *SourceRequestParams, error) { + ctx := r.Context() + sourceID := chi.URLParam(r, "sourceID") + // Get stack with repository + id, err := strconv.Atoi(sourceID) + if err != nil { + return nil, nil, nil, sourcemanager.ErrInvalidSourceID + } + logger := util.GetLogger(ctx) + params := SourceRequestParams{ + SourceID: uint(id), } + return ctx, &logger, ¶ms, nil } diff --git a/pkg/server/handler/source/handler_test.go b/pkg/server/handler/source/handler_test.go index 3ad3d886..23376a60 100644 --- a/pkg/server/handler/source/handler_test.go +++ b/pkg/server/handler/source/handler_test.go @@ -18,6 +18,7 @@ import ( "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/infra/persistence" "kusionstack.io/kusion/pkg/server/handler" + sourcemanager "kusionstack.io/kusion/pkg/server/manager/source" ) func TestSourceHandler(t *testing.T) { @@ -242,7 +243,7 @@ func TestSourceHandler(t *testing.T) { assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, resp.Success, false) - assert.Equal(t, resp.Message, gorm.ErrRecordNotFound.Error()) + assert.Equal(t, resp.Message, sourcemanager.ErrGettingNonExistingSource.Error()) }) t.Run("Update Nonexisting Source", func(t *testing.T) { @@ -287,7 +288,7 @@ func TestSourceHandler(t *testing.T) { // Assertion assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, resp.Success, false) - assert.Equal(t, resp.Message, ErrUpdatingNonExistingSource.Error()) + assert.Equal(t, resp.Message, sourcemanager.ErrUpdatingNonExistingSource.Error()) }) } @@ -296,7 +297,7 @@ func setupTest(t *testing.T) (sqlmock.Sqlmock, *gorm.DB, *httptest.ResponseRecor require.NoError(t, err) repo := persistence.NewSourceRepository(fakeGDB) sourceHandler := &Handler{ - sourceRepo: repo, + sourceManager: sourcemanager.NewSourceManager(repo), } recorder := httptest.NewRecorder() return sqlMock, fakeGDB, recorder, sourceHandler diff --git a/pkg/server/handler/source/types.go b/pkg/server/handler/source/types.go index 784532ba..1f58b1b5 100644 --- a/pkg/server/handler/source/types.go +++ b/pkg/server/handler/source/types.go @@ -1,25 +1,21 @@ package source import ( - "errors" - - "kusionstack.io/kusion/pkg/domain/repository" -) - -var ( - ErrGettingNonExistingSource = errors.New("the source does not exist") - ErrUpdatingNonExistingSource = errors.New("the source to update does not exist") - ErrInvalidSourceID = errors.New("the source ID should be a uuid") + sourcemanager "kusionstack.io/kusion/pkg/server/manager/source" ) func NewHandler( - sourceRepo repository.SourceRepository, + sourceManager *sourcemanager.SourceManager, ) (*Handler, error) { return &Handler{ - sourceRepo: sourceRepo, + sourceManager: sourceManager, }, nil } type Handler struct { - sourceRepo repository.SourceRepository + sourceManager *sourcemanager.SourceManager +} + +type SourceRequestParams struct { + SourceID uint } diff --git a/pkg/server/handler/stack/execute.go b/pkg/server/handler/stack/execute.go index d86a4a86..304ca44e 100644 --- a/pkg/server/handler/stack/execute.go +++ b/pkg/server/handler/stack/execute.go @@ -1,24 +1,16 @@ package stack import ( - "encoding/json" - "errors" - "fmt" + "context" "net/http" - "os" "strconv" - "time" "github.com/go-chi/chi/v5" "github.com/go-chi/render" + "github.com/go-logr/logr" yamlv2 "gopkg.in/yaml.v2" - "gorm.io/gorm" - apiv1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" - "kusionstack.io/kusion/pkg/backend" - "kusionstack.io/kusion/pkg/domain/constant" - engineapi "kusionstack.io/kusion/pkg/engine/api" - sourceapi "kusionstack.io/kusion/pkg/engine/api/source" "kusionstack.io/kusion/pkg/server/handler" + stackmanager "kusionstack.io/kusion/pkg/server/manager/stack" "kusionstack.io/kusion/pkg/server/util" ) @@ -36,148 +28,31 @@ import ( func (h *Handler) PreviewStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Previewing stack...") - // Get params from URL parameter - stackID := chi.URLParam(r, "stackID") - - // Get params from query parameter - formatParam := r.URL.Query().Get("output") - // TODO: Define default behaviors - detailParam, _ := strconv.ParseBool(r.URL.Query().Get("detail")) - // kpmParam, _ := strconv.ParseBool(r.URL.Query().Get("kpm")) - // TODO: Should match automatically eventually - workspaceParam := r.URL.Query().Get("workspace") - - // Get stack with repository - id, err := strconv.Atoi(stackID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidStacktID)) - return - } - stackEntity, err := h.stackRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingStack)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get project by id - project, err := stackEntity.Project.ConvertToCore() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get stack by id - stack, err := stackEntity.ConvertToCore() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // get workspace configurations - bk, err := backend.NewBackend("") - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - wsStorage, err := bk.WorkspaceStorage() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - ws, err := wsStorage.Get(workspaceParam) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Previewing stack...", "stackID", params.StackID) - // Build API inputs - // get project to get source and workdir - projectEntity, err := handler.GetProjectByID(ctx, h.projectRepo, stackEntity.Project.ID) + // Call preview stack + changes, err := h.stackManager.PreviewStack(ctx, params.StackID, params.Workspace) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } - directory, workDir, err := getWorkDirFromSource(ctx, stackEntity, projectEntity) + previewChanges, err := stackmanager.ProcessChanges(ctx, w, changes, params.Format, params.Detail) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } - previewOptions := buildOptions(false) - stack.Path = workDir - - // Cleanup - defer sourceapi.Cleanup(ctx, directory) - - // Generate spec - sp, err := engineapi.GenerateSpecWithSpinner(project, stack, ws, true) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // return immediately if no resource found in stack - // todo: if there is no resource, should still do diff job; for now, if output is json format, there is no hint - if sp == nil || len(sp.Resources) == 0 { - if formatParam != engineapi.JSONOutput { - logger.Info("No resource change found in this stack...") - render.Render(w, r, handler.SuccessResponse(ctx, "No resource change found in this stack.")) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Compute state storage - // TODO: this local storage is temporary, will support remote later - stateStorage := bk.StateStorage(project.Name, stack.Name, ws.Name) - logger.Info("Local state storage found", "Path", stateStorage) - - // Compute changes for preview - changes, err := engineapi.Preview(previewOptions, stateStorage, sp, project, stack) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // If output format is json, return details without any summary or formatting - if formatParam == engineapi.JSONOutput { - var previewChanges []byte - previewChanges, err = json.Marshal(changes) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - logger.Info(string(previewChanges)) - render.Render(w, r, handler.SuccessResponse(ctx, string(previewChanges))) - return - } - - if changes.AllUnChange() { - logger.Info("All resources are reconciled. No diff found") - render.Render(w, r, handler.SuccessResponse(ctx, "All resources are reconciled. No diff found")) - return - } - - // Summary preview table - changes.Summary(w, true) - - // Detail detection - if detailParam { - render.Render(w, r, handler.SuccessResponse(ctx, changes.Diffs(true))) - } + render.Render(w, r, handler.SuccessResponse(ctx, previewChanges)) } } -// @Summary Build stack -// @Description Build stack information by stack ID +// @Summary Generate stack +// @Description Generate stack information by stack ID // @Produce json // @Param id path int true "Stack ID" // @Success 200 {object} entity.Stack "Success" @@ -186,100 +61,25 @@ func (h *Handler) PreviewStack() http.HandlerFunc { // @Failure 429 {object} errors.DetailError "Too Many Requests" // @Failure 404 {object} errors.DetailError "Not Found" // @Failure 500 {object} errors.DetailError "Internal Server Error" -// @Router /api/v1/stack/{stackID}/build [post] -func (h *Handler) BuildStack() http.HandlerFunc { +// @Router /api/v1/stack/{stackID}/generate [post] +func (h *Handler) GenerateStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Building stack...") - // Get params from URL parameter - stackID := chi.URLParam(r, "stackID") - // TODO: Define default behaviors - // kpmParam, _ := strconv.ParseBool(r.URL.Query().Get("kpm")) - // TODO: Should match automatically eventually - workspaceParam := r.URL.Query().Get("workspace") - - // Get stack with repository - id, err := strconv.Atoi(stackID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidStacktID)) - return - } - stackEntity, err := h.stackRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingStack)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get project by id - project, err := stackEntity.Project.ConvertToCore() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get stack by id - stack, err := stackEntity.ConvertToCore() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // get workspace configurations - bk, err := backend.NewBackend("") - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - wsStorage, err := bk.WorkspaceStorage() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - ws, err := wsStorage.Get(workspaceParam) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Build API inputs - // get project to get source and workdir - projectEntity, err := handler.GetProjectByID(ctx, h.projectRepo, stackEntity.Project.ID) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Generating stack...", "stackID", params.StackID) - directory, workDir, err := getWorkDirFromSource(ctx, stackEntity, projectEntity) - logger.Info("workDir derived", "workDir", workDir) - logger.Info("directory derived", "directory", directory) - - stack.Path = workDir - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // intentOptions, _ := buildOptions(workDir, kpmParam, false) - // Cleanup - defer sourceapi.Cleanup(ctx, directory) - - // Generate spec - sp, err := engineapi.GenerateSpecWithSpinner(project, stack, ws, true) + // Call generate stack + sp, err := h.stackManager.GenerateStack(ctx, params.StackID, params.Workspace) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } yaml, err := yamlv2.Marshal(sp) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } handler.HandleResult(w, r, ctx, err, string(yaml)) } } @@ -298,205 +98,25 @@ func (h *Handler) BuildStack() http.HandlerFunc { func (h *Handler) ApplyStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Applying stack...") - // Get params from URL parameter - stackID := chi.URLParam(r, "stackID") - - // Get params from query parameter - formatParam := r.URL.Query().Get("output") - dryRunParam, _ := strconv.ParseBool(r.URL.Query().Get("dryrun")) - // TODO: Define default behaviors - detailParam, _ := strconv.ParseBool(r.URL.Query().Get("detail")) - // kpmParam, _ := strconv.ParseBool(r.URL.Query().Get("kpm")) - // TODO: Should match automatically eventually - workspaceParam := r.URL.Query().Get("workspace") - - // Get stack with repository - id, err := strconv.Atoi(stackID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidStacktID)) - return - } - stackEntity, err := h.stackRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingStack)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get project by id - project, err := stackEntity.Project.ConvertToCore() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get stack by id - stack, err := stackEntity.ConvertToCore() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // get workspace configurations - localBackend, err := backend.NewBackend("") - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // wsStorage, err := bk.WorkspaceStorage() - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - // ws, err := wsStorage.Get(workspaceParam) - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - - // // Get backend by id - // workspaceEntity, err := h.workspaceRepo.GetByName(ctx, workspaceParam) - // if err != nil && err == gorm.ErrRecordNotFound { - // render.Render(w, r, handler.FailureResponse(ctx, ErrWorkspaceNotFound)) - // return - // } else if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - // // Generate backend from entity - // remoteBackend, err := NewBackendFromEntity(*workspaceEntity.Backend) - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - - remoteBackend, err := h.GetBackendFromWorkspaceName(ctx, workspaceParam) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get workspace configurations from backend - // TODO: temporarily local for now, should be replaced by variable sets - wsStorage, err := localBackend.WorkspaceStorage() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - ws, err := wsStorage.Get(workspaceParam) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Build API inputs - // get project to get source and workdir - projectEntity, err := handler.GetProjectByID(ctx, h.projectRepo, stackEntity.Project.ID) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Applying stack...", "stackID", params.StackID) - directory, workDir, err := getWorkDirFromSource(ctx, stackEntity, projectEntity) + err = h.stackManager.ApplyStack(ctx, params.StackID, params.Workspace, params.Format, params.Detail, params.Dryrun, w) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // Cleanup - defer sourceapi.Cleanup(ctx, directory) - - executeOptions := buildOptions(dryRunParam) - stack.Path = workDir - - // Generate spec - sp, err := engineapi.GenerateSpecWithSpinner(project, stack, ws, true) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // return immediately if no resource found in stack - // todo: if there is no resource, should still do diff job; for now, if output is json format, there is no hint - if sp == nil || len(sp.Resources) == 0 { - if formatParam != engineapi.JSONOutput { - logger.Info("No resource change found in this stack...") - render.Render(w, r, handler.SuccessResponse(ctx, "No resource change found in this stack.")) + if err == stackmanager.ErrDryrunDestroy { + render.Render(w, r, handler.SuccessResponse(ctx, "Dry-run mode enabled, the above resources will be destroyed if dryrun is set to false")) return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Compute state storage - // TODO: this local storage is temporary, will support remote later - stateStorage := remoteBackend.StateStorage(project.Name, stack.Name, workspaceParam) - logger.Info("Remote state storage found", "Remote", stateStorage) - // logger.Info("Local state storage found", "Path", stateStorage) - - // Compute changes for preview - changes, err := engineapi.Preview(executeOptions, stateStorage, sp, project, stack) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // If output format is json, return details without any summary or formatting - if formatParam == engineapi.JSONOutput { - var previewChanges []byte - previewChanges, err = json.Marshal(changes) - if err != nil { + } else { render.Render(w, r, handler.FailureResponse(ctx, err)) return } - logger.Info(string(previewChanges)) - render.Render(w, r, handler.SuccessResponse(ctx, string(previewChanges))) - return - } - - if changes.AllUnChange() { - logger.Info("All resources are reconciled. No diff found") - render.Render(w, r, handler.SuccessResponse(ctx, "All resources are reconciled. No diff found")) - return - } - - // Summary preview table - changes.Summary(w, true) - // detail detection - if detailParam { - changes.OutputDiff("all") - } - - logger.Info("Start applying diffs ...") - if err = engineapi.Apply(executeOptions, stateStorage, sp, changes, os.Stdout); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // if dry run, print the hint - if dryRunParam { - fmt.Printf("NOTE: Currently running in the --dry-run mode, the above configuration does not really take effect") - render.Render(w, r, handler.SuccessResponse(ctx, "NOTE: Currently running in the --dry-run mode, the above configuration does not really take effect")) - return - } - - // Update LastSyncTimestamp to current time and set stack syncState to synced - stackEntity.LastSyncTimestamp = time.Now() - stackEntity.SyncState = constant.StackStateSynced - - // Update stack with repository - err = h.stackRepo.Update(ctx, stackEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return } - // Destroy completed + // Apply completed logger.Info("apply completed") render.Render(w, r, handler.SuccessResponse(ctx, "apply completed")) @@ -524,143 +144,51 @@ func (h *Handler) ApplyStack() http.HandlerFunc { func (h *Handler) DestroyStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Destroying stack...") - // Get params from URL parameter - stackID := chi.URLParam(r, "stackID") - // TODO: Define default behaviors - // kpmParam, _ := strconv.ParseBool(r.URL.Query().Get("kpm")) - // TODO: Should match automatically eventually - workspaceParam := r.URL.Query().Get("workspace") - detailParam, _ := strconv.ParseBool(r.URL.Query().Get("detail")) - dryRunParam, _ := strconv.ParseBool(r.URL.Query().Get("dryrun")) - - // Get stack with repository - id, err := strconv.Atoi(stackID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidStacktID)) - return - } - stackEntity, err := h.stackRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingStack)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get project by id - project, err := stackEntity.Project.ConvertToCore() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get stack by id - stack, err := stackEntity.ConvertToCore() - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // get workspace configurations - // localBackend, err := backend.NewBackend("") - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - // wsStorage, err := bk.WorkspaceStorage() - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - // ws, err := wsStorage.Get(workspaceParam) - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - - remoteBackend, err := h.GetBackendFromWorkspaceName(ctx, workspaceParam) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Build API inputs - // get project to get source and workdir - projectEntity, err := handler.GetProjectByID(ctx, h.projectRepo, stackEntity.Project.ID) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Destroying stack...", "stackID", params.StackID) - directory, workDir, err := getWorkDirFromSource(ctx, stackEntity, projectEntity) + err = h.stackManager.DestroyStack(ctx, params.StackID, params.Workspace, params.Detail, params.Dryrun, w) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - destroyOptions := buildOptions(dryRunParam) - stack.Path = workDir - - // Cleanup - defer sourceapi.Cleanup(ctx, directory) - - // Compute state storage - // TODO: this local storage is temporary, will support remote later - stateStorage := remoteBackend.StateStorage(project.Name, stack.Name, workspaceParam) - // logger.Info("Local state storage found", "Path", stateStorage) - logger.Info("Remote state storage found", "Remote", stateStorage) - - priorState, err := stateStorage.Get() - if err != nil || priorState == nil { - logger.Info("can't find state", "project", project.Name, "stack", stack.Name, "workspace", workspaceParam) - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingStateForStack)) - return - } - destroyResources := priorState.Resources - - if destroyResources == nil || len(priorState.Resources) == 0 { - render.Render(w, r, handler.SuccessResponse(ctx, "No managed resources to destroy")) - return - } - - // compute changes for preview - i := &apiv1.Spec{Resources: destroyResources} - changes, err := engineapi.DestroyPreview(destroyOptions, i, project, stack, stateStorage) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Summary preview table - changes.Summary(w, true) - // detail detection - if detailParam { - changes.OutputDiff("all") - } - - // if dryrun, print the hint - if dryRunParam { - fmt.Printf("Dry-run mode enabled, the above resources will be destroyed if dryrun is set to false") - render.Render(w, r, handler.SuccessResponse(ctx, "Dry-run mode enabled, the above resources will be destroyed if dryrun is set to false")) - return - } - - // Destroy - logger.Info("Start destroying resources......") - if err = engineapi.Destroy(destroyOptions, i, changes, stateStorage); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return + if err == stackmanager.ErrDryrunDestroy { + render.Render(w, r, handler.SuccessResponse(ctx, "Dry-run mode enabled, the above resources will be destroyed if dryrun is set to false")) + return + } else { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } } // Destroy completed logger.Info("destroy completed") render.Render(w, r, handler.SuccessResponse(ctx, "destroy completed")) + } +} - // Cleanup - sourceapi.Cleanup(ctx, directory) +func requestHelper(r *http.Request) (context.Context, *logr.Logger, *StackRequestParams, error) { + ctx := r.Context() + stackID := chi.URLParam(r, "stackID") + // Get stack with repository + id, err := strconv.Atoi(stackID) + if err != nil { + return nil, nil, nil, stackmanager.ErrInvalidStackID + } + logger := util.GetLogger(ctx) + // Get Params + detailParam, _ := strconv.ParseBool(r.URL.Query().Get("detail")) + dryrunParam, _ := strconv.ParseBool(r.URL.Query().Get("dryrun")) + outputParam := r.URL.Query().Get("output") + // TODO: Should match automatically eventually??? + workspaceParam := r.URL.Query().Get("workspace") + params := StackRequestParams{ + StackID: uint(id), + Workspace: workspaceParam, + Detail: detailParam, + Dryrun: dryrunParam, + Format: outputParam, } + return ctx, &logger, ¶ms, nil } diff --git a/pkg/server/handler/stack/handler.go b/pkg/server/handler/stack/handler.go index f217c0c0..920e6521 100644 --- a/pkg/server/handler/stack/handler.go +++ b/pkg/server/handler/stack/handler.go @@ -1,17 +1,9 @@ package stack import ( - "errors" "net/http" - "strconv" - "time" - "github.com/go-chi/chi/v5" "github.com/go-chi/render" - "github.com/jinzhu/copier" - "gorm.io/gorm" - "kusionstack.io/kusion/pkg/domain/constant" - "kusionstack.io/kusion/pkg/domain/entity" "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/server/handler" "kusionstack.io/kusion/pkg/server/util" @@ -35,7 +27,6 @@ func (h *Handler) CreateStack() http.HandlerFunc { ctx := r.Context() logger := util.GetLogger(ctx) logger.Info("Creating stack...") - // workspaceParam := chi.URLParam(r, "workspaceName") // Decode the request body into the payload. var requestPayload request.CreateStackRequest @@ -44,50 +35,7 @@ func (h *Handler) CreateStack() http.HandlerFunc { return } - // Convert request payload to domain model - var createdEntity entity.Stack - if err := copier.Copy(&createdEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // The default state is UnSynced - createdEntity.SyncState = constant.StackStateUnSynced - createdEntity.CreationTimestamp = time.Now() - createdEntity.UpdateTimestamp = time.Now() - createdEntity.LastSyncTimestamp = time.Unix(0, 0) // default to none - - // TODO: Only project ID should be needed here. Not source and org IDs. - // Get source by id - // sourceEntity, err := handler.GetSourceByID(ctx, h.sourceRepo, requestPayload.SourceID) - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - // createdEntity.Source = sourceEntity - - // Get project by id - projectEntity, err := handler.GetProjectByID(ctx, h.projectRepo, requestPayload.ProjectID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - createdEntity.Project = projectEntity - - // // Get organization by id - // organizationEntity, err := handler.GetOrganizationByID(ctx, h.orgRepository, requestPayload.OrganizationID) - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - // createdEntity.Organization = organizationEntity - // TODO: Only project ID should be needed here. Not source and org IDs. - - // Create stack with repository - err = h.stackRepo.Create(ctx, &createdEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } + createdEntity, err := h.stackManager.CreateStack(ctx, requestPayload) handler.HandleResult(w, r, ctx, err, createdEntity) } } @@ -107,22 +55,14 @@ func (h *Handler) CreateStack() http.HandlerFunc { func (h *Handler) DeleteStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Deleting source...") - stackID := chi.URLParam(r, "stackID") - - // Delete stack with repository - id, err := strconv.Atoi(stackID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidStacktID)) - return - } - err = h.stackRepo.Delete(ctx, uint(id)) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Deleting source...", "stackID", params.StackID) + + err = h.stackManager.DeleteStackByID(ctx, params.StackID) handler.HandleResult(w, r, ctx, err, "Deletion Success") } } @@ -142,17 +82,12 @@ func (h *Handler) DeleteStack() http.HandlerFunc { func (h *Handler) UpdateStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Updating stack...") - stackID := chi.URLParam(r, "stackID") - - // convert stack ID to int - id, err := strconv.Atoi(stackID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidStacktID)) + render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Updating stack...", "stackID", params.StackID) // Decode the request body into the payload. var requestPayload request.UpdateStackRequest @@ -161,61 +96,7 @@ func (h *Handler) UpdateStack() http.HandlerFunc { return } - // Convert request payload to domain model - var requestEntity entity.Stack - if err := copier.Copy(&requestEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // TODO: Only project ID should be needed here. Not source and org IDs. - // Get source by id - // sourceEntity, err := handler.GetSourceByID(ctx, h.sourceRepo, requestPayload.SourceID) - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - // requestEntity.Source = sourceEntity - - // Get project by id - projectEntity, err := handler.GetProjectByID(ctx, h.projectRepo, requestPayload.ProjectID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - requestEntity.Project = projectEntity - - // // Get organization by id - // organizationEntity, err := handler.GetOrganizationByID(ctx, h.orgRepository, requestPayload.OrganizationID) - // if err != nil { - // render.Render(w, r, handler.FailureResponse(ctx, err)) - // return - // } - // requestEntity.Organization = organizationEntity - // TODO: Only project ID should be needed here. Not source and org IDs. - - // Get the existing stack by id - updatedEntity, err := h.stackRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrUpdatingNonExistingStack)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Overwrite non-zero values in request entity to existing entity - copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) - - // Update stack with repository - err = h.stackRepo.Update(ctx, updatedEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return updated stack + updatedEntity, err := h.stackManager.UpdateStackByID(ctx, params.StackID, requestPayload) handler.HandleResult(w, r, ctx, err, updatedEntity) } } @@ -234,28 +115,14 @@ func (h *Handler) UpdateStack() http.HandlerFunc { func (h *Handler) GetStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Getting stack...") - stackID := chi.URLParam(r, "stackID") - - // Get stack with repository - id, err := strconv.Atoi(stackID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidStacktID)) - return - } - existingEntity, err := h.stackRepo.Get(ctx, uint(id)) + ctx, logger, params, err := requestHelper(r) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingStack)) - return - } render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Getting stack...", "stackID", params.StackID) - // Return found stack + existingEntity, err := h.stackManager.GetStackByID(ctx, params.StackID) handler.HandleResult(w, r, ctx, err, existingEntity) } } @@ -277,17 +144,7 @@ func (h *Handler) ListStacks() http.HandlerFunc { logger := util.GetLogger(ctx) logger.Info("Listing stack...") - stackEntities, err := h.stackRepo.List(ctx) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingStack)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return found stacks + stackEntities, err := h.stackManager.ListStacks(ctx) handler.HandleResult(w, r, ctx, err, stackEntities) } } diff --git a/pkg/server/handler/stack/handler_test.go b/pkg/server/handler/stack/handler_test.go new file mode 100644 index 00000000..acf18350 --- /dev/null +++ b/pkg/server/handler/stack/handler_test.go @@ -0,0 +1,330 @@ +package stack + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/constant" + "kusionstack.io/kusion/pkg/domain/request" + "kusionstack.io/kusion/pkg/infra/persistence" + "kusionstack.io/kusion/pkg/server/handler" + stackmanager "kusionstack.io/kusion/pkg/server/manager/stack" +) + +func TestStackHandler(t *testing.T) { + var ( + stackName = "test-stack" + stackNameSecond = "test-stack-2" + projectName = "test-project" + projectPath = "/path/to/project" + stackPath = "/path/to/stack" + stackNameUpdated = "test-stack-updated" + stackPathUpdated = "/path/to/stack/updated" + owners = persistence.MultiString{"hua.li", "xiaoming.li"} + ) + t.Run("ListStacks", func(t *testing.T) { + sqlMock, fakeGDB, recorder, stackHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "sync_state", "Project__id", "Project__name", "Project__path"}). + AddRow(1, stackName, stackPath, constant.StackStateUnSynced, 1, projectName, projectPath). + AddRow(2, stackNameSecond, stackPath, constant.StackStateUnSynced, 2, projectName, projectPath)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/stacks", nil) + assert.NoError(t, err) + + // Call the ListStacks handler function + stackHandler.ListStacks()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, 2, len(resp.Data.([]any))) + }) + + t.Run("GetStack", func(t *testing.T) { + sqlMock, fakeGDB, recorder, stackHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "sync_state", "Project__id", "Project__name", "Project__path"}). + AddRow(1, stackName, stackPath, constant.StackStateUnSynced, 1, projectName, projectPath)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/stack/{stackID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("stackID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Call the ListStacks handler function + stackHandler.GetStack()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, stackName, resp.Data.(map[string]any)["name"]) + assert.Equal(t, stackPath, resp.Data.(map[string]any)["path"]) + assert.Equal(t, float64(1), resp.Data.(map[string]any)["project"].(map[string]any)["id"]) + }) + + t.Run("CreateStack", func(t *testing.T) { + sqlMock, fakeGDB, recorder, stackHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("POST", "/stack/{stackID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("stackID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.CreateStackRequest{ + Name: stackName, + Path: stackPath, + DesiredVersion: "latest", + ProjectID: 1, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "Organization__id", "Organization__name", "Organization__owners", "Source__id", "Source__remote", "Source__source_provider"}). + AddRow(1, projectName, projectPath, 1, "test-org", owners, 1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + sqlMock.ExpectBegin() + sqlMock.ExpectExec("INSERT"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + sqlMock.ExpectCommit() + + // Call the CreateStack handler function + stackHandler.CreateStack()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, stackName, resp.Data.(map[string]any)["name"]) + assert.Equal(t, stackPath, resp.Data.(map[string]any)["path"]) + assert.Equal(t, "latest", resp.Data.(map[string]any)["desiredVersion"]) + assert.Equal(t, float64(1), resp.Data.(map[string]any)["project"].(map[string]any)["id"]) + }) + + t.Run("UpdateExistingStack", func(t *testing.T) { + sqlMock, fakeGDB, recorder, stackHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("PUT", "/stack/{stackID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("stackID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateStackRequest{ + // Set your request payload fields here + ID: 1, + Name: stackNameUpdated, + Path: stackPathUpdated, + ProjectID: 1, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "Organization__id", "Organization__name", "Organization__owners", "Source__id", "Source__remote", "Source__source_provider"}). + AddRow(1, stackName, stackPath, 1, "test-org", owners, 1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "sync_state", "Project__id", "Project__name", "Project__path"}). + AddRow(1, stackName, stackPath, constant.StackStateUnSynced, 1, projectName, projectPath)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + + // Call the ListStacks handler function + stackHandler.UpdateStack()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, stackNameUpdated, resp.Data.(map[string]any)["name"]) + assert.Equal(t, stackPathUpdated, resp.Data.(map[string]any)["path"]) + }) + + t.Run("Delete Existing Stack", func(t *testing.T) { + sqlMock, fakeGDB, recorder, stackHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/stack/{stackID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("stackID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Mock the Delete method of the stack repository + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). + AddRow(1)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(1, 1)) + sqlMock.ExpectCommit() + + // Call the DeleteStack handler function + stackHandler.DeleteStack()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, "Deletion Success", resp.Data) + }) + + t.Run("Delete Nonexisting Stack", func(t *testing.T) { + sqlMock, fakeGDB, recorder, stackHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/stack/{stackID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("stackID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the DeleteStack handler function + stackHandler.DeleteStack()(recorder, req) + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, stackmanager.ErrGettingNonExistingStack.Error(), resp.Message) + }) + + t.Run("Update Nonexisting Stack", func(t *testing.T) { + sqlMock, fakeGDB, recorder, stackHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("POST", "/stack/{stackID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("stackID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateStackRequest{ + // Set your request payload fields here + ID: 1, + Name: "test-stack-updated", + Path: stackPathUpdated, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "Organization__id", "Organization__name", "Organization__owners", "Source__id", "Source__remote", "Source__source_provider"}). + AddRow(1, stackName, stackPath, 1, "test-org", owners, 1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the UpdateStack handler function + stackHandler.UpdateStack()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, stackmanager.ErrUpdatingNonExistingStack.Error(), resp.Message) + }) +} + +func setupTest(t *testing.T) (sqlmock.Sqlmock, *gorm.DB, *httptest.ResponseRecorder, *Handler) { + fakeGDB, sqlMock, err := persistence.GetMockDB() + require.NoError(t, err) + stackRepo := persistence.NewStackRepository(fakeGDB) + projectRepo := persistence.NewProjectRepository(fakeGDB) + workspaceRepo := persistence.NewWorkspaceRepository(fakeGDB) + stackHandler := &Handler{ + stackManager: stackmanager.NewStackManager(stackRepo, projectRepo, workspaceRepo), + } + recorder := httptest.NewRecorder() + return sqlMock, fakeGDB, recorder, stackHandler +} diff --git a/pkg/server/handler/stack/types.go b/pkg/server/handler/stack/types.go index 549a3279..5f7c41f5 100644 --- a/pkg/server/handler/stack/types.go +++ b/pkg/server/handler/stack/types.go @@ -1,43 +1,25 @@ package stack import ( - "errors" - - "kusionstack.io/kusion/pkg/domain/repository" -) - -const Stdout = "stdout" - -var ( - ErrGettingNonExistingStack = errors.New("the stack does not exist") - ErrUpdatingNonExistingStack = errors.New("the stack to update does not exist") - ErrSourceNotFound = errors.New("the specified source does not exist") - ErrWorkspaceNotFound = errors.New("the specified workspace does not exist") - ErrProjectNotFound = errors.New("the specified project does not exist") - ErrInvalidStacktID = errors.New("the stack ID should be a uuid") - ErrGettingNonExistingStateForStack = errors.New("can not find State in this stack") + stackmanager "kusionstack.io/kusion/pkg/server/manager/stack" ) func NewHandler( - orgRepository repository.OrganizationRepository, - projectRepo repository.ProjectRepository, - stackRepo repository.StackRepository, - sourceRepo repository.SourceRepository, - workspaceRepo repository.WorkspaceRepository, + stackManager *stackmanager.StackManager, ) (*Handler, error) { return &Handler{ - orgRepository: orgRepository, - stackRepo: stackRepo, - projectRepo: projectRepo, - sourceRepo: sourceRepo, - workspaceRepo: workspaceRepo, + stackManager: stackManager, }, nil } type Handler struct { - orgRepository repository.OrganizationRepository - projectRepo repository.ProjectRepository - stackRepo repository.StackRepository - sourceRepo repository.SourceRepository - workspaceRepo repository.WorkspaceRepository + stackManager *stackmanager.StackManager +} + +type StackRequestParams struct { + StackID uint + Workspace string + Format string + Detail bool + Dryrun bool } diff --git a/pkg/server/handler/types.go b/pkg/server/handler/types.go index 091a91ee..3d55744d 100644 --- a/pkg/server/handler/types.go +++ b/pkg/server/handler/types.go @@ -8,7 +8,6 @@ import ( ) var ( - ErrSourceDoesNotExist = errors.New("the source does not exist") ErrProjectDoesNotExist = errors.New("the project does not exist") ErrOrganizationDoesNotExist = errors.New("the organization does not exist") ErrStackDoesNotExist = errors.New("the stack does not exist") diff --git a/pkg/server/handler/util.go b/pkg/server/handler/util.go index 7a617df2..37ce1cec 100644 --- a/pkg/server/handler/util.go +++ b/pkg/server/handler/util.go @@ -8,6 +8,7 @@ import ( "gorm.io/gorm" "kusionstack.io/kusion/pkg/domain/entity" "kusionstack.io/kusion/pkg/domain/repository" + sourcemanager "kusionstack.io/kusion/pkg/server/manager/source" ) func HandleResult(w http.ResponseWriter, r *http.Request, ctx context.Context, err error, data any) { @@ -22,7 +23,7 @@ func GetSourceByID(ctx context.Context, sourceRepo repository.SourceRepository, // Get source by id sourceEntity, err := sourceRepo.Get(ctx, id) if err != nil && err == gorm.ErrRecordNotFound { - return nil, ErrSourceDoesNotExist + return nil, sourcemanager.ErrGettingNonExistingSource } else if err != nil { return nil, err } diff --git a/pkg/server/handler/workspace/handler.go b/pkg/server/handler/workspace/handler.go index 61db7c43..4050ee39 100644 --- a/pkg/server/handler/workspace/handler.go +++ b/pkg/server/handler/workspace/handler.go @@ -1,18 +1,16 @@ package workspace import ( - "errors" + "context" "net/http" "strconv" - "time" "github.com/go-chi/chi/v5" "github.com/go-chi/render" - "github.com/jinzhu/copier" - "gorm.io/gorm" - "kusionstack.io/kusion/pkg/domain/entity" + "github.com/go-logr/logr" "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/server/handler" + workspacemanager "kusionstack.io/kusion/pkg/server/manager/workspace" "kusionstack.io/kusion/pkg/server/util" ) @@ -42,33 +40,7 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { return } - // Convert request payload to domain model - var createdEntity entity.Workspace - if err := copier.Copy(&createdEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // The default state is UnSynced - createdEntity.CreationTimestamp = time.Now() - createdEntity.UpdateTimestamp = time.Now() - - // Get backend by id - backendEntity, err := h.backendRepo.Get(ctx, requestPayload.BackendID) - if err != nil && err == gorm.ErrRecordNotFound { - render.Render(w, r, handler.FailureResponse(ctx, ErrBackendNotFound)) - return - } else if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - createdEntity.Backend = backendEntity - - // Create workspace with repository - err = h.workspaceRepo.Create(ctx, &createdEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } + createdEntity, err := h.workspaceManager.CreateWorkspace(ctx, requestPayload) handler.HandleResult(w, r, ctx, err, createdEntity) } } @@ -88,22 +60,14 @@ func (h *Handler) CreateWorkspace() http.HandlerFunc { func (h *Handler) DeleteWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Deleting source...") - workspaceID := chi.URLParam(r, "workspaceID") - - // Delete workspace with repository - id, err := strconv.Atoi(workspaceID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidWorkspaceID)) - return - } - err = h.workspaceRepo.Delete(ctx, uint(id)) + ctx, logger, params, err := requestHelper(r) if err != nil { render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Deleting source...", "workspaceID", params.WorkspaceID) + + err = h.workspaceManager.DeleteWorkspaceByID(ctx, params.WorkspaceID) handler.HandleResult(w, r, ctx, err, "Deletion Success") } } @@ -123,17 +87,12 @@ func (h *Handler) DeleteWorkspace() http.HandlerFunc { func (h *Handler) UpdateWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Updating workspace...") - workspaceID := chi.URLParam(r, "workspaceID") - - // convert workspace ID to int - id, err := strconv.Atoi(workspaceID) + ctx, logger, params, err := requestHelper(r) if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidWorkspaceID)) + render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Updating workspace...", "workspaceID", params.WorkspaceID) // Decode the request body into the payload. var requestPayload request.UpdateWorkspaceRequest @@ -142,35 +101,7 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { return } - // Convert request payload to domain model - var requestEntity entity.Workspace - if err := copier.Copy(&requestEntity, &requestPayload); err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Get the existing workspace by id - updatedEntity, err := h.workspaceRepo.Get(ctx, uint(id)) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrUpdatingNonExistingWorkspace)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Overwrite non-zero values in request entity to existing entity - copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) - - // Update workspace with repository - err = h.workspaceRepo.Update(ctx, updatedEntity) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - - // Return updated workspace + updatedEntity, err := h.workspaceManager.UpdateWorkspaceByID(ctx, params.WorkspaceID, requestPayload) handler.HandleResult(w, r, ctx, err, updatedEntity) } } @@ -189,28 +120,15 @@ func (h *Handler) UpdateWorkspace() http.HandlerFunc { func (h *Handler) GetWorkspace() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Getting stuff from context - ctx := r.Context() - logger := util.GetLogger(ctx) - logger.Info("Getting workspace...") - workspaceID := chi.URLParam(r, "workspaceID") - - // Get workspace with repository - id, err := strconv.Atoi(workspaceID) - if err != nil { - render.Render(w, r, handler.FailureResponse(ctx, ErrInvalidWorkspaceID)) - return - } - existingEntity, err := h.workspaceRepo.Get(ctx, uint(id)) + ctx, logger, params, err := requestHelper(r) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingWorkspace)) - return - } render.Render(w, r, handler.FailureResponse(ctx, err)) return } + logger.Info("Getting workspace...", "workspaceID", params.WorkspaceID) // Return found workspace + existingEntity, err := h.workspaceManager.GetWorkspaceByID(ctx, params.WorkspaceID) handler.HandleResult(w, r, ctx, err, existingEntity) } } @@ -232,17 +150,23 @@ func (h *Handler) ListWorkspaces() http.HandlerFunc { logger := util.GetLogger(ctx) logger.Info("Listing workspace...") - workspaceEntities, err := h.workspaceRepo.List(ctx) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - render.Render(w, r, handler.FailureResponse(ctx, ErrGettingNonExistingWorkspace)) - return - } - render.Render(w, r, handler.FailureResponse(ctx, err)) - return - } - // Return found workspaces + workspaceEntities, err := h.workspaceManager.ListWorkspaces(ctx) handler.HandleResult(w, r, ctx, err, workspaceEntities) } } + +func requestHelper(r *http.Request) (context.Context, *logr.Logger, *WorkspaceRequestParams, error) { + ctx := r.Context() + workspaceID := chi.URLParam(r, "workspaceID") + // Get stack with repository + id, err := strconv.Atoi(workspaceID) + if err != nil { + return nil, nil, nil, workspacemanager.ErrInvalidWorkspaceID + } + logger := util.GetLogger(ctx) + params := WorkspaceRequestParams{ + WorkspaceID: uint(id), + } + return ctx, &logger, ¶ms, nil +} diff --git a/pkg/server/handler/workspace/handler_test.go b/pkg/server/handler/workspace/handler_test.go new file mode 100644 index 00000000..12a499b9 --- /dev/null +++ b/pkg/server/handler/workspace/handler_test.go @@ -0,0 +1,307 @@ +package workspace + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/request" + "kusionstack.io/kusion/pkg/infra/persistence" + "kusionstack.io/kusion/pkg/server/handler" + workspacemanager "kusionstack.io/kusion/pkg/server/manager/workspace" +) + +func TestWorkspaceHandler(t *testing.T) { + var ( + wsName = "test-ws" + wsNameUpdated = "test-ws-updated" + ) + t.Run("ListWorkspaces", func(t *testing.T) { + sqlMock, fakeGDB, recorder, workspaceHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "Backend__id"}). + AddRow(1, "test-ws", 1). + AddRow(2, "test-ws-2", 2)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/workspaces", nil) + assert.NoError(t, err) + + // Call the ListWorkspaces handler function + workspaceHandler.ListWorkspaces()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, 2, len(resp.Data.([]any))) + }) + + t.Run("GetWorkspace", func(t *testing.T) { + sqlMock, fakeGDB, recorder, workspaceHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "Backend__id"}). + AddRow(1, wsName, 1)) + + // Create a new HTTP request + req, err := http.NewRequest("GET", "/workspace/{workspaceID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("workspaceID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Call the ListWorkspaces handler function + workspaceHandler.GetWorkspace()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, wsName, resp.Data.(map[string]any)["name"]) + }) + + t.Run("CreateWorkspace", func(t *testing.T) { + sqlMock, fakeGDB, recorder, workspaceHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("POST", "/workspace/{workspaceID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("workspaceID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.CreateWorkspaceRequest{ + Name: wsName, + BackendID: 1, + Owners: []string{"hua.li", "xiaoming.li"}, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). + AddRow(1)) + sqlMock.ExpectBegin() + sqlMock.ExpectExec("INSERT"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + sqlMock.ExpectCommit() + + // Call the CreateWorkspace handler function + workspaceHandler.CreateWorkspace()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshal the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, wsName, resp.Data.(map[string]any)["name"]) + }) + + t.Run("UpdateExistingWorkspace", func(t *testing.T) { + sqlMock, fakeGDB, recorder, workspaceHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("POST", "/workspace/{workspaceID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("workspaceID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateWorkspaceRequest{ + ID: 1, + Name: wsNameUpdated, + BackendID: 1, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "Backend__id"}). + AddRow(1, "test-ws-updated", 1)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) + + // Call the ListWorkspaces handler function + workspaceHandler.UpdateWorkspace()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, float64(1), resp.Data.(map[string]any)["id"]) + assert.Equal(t, wsNameUpdated, resp.Data.(map[string]any)["name"]) + }) + + t.Run("Delete Existing Workspace", func(t *testing.T) { + sqlMock, fakeGDB, recorder, workspaceHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/workspace/{workspaceID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("workspaceID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Mock the Delete method of the workspace repository + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). + AddRow(1)) + sqlMock.ExpectExec("UPDATE"). + WillReturnResult(sqlmock.NewResult(1, 1)) + sqlMock.ExpectCommit() + + // Call the DeleteWorkspace handler function + workspaceHandler.DeleteWorkspace()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, "Deletion Success", resp.Data) + }) + + t.Run("Delete Nonexisting Workspace", func(t *testing.T) { + sqlMock, fakeGDB, recorder, workspaceHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Create a new HTTP request + req, err := http.NewRequest("DELETE", "/workspace/{workspaceID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("workspaceID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + sqlMock.ExpectBegin() + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the DeleteWorkspace handler function + workspaceHandler.DeleteWorkspace()(recorder, req) + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, workspacemanager.ErrGettingNonExistingWorkspace.Error(), resp.Message) + }) + + t.Run("Update Nonexisting Workspace", func(t *testing.T) { + sqlMock, fakeGDB, recorder, workspaceHandler := setupTest(t) + defer persistence.CloseDB(t, fakeGDB) + defer sqlMock.ExpectClose() + + // Update a new HTTP request + req, err := http.NewRequest("POST", "/workspace/{workspaceID}", nil) + assert.NoError(t, err) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("workspaceID", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + // Set request body + requestPayload := request.UpdateWorkspaceRequest{ + // Set your request payload fields here + ID: 1, + Name: "test-ws-updated", + BackendID: 1, + } + reqBody, err := json.Marshal(requestPayload) + assert.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + req.Header.Add("Content-Type", "application/json") + + sqlMock.ExpectQuery("SELECT"). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + // Call the UpdateWorkspace handler function + workspaceHandler.UpdateWorkspace()(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Unmarshall the response body + var resp handler.Response + err = json.Unmarshal(recorder.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Assertion + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, false, resp.Success) + assert.Equal(t, workspacemanager.ErrUpdatingNonExistingWorkspace.Error(), resp.Message) + }) +} + +func setupTest(t *testing.T) (sqlmock.Sqlmock, *gorm.DB, *httptest.ResponseRecorder, *Handler) { + fakeGDB, sqlMock, err := persistence.GetMockDB() + require.NoError(t, err) + workspaceRepo := persistence.NewWorkspaceRepository(fakeGDB) + backendRepo := persistence.NewBackendRepository(fakeGDB) + workspaceHandler := &Handler{ + workspaceManager: workspacemanager.NewWorkspaceManager(workspaceRepo, backendRepo), + } + recorder := httptest.NewRecorder() + return sqlMock, fakeGDB, recorder, workspaceHandler +} diff --git a/pkg/server/handler/workspace/types.go b/pkg/server/handler/workspace/types.go index ae848580..01291793 100644 --- a/pkg/server/handler/workspace/types.go +++ b/pkg/server/handler/workspace/types.go @@ -1,29 +1,21 @@ package workspace import ( - "errors" - - "kusionstack.io/kusion/pkg/domain/repository" -) - -var ( - ErrGettingNonExistingWorkspace = errors.New("the workspace does not exist") - ErrUpdatingNonExistingWorkspace = errors.New("the workspace to update does not exist") - ErrInvalidWorkspaceID = errors.New("the workspace ID should be a uuid") - ErrBackendNotFound = errors.New("the specified backend does not exist") + workspacemanager "kusionstack.io/kusion/pkg/server/manager/workspace" ) func NewHandler( - workspaceRepo repository.WorkspaceRepository, - backendRepo repository.BackendRepository, + workspaceManager *workspacemanager.WorkspaceManager, ) (*Handler, error) { return &Handler{ - workspaceRepo: workspaceRepo, - backendRepo: backendRepo, + workspaceManager: workspaceManager, }, nil } type Handler struct { - workspaceRepo repository.WorkspaceRepository - backendRepo repository.BackendRepository + workspaceManager *workspacemanager.WorkspaceManager +} + +type WorkspaceRequestParams struct { + WorkspaceID uint } diff --git a/pkg/server/manager/backend/backend_manager.go b/pkg/server/manager/backend/backend_manager.go new file mode 100644 index 00000000..cd6e6813 --- /dev/null +++ b/pkg/server/manager/backend/backend_manager.go @@ -0,0 +1,86 @@ +package backend + +import ( + "context" + "errors" + + "github.com/jinzhu/copier" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/entity" + "kusionstack.io/kusion/pkg/domain/request" +) + +func (m *BackendManager) ListBackends(ctx context.Context) ([]*entity.Backend, error) { + backendEntities, err := m.backendRepo.List(ctx) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingBackend + } + return nil, err + } + return backendEntities, nil +} + +func (m *BackendManager) GetBackendByID(ctx context.Context, id uint) (*entity.Backend, error) { + existingEntity, err := m.backendRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingBackend + } + return nil, err + } + return existingEntity, nil +} + +func (m *BackendManager) DeleteBackendByID(ctx context.Context, id uint) error { + err := m.backendRepo.Delete(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingBackend + } + return err + } + return nil +} + +func (m *BackendManager) UpdateBackendByID(ctx context.Context, id uint, requestPayload request.UpdateBackendRequest) (*entity.Backend, error) { + // Convert request payload to domain model + var requestEntity entity.Backend + if err := copier.Copy(&requestEntity, &requestPayload); err != nil { + return nil, err + } + + // Get the existing backend by id + updatedEntity, err := m.backendRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUpdatingNonExistingBackend + } + return nil, err + } + + // Overwrite non-zero values in request entity to existing entity + copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) + + // Update backend with repository + err = m.backendRepo.Update(ctx, updatedEntity) + if err != nil { + return nil, err + } + return updatedEntity, nil +} + +func (m *BackendManager) CreateBackend(ctx context.Context, requestPayload request.CreateBackendRequest) (*entity.Backend, error) { + // Convert request payload to domain model + var createdEntity entity.Backend + if err := copier.Copy(&createdEntity, &requestPayload); err != nil { + return nil, err + } + + // Create backend with repository + err := m.backendRepo.Create(ctx, &createdEntity) + if err != nil { + return nil, err + } + return &createdEntity, nil +} diff --git a/pkg/server/manager/backend/types.go b/pkg/server/manager/backend/types.go new file mode 100644 index 00000000..bf137f91 --- /dev/null +++ b/pkg/server/manager/backend/types.go @@ -0,0 +1,23 @@ +package backend + +import ( + "errors" + + "kusionstack.io/kusion/pkg/domain/repository" +) + +var ( + ErrGettingNonExistingBackend = errors.New("the backend does not exist") + ErrUpdatingNonExistingBackend = errors.New("the backend to update does not exist") + ErrInvalidBackendID = errors.New("the backend ID should be a uuid") +) + +type BackendManager struct { + backendRepo repository.BackendRepository +} + +func NewBackendManager(backendRepo repository.BackendRepository) *BackendManager { + return &BackendManager{ + backendRepo: backendRepo, + } +} diff --git a/pkg/server/manager/organization/organization_manager.go b/pkg/server/manager/organization/organization_manager.go new file mode 100644 index 00000000..8e7d62c0 --- /dev/null +++ b/pkg/server/manager/organization/organization_manager.go @@ -0,0 +1,91 @@ +package organization + +import ( + "context" + "errors" + "time" + + "github.com/jinzhu/copier" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/entity" + "kusionstack.io/kusion/pkg/domain/request" +) + +func (m *OrganizationManager) ListOrganizations(ctx context.Context) ([]*entity.Organization, error) { + organizationEntities, err := m.organizationRepo.List(ctx) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingOrganization + } + return nil, err + } + return organizationEntities, nil +} + +func (m *OrganizationManager) GetOrganizationByID(ctx context.Context, id uint) (*entity.Organization, error) { + existingEntity, err := m.organizationRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingOrganization + } + return nil, err + } + return existingEntity, nil +} + +func (m *OrganizationManager) DeleteOrganizationByID(ctx context.Context, id uint) error { + err := m.organizationRepo.Delete(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingOrganization + } + return err + } + return nil +} + +func (m *OrganizationManager) UpdateOrganizationByID(ctx context.Context, id uint, requestPayload request.UpdateOrganizationRequest) (*entity.Organization, error) { + // Convert request payload to domain model + var requestEntity entity.Organization + if err := copier.Copy(&requestEntity, &requestPayload); err != nil { + return nil, err + } + + // Get the existing organization by id + updatedEntity, err := m.organizationRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUpdatingNonExistingOrganization + } + return nil, err + } + + // Overwrite non-zero values in request entity to existing entity + copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) + + // Update organization with repository + err = m.organizationRepo.Update(ctx, updatedEntity) + if err != nil { + return nil, err + } + + return updatedEntity, nil +} + +func (m *OrganizationManager) CreateOrganization(ctx context.Context, requestPayload request.CreateOrganizationRequest) (*entity.Organization, error) { + // Convert request payload to domain model + var createdEntity entity.Organization + if err := copier.Copy(&createdEntity, &requestPayload); err != nil { + return nil, err + } + // The default state is UnSynced + createdEntity.CreationTimestamp = time.Now() + createdEntity.UpdateTimestamp = time.Now() + + // Create organization with repository + err := m.organizationRepo.Create(ctx, &createdEntity) + if err != nil { + return nil, err + } + return &createdEntity, nil +} diff --git a/pkg/server/manager/organization/types.go b/pkg/server/manager/organization/types.go new file mode 100644 index 00000000..bb72fbd5 --- /dev/null +++ b/pkg/server/manager/organization/types.go @@ -0,0 +1,23 @@ +package organization + +import ( + "errors" + + "kusionstack.io/kusion/pkg/domain/repository" +) + +var ( + ErrGettingNonExistingOrganization = errors.New("the organization does not exist") + ErrUpdatingNonExistingOrganization = errors.New("the organization to update does not exist") + ErrInvalidOrganizationID = errors.New("the organization ID should be a uuid") +) + +type OrganizationManager struct { + organizationRepo repository.OrganizationRepository +} + +func NewOrganizationManager(organizationRepo repository.OrganizationRepository) *OrganizationManager { + return &OrganizationManager{ + organizationRepo: organizationRepo, + } +} diff --git a/pkg/server/manager/project/project_manager.go b/pkg/server/manager/project/project_manager.go new file mode 100644 index 00000000..fe55d376 --- /dev/null +++ b/pkg/server/manager/project/project_manager.go @@ -0,0 +1,120 @@ +package project + +import ( + "context" + "errors" + + "github.com/jinzhu/copier" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/entity" + "kusionstack.io/kusion/pkg/domain/request" + "kusionstack.io/kusion/pkg/server/handler" +) + +func (m *ProjectManager) ListProjects(ctx context.Context) ([]*entity.Project, error) { + projectEntities, err := m.projectRepo.List(ctx) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingProject + } + return nil, err + } + return projectEntities, nil +} + +func (m *ProjectManager) GetProjectByID(ctx context.Context, id uint) (*entity.Project, error) { + existingEntity, err := m.projectRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingProject + } + return nil, err + } + return existingEntity, nil +} + +func (m *ProjectManager) DeleteProjectByID(ctx context.Context, id uint) error { + err := m.projectRepo.Delete(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingProject + } + return err + } + return nil +} + +func (m *ProjectManager) UpdateProjectByID(ctx context.Context, id uint, requestPayload request.UpdateProjectRequest) (*entity.Project, error) { + // Convert request payload to domain model + var requestEntity entity.Project + if err := copier.Copy(&requestEntity, &requestPayload); err != nil { + return nil, err + } + + // Get source by id + sourceEntity, err := handler.GetSourceByID(ctx, m.sourceRepo, requestPayload.SourceID) + if err != nil { + return nil, err + } + requestEntity.Source = sourceEntity + + // Get organization by id + organizationEntity, err := handler.GetOrganizationByID(ctx, m.organizationRepo, requestPayload.OrganizationID) + if err != nil { + return nil, err + } + requestEntity.Organization = organizationEntity + + // Get the existing project by id + updatedEntity, err := m.projectRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUpdatingNonExistingProject + } + return nil, err + } + + // Overwrite non-zero values in request entity to existing entity + copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) + // fmt.Printf("updatedEntity.Source: %v; updatedEntity.Organization: %v", updatedEntity.Source, updatedEntity.Organization) + + // Update project with repository + err = m.projectRepo.Update(ctx, updatedEntity) + if err != nil { + return nil, err + } + return updatedEntity, nil +} + +func (m *ProjectManager) CreateProject(ctx context.Context, requestPayload request.CreateProjectRequest) (*entity.Project, error) { + // Convert request payload to domain model + var createdEntity entity.Project + if err := copier.Copy(&createdEntity, &requestPayload); err != nil { + return nil, err + } + + // Get source by id + sourceEntity, err := m.sourceRepo.Get(ctx, requestPayload.SourceID) + if err != nil && err == gorm.ErrRecordNotFound { + return nil, ErrSourceNotFound + } else if err != nil { + return nil, err + } + createdEntity.Source = sourceEntity + + // Get org by id + organizationEntity, err := m.organizationRepo.Get(ctx, requestPayload.OrganizationID) + if err != nil && err == gorm.ErrRecordNotFound { + return nil, ErrOrgNotFound + } else if err != nil { + return nil, err + } + createdEntity.Organization = organizationEntity + + // Create project with repository + err = m.projectRepo.Create(ctx, &createdEntity) + if err != nil { + return nil, err + } + return &createdEntity, nil +} diff --git a/pkg/server/manager/project/types.go b/pkg/server/manager/project/types.go new file mode 100644 index 00000000..da18ceed --- /dev/null +++ b/pkg/server/manager/project/types.go @@ -0,0 +1,29 @@ +package project + +import ( + "errors" + + "kusionstack.io/kusion/pkg/domain/repository" +) + +var ( + ErrGettingNonExistingProject = errors.New("the project does not exist") + ErrUpdatingNonExistingProject = errors.New("the project to update does not exist") + ErrSourceNotFound = errors.New("the specified source does not exist") + ErrOrgNotFound = errors.New("the specified org does not exist") + ErrInvalidProjectID = errors.New("the project ID should be a uuid") +) + +type ProjectManager struct { + projectRepo repository.ProjectRepository + organizationRepo repository.OrganizationRepository + sourceRepo repository.SourceRepository +} + +func NewProjectManager(projectRepo repository.ProjectRepository, organizationRepo repository.OrganizationRepository, sourceRepo repository.SourceRepository) *ProjectManager { + return &ProjectManager{ + projectRepo: projectRepo, + organizationRepo: organizationRepo, + sourceRepo: sourceRepo, + } +} diff --git a/pkg/server/manager/source/source_manager.go b/pkg/server/manager/source/source_manager.go new file mode 100644 index 00000000..d38235ae --- /dev/null +++ b/pkg/server/manager/source/source_manager.go @@ -0,0 +1,101 @@ +package source + +import ( + "context" + "errors" + "net/url" + + "github.com/jinzhu/copier" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/entity" + "kusionstack.io/kusion/pkg/domain/request" +) + +func (m *SourceManager) ListSources(ctx context.Context) ([]*entity.Source, error) { + sourceEntities, err := m.sourceRepo.List(ctx) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingSource + } + return nil, err + } + return sourceEntities, nil +} + +func (m *SourceManager) GetSourceByID(ctx context.Context, id uint) (*entity.Source, error) { + existingEntity, err := m.sourceRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingSource + } + return nil, err + } + return existingEntity, nil +} + +func (m *SourceManager) DeleteSourceByID(ctx context.Context, id uint) error { + err := m.sourceRepo.Delete(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingSource + } + return err + } + return nil +} + +func (m *SourceManager) UpdateSourceByID(ctx context.Context, id uint, requestPayload request.UpdateSourceRequest) (*entity.Source, error) { + // Convert request payload to domain model + var requestEntity entity.Source + if err := copier.Copy(&requestEntity, &requestPayload); err != nil { + return nil, err + } + + // Convert Remote string to URL + remote, err := url.Parse(requestPayload.Remote) + if err != nil { + return nil, err + } + requestEntity.Remote = remote + + // Get the existing source by id + updatedEntity, err := m.sourceRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUpdatingNonExistingSource + } + return nil, err + } + + // Overwrite non-zero values in request entity to existing entity + copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) + + // Update source with repository + err = m.sourceRepo.Update(ctx, updatedEntity) + if err != nil { + return nil, err + } + return updatedEntity, nil +} + +func (m *SourceManager) CreateSource(ctx context.Context, requestPayload request.CreateSourceRequest) (*entity.Source, error) { + // Convert request payload to domain model + var createdEntity entity.Source + if err := copier.Copy(&createdEntity, &requestPayload); err != nil { + return nil, err + } + + // Convert Remote string to URL + remote, err := url.Parse(requestPayload.Remote) + if err != nil { + return nil, err + } + createdEntity.Remote = remote + + // Create source with repository + err = m.sourceRepo.Create(ctx, &createdEntity) + if err != nil { + return nil, err + } + return &createdEntity, nil +} diff --git a/pkg/server/manager/source/types.go b/pkg/server/manager/source/types.go new file mode 100644 index 00000000..a8e5741d --- /dev/null +++ b/pkg/server/manager/source/types.go @@ -0,0 +1,23 @@ +package source + +import ( + "errors" + + "kusionstack.io/kusion/pkg/domain/repository" +) + +var ( + ErrGettingNonExistingSource = errors.New("the source does not exist") + ErrUpdatingNonExistingSource = errors.New("the source to update does not exist") + ErrInvalidSourceID = errors.New("the source ID should be a uuid") +) + +type SourceManager struct { + sourceRepo repository.SourceRepository +} + +func NewSourceManager(sourceRepo repository.SourceRepository) *SourceManager { + return &SourceManager{ + sourceRepo: sourceRepo, + } +} diff --git a/pkg/server/manager/stack/stack_manager.go b/pkg/server/manager/stack/stack_manager.go new file mode 100644 index 00000000..71396b44 --- /dev/null +++ b/pkg/server/manager/stack/stack_manager.go @@ -0,0 +1,335 @@ +package stack + +import ( + "context" + "errors" + "net/http" + "os" + "time" + + "github.com/jinzhu/copier" + "gorm.io/gorm" + v1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" + "kusionstack.io/kusion/pkg/backend" + "kusionstack.io/kusion/pkg/domain/constant" + "kusionstack.io/kusion/pkg/domain/entity" + "kusionstack.io/kusion/pkg/domain/repository" + "kusionstack.io/kusion/pkg/domain/request" + + engineapi "kusionstack.io/kusion/pkg/engine/api" + "kusionstack.io/kusion/pkg/engine/operation/models" + + sourceapi "kusionstack.io/kusion/pkg/engine/api/source" + "kusionstack.io/kusion/pkg/server/handler" + "kusionstack.io/kusion/pkg/server/util" +) + +func NewStackManager(stackRepo repository.StackRepository, projectRepo repository.ProjectRepository, workspaceRepo repository.WorkspaceRepository) *StackManager { + return &StackManager{ + stackRepo: stackRepo, + projectRepo: projectRepo, + workspaceRepo: workspaceRepo, + } +} + +func (m *StackManager) GenerateStack(ctx context.Context, id uint, workspaceName string) (*v1.Spec, error) { + logger := util.GetLogger(ctx) + logger.Info("Starting generating spec in StackManager ...") + + // Generate a stack + stackEntity, err := m.stackRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingStack + } + return nil, err + } + + // Get project by id + project, err := stackEntity.Project.ConvertToCore() + if err != nil { + return nil, err + } + + // Get stack by id + stack, err := stackEntity.ConvertToCore() + if err != nil { + return nil, err + } + + // get workspace configurations + bk, err := backend.NewBackend("") + if err != nil { + return nil, err + } + wsStorage, err := bk.WorkspaceStorage() + if err != nil { + return nil, err + } + ws, err := wsStorage.Get(workspaceName) + if err != nil { + return nil, err + } + + // Build API inputs + // get project to get source and workdir + projectEntity, err := handler.GetProjectByID(ctx, m.projectRepo, stackEntity.Project.ID) + if err != nil { + return nil, err + } + + directory, workDir, err := GetWorkDirFromSource(ctx, stackEntity, projectEntity) + logger.Info("workDir derived", "workDir", workDir) + logger.Info("directory derived", "directory", directory) + + stack.Path = workDir + if err != nil { + return nil, err + } + // intentOptions, _ := buildOptions(workDir, kpmParam, false) + // Cleanup + defer sourceapi.Cleanup(ctx, directory) + + // Generate spec + return engineapi.GenerateSpecWithSpinner(project, stack, ws, true) +} + +func (m *StackManager) PreviewStack(ctx context.Context, id uint, workspaceName string) (*models.Changes, error) { + logger := util.GetLogger(ctx) + logger.Info("Starting previewing stack in StackManager ...") + _, changes, _, err := m.previewHelper(ctx, id, workspaceName) + return changes, err +} + +func (m *StackManager) ApplyStack(ctx context.Context, id uint, workspaceName, format string, detail, dryrun bool, w http.ResponseWriter) error { + logger := util.GetLogger(ctx) + logger.Info("Starting applying stack in StackManager ...") + + // Get the stack entity by id + stackEntity, err := m.stackRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingStack + } + return err + } + + // Preview a stack + sp, changes, stateStorage, err := m.previewHelper(ctx, id, workspaceName) + if err != nil { + return err + } + + _, err = ProcessChanges(ctx, w, changes, format, detail) + if err != nil { + return err + } + + // if dry run, print the hint + if dryrun { + logger.Info("NOTE: Currently running in the --dry-run mode, the above configuration does not really take effect") + return ErrDryrunDestroy + } + + logger.Info("Dryrun set to false. Start applying diffs ...") + executeOptions := BuildOptions(dryrun) + if err = engineapi.Apply(executeOptions, stateStorage, sp, changes, os.Stdout); err != nil { + return err + } + + // Update LastSyncTimestamp to current time and set stack syncState to synced + stackEntity.LastSyncTimestamp = time.Now() + stackEntity.SyncState = constant.StackStateSynced + + // Update stack with repository + err = m.stackRepo.Update(ctx, stackEntity) + if err != nil { + return err + } + + return nil +} + +func (m *StackManager) DestroyStack(ctx context.Context, id uint, workspaceName string, detail, dryrun bool, w http.ResponseWriter) error { + logger := util.GetLogger(ctx) + logger.Info("Starting applying stack in StackManager ...") + + // Get the stack entity by id + stackEntity, err := m.stackRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingStack + } + return err + } + + // Get project by id + project, err := stackEntity.Project.ConvertToCore() + if err != nil { + return err + } + + // Get stack by id + stack, err := stackEntity.ConvertToCore() + if err != nil { + return err + } + + stateBackend, err := m.getBackendFromWorkspaceName(ctx, workspaceName) + if err != nil { + return err + } + + // Build API inputs + // get project to get source and workdir + projectEntity, err := handler.GetProjectByID(ctx, m.projectRepo, stackEntity.Project.ID) + if err != nil { + return err + } + + directory, workDir, err := GetWorkDirFromSource(ctx, stackEntity, projectEntity) + if err != nil { + return err + } + destroyOptions := BuildOptions(dryrun) + stack.Path = workDir + + // Cleanup + defer sourceapi.Cleanup(ctx, directory) + + // Compute state storage + stateStorage := stateBackend.StateStorage(project.Name, stack.Name, workspaceName) + logger.Info("Remote state storage found", "Remote", stateStorage) + + priorState, err := stateStorage.Get() + if err != nil || priorState == nil { + logger.Info("can't find state", "project", project.Name, "stack", stack.Name, "workspace", workspaceName) + return ErrGettingNonExistingStateForStack + } + destroyResources := priorState.Resources + + if destroyResources == nil || len(priorState.Resources) == 0 { + return ErrNoManagedResourceToDestroy + } + + // compute changes for preview + i := &v1.Spec{Resources: destroyResources} + changes, err := engineapi.DestroyPreview(destroyOptions, i, project, stack, stateStorage) + if err != nil { + return err + } + + // Summary preview table + changes.Summary(w, true) + // detail detection + if detail { + changes.OutputDiff("all") + } + + // if dryrun, print the hint + if dryrun { + logger.Info("Dry-run mode enabled, the above resources will be destroyed if dryrun is set to false") + return ErrDryrunDestroy + } + + // Destroy + logger.Info("Start destroying resources......") + if err = engineapi.Destroy(destroyOptions, i, changes, stateStorage); err != nil { + return err + } + return nil +} + +func (m *StackManager) ListStacks(ctx context.Context) ([]*entity.Stack, error) { + stackEntities, err := m.stackRepo.List(ctx) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingStack + } + return nil, err + } + return stackEntities, nil +} + +func (m *StackManager) GetStackByID(ctx context.Context, id uint) (*entity.Stack, error) { + existingEntity, err := m.stackRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingStack + } + return nil, err + } + return existingEntity, nil +} + +func (m *StackManager) DeleteStackByID(ctx context.Context, id uint) error { + err := m.stackRepo.Delete(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingStack + } + return err + } + return nil +} + +func (m *StackManager) UpdateStackByID(ctx context.Context, id uint, requestPayload request.UpdateStackRequest) (*entity.Stack, error) { + // Convert request payload to domain model + var requestEntity entity.Stack + if err := copier.Copy(&requestEntity, &requestPayload); err != nil { + return nil, err + } + + // Get project by id + projectEntity, err := handler.GetProjectByID(ctx, m.projectRepo, requestPayload.ProjectID) + if err != nil { + return nil, err + } + requestEntity.Project = projectEntity + + // Get the existing stack by id + updatedEntity, err := m.stackRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUpdatingNonExistingStack + } + return nil, err + } + + // Overwrite non-zero values in request entity to existing entity + copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) + + // Update stack with repository + err = m.stackRepo.Update(ctx, updatedEntity) + if err != nil { + return nil, err + } + return updatedEntity, nil +} + +func (m *StackManager) CreateStack(ctx context.Context, requestPayload request.CreateStackRequest) (*entity.Stack, error) { + // Convert request payload to domain model + var createdEntity entity.Stack + if err := copier.Copy(&createdEntity, &requestPayload); err != nil { + return nil, err + } + // The default state is UnSynced + createdEntity.SyncState = constant.StackStateUnSynced + createdEntity.CreationTimestamp = time.Now() + createdEntity.UpdateTimestamp = time.Now() + createdEntity.LastSyncTimestamp = time.Unix(0, 0) // default to none + + // Get project by id + projectEntity, err := handler.GetProjectByID(ctx, m.projectRepo, requestPayload.ProjectID) + if err != nil { + return nil, err + } + createdEntity.Project = projectEntity + + // Create stack with repository + err = m.stackRepo.Create(ctx, &createdEntity) + if err != nil { + return nil, err + } + return &createdEntity, nil +} diff --git a/pkg/server/manager/stack/types.go b/pkg/server/manager/stack/types.go new file mode 100644 index 00000000..114f31b4 --- /dev/null +++ b/pkg/server/manager/stack/types.go @@ -0,0 +1,30 @@ +package stack + +import ( + "errors" + + "kusionstack.io/kusion/pkg/domain/repository" +) + +const ( + Stdout = "stdout" + NoDiffFound = "All resources are reconciled. No diff found" +) + +var ( + ErrGettingNonExistingStack = errors.New("the stack does not exist") + ErrUpdatingNonExistingStack = errors.New("the stack to update does not exist") + ErrSourceNotFound = errors.New("the specified source does not exist") + ErrWorkspaceNotFound = errors.New("the specified workspace does not exist") + ErrProjectNotFound = errors.New("the specified project does not exist") + ErrInvalidStackID = errors.New("the stack ID should be a uuid") + ErrGettingNonExistingStateForStack = errors.New("can not find State in this stack") + ErrNoManagedResourceToDestroy = errors.New("no managed resources to destroy") + ErrDryrunDestroy = errors.New("dryrun-mode is enabled, no resources will be destroyed") +) + +type StackManager struct { + stackRepo repository.StackRepository + projectRepo repository.ProjectRepository + workspaceRepo repository.WorkspaceRepository +} diff --git a/pkg/server/handler/stack/util.go b/pkg/server/manager/stack/util.go similarity index 51% rename from pkg/server/handler/stack/util.go rename to pkg/server/manager/stack/util.go index 26fc5027..71b6d762 100644 --- a/pkg/server/handler/stack/util.go +++ b/pkg/server/manager/stack/util.go @@ -2,10 +2,14 @@ package stack import ( "context" + "encoding/json" + "errors" "fmt" + "net/http" "path/filepath" "gorm.io/gorm" + apiv1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" v1 "kusionstack.io/kusion/pkg/apis/internal.kusion.io/v1" "kusionstack.io/kusion/pkg/backend" "kusionstack.io/kusion/pkg/backend/storages" @@ -13,35 +17,26 @@ import ( "kusionstack.io/kusion/pkg/domain/entity" engineapi "kusionstack.io/kusion/pkg/engine/api" sourceapi "kusionstack.io/kusion/pkg/engine/api/source" + "kusionstack.io/kusion/pkg/engine/operation/models" + "kusionstack.io/kusion/pkg/engine/state" + "kusionstack.io/kusion/pkg/server/handler" "kusionstack.io/kusion/pkg/server/util" ) -func buildOptions(dryrun bool) *engineapi.APIOptions { - // Construct intent options - // intentOptions := &buildersapi.Options{ - // IsKclPkg: kpmParam, - // WorkDir: workDir, - // Arguments: map[string]string{}, - // NoStyle: true, - // } - // Construct preview api option - // TODO: Complete preview options - // TODO: Operator should be derived from auth info - // TODO: Cluster should be derived from workspace config - previewOptions := &engineapi.APIOptions{ +func BuildOptions(dryrun bool) *engineapi.APIOptions { + executeOptions := &engineapi.APIOptions{ // Operator: "operator", // Cluster: "cluster", // IgnoreFields: []string{}, DryRun: dryrun, } - // return intentOptions, previewOptions - return previewOptions + return executeOptions } // getWorkDirFromSource returns the workdir based on the source // if the source type is local, it will return the path as an absolute path on the local filesystem // if the source type is remote (git for example), it will pull the source and return the path to the pulled source -func getWorkDirFromSource(ctx context.Context, stack *entity.Stack, project *entity.Project) (string, string, error) { +func GetWorkDirFromSource(ctx context.Context, stack *entity.Stack, project *entity.Project) (string, string, error) { logger := util.GetLogger(ctx) logger.Info("Getting workdir from stack source...") // TODO: Also copy the local workdir to /tmp directory? @@ -63,30 +58,6 @@ func getWorkDirFromSource(ctx context.Context, stack *entity.Stack, project *ent } func NewBackendFromEntity(backendEntity entity.Backend) (backend.Backend, error) { - // var emptyCfg bool - // cfg, err := config.GetConfig() - // if errors.Is(err, config.ErrEmptyConfig) { - // emptyCfg = true - // } else if err != nil { - // return nil, err - // } else if cfg.Backends == nil { - // emptyCfg = true - // } - - // var bkCfg *v1.BackendConfig - // if name == "" && (emptyCfg || cfg.Backends.Current == "") { - // // if empty backends config or empty current backend, use default local storage - // bkCfg = &v1.BackendConfig{Type: v1.BackendTypeLocal} - // } else { - // if name == "" { - // name = cfg.Backends.Current - // } - // bkCfg = cfg.Backends.Backends[name] - // if bkCfg == nil { - // return nil, fmt.Errorf("config of backend %s does not exist", name) - // } - // } - // TODO: refactor this so backend.NewBackend() share the same common logic var storage backend.Backend var err error @@ -133,11 +104,38 @@ func NewBackendFromEntity(backendEntity entity.Backend) (backend.Backend, error) return storage, nil } -func (h *Handler) GetBackendFromWorkspaceName(ctx context.Context, workspaceName string) (backend.Backend, error) { +func ProcessChanges(ctx context.Context, w http.ResponseWriter, changes *models.Changes, format string, detail bool) (string, error) { + logger := util.GetLogger(ctx) + logger.Info("Starting previewing stack in StackManager ...") + + if format == engineapi.JSONOutput { + previewChanges, err := json.Marshal(changes) + if err != nil { + return "", err + } + logger.Info(string(previewChanges)) + return string(previewChanges), nil + } + + if changes.AllUnChange() { + logger.Info(NoDiffFound) + return NoDiffFound, nil + } + + // Summary preview table + changes.Summary(w, true) + // detail detection + if detail { + return changes.Diffs(true), nil + } + return "", nil +} + +func (m *StackManager) getBackendFromWorkspaceName(ctx context.Context, workspaceName string) (backend.Backend, error) { logger := util.GetLogger(ctx) logger.Info("Getting backend based on workspace name...") // Get backend by id - workspaceEntity, err := h.workspaceRepo.GetByName(ctx, workspaceName) + workspaceEntity, err := m.workspaceRepo.GetByName(ctx, workspaceName) if err != nil && err == gorm.ErrRecordNotFound { return nil, err } else if err != nil { @@ -150,3 +148,88 @@ func (h *Handler) GetBackendFromWorkspaceName(ctx context.Context, workspaceName } return remoteBackend, nil } + +func (m *StackManager) previewHelper(ctx context.Context, id uint, workspaceName string) (*apiv1.Spec, *models.Changes, state.Storage, error) { + logger := util.GetLogger(ctx) + logger.Info("Starting previewing stack in StackManager ...") + + // Get the stack entity by id + stackEntity, err := m.stackRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil, nil, ErrGettingNonExistingStack + } + return nil, nil, nil, err + } + + // Get project by id + project, err := stackEntity.Project.ConvertToCore() + if err != nil { + return nil, nil, nil, err + } + + // Get stack by id + stack, err := stackEntity.ConvertToCore() + if err != nil { + return nil, nil, nil, err + } + + // Temp: LocalBackend for ws config and Remote Backend for state storage + // TODO: should use variable set eventually and remove this localBackend eventually + wsBackend, err := backend.NewBackend("") + if err != nil { + return nil, nil, nil, err + } + stateBackend, err := m.getBackendFromWorkspaceName(ctx, workspaceName) + if err != nil { + return nil, nil, nil, err + } + + // Get workspace configurations from backend + // TODO: temporarily local for now, should be replaced by variable sets + wsStorage, err := wsBackend.WorkspaceStorage() + if err != nil { + return nil, nil, nil, err + } + ws, err := wsStorage.Get(workspaceName) + if err != nil { + return nil, nil, nil, err + } + + // Build API inputs + // get project to get source and workdir + projectEntity, err := handler.GetProjectByID(ctx, m.projectRepo, stackEntity.Project.ID) + if err != nil { + return nil, nil, nil, err + } + + directory, workDir, err := GetWorkDirFromSource(ctx, stackEntity, projectEntity) + if err != nil { + return nil, nil, nil, err + } + executeOptions := BuildOptions(false) + stack.Path = workDir + + // Cleanup + defer sourceapi.Cleanup(ctx, directory) + + // Generate spec + sp, err := engineapi.GenerateSpecWithSpinner(project, stack, ws, true) + if err != nil { + return nil, nil, nil, err + } + + // return immediately if no resource found in stack + // todo: if there is no resource, should still do diff job; for now, if output is json format, there is no hint + if sp == nil || len(sp.Resources) == 0 { + logger.Info("No resource change found in this stack...") + return nil, nil, nil, nil + } + + // Compute state storage + stateStorage := stateBackend.StateStorage(project.Name, stack.Name, ws.Name) + logger.Info("Local state storage found", "Path", stateStorage) + + changes, err := engineapi.Preview(executeOptions, stateStorage, sp, project, stack) + return sp, changes, stateStorage, err +} diff --git a/pkg/server/manager/workspace/types.go b/pkg/server/manager/workspace/types.go new file mode 100644 index 00000000..7c78d256 --- /dev/null +++ b/pkg/server/manager/workspace/types.go @@ -0,0 +1,26 @@ +package workspace + +import ( + "errors" + + "kusionstack.io/kusion/pkg/domain/repository" +) + +var ( + ErrGettingNonExistingWorkspace = errors.New("the workspace does not exist") + ErrUpdatingNonExistingWorkspace = errors.New("the workspace to update does not exist") + ErrInvalidWorkspaceID = errors.New("the workspace ID should be a uuid") + ErrBackendNotFound = errors.New("the specified backend does not exist") +) + +type WorkspaceManager struct { + workspaceRepo repository.WorkspaceRepository + backendRepo repository.BackendRepository +} + +func NewWorkspaceManager(workspaceRepo repository.WorkspaceRepository, backendRepo repository.BackendRepository) *WorkspaceManager { + return &WorkspaceManager{ + workspaceRepo: workspaceRepo, + backendRepo: backendRepo, + } +} diff --git a/pkg/server/manager/workspace/workspace_manager.go b/pkg/server/manager/workspace/workspace_manager.go new file mode 100644 index 00000000..1adafa28 --- /dev/null +++ b/pkg/server/manager/workspace/workspace_manager.go @@ -0,0 +1,95 @@ +package workspace + +import ( + "context" + "errors" + + "github.com/jinzhu/copier" + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/entity" + "kusionstack.io/kusion/pkg/domain/request" +) + +func (m *WorkspaceManager) ListWorkspaces(ctx context.Context) ([]*entity.Workspace, error) { + workspaceEntities, err := m.workspaceRepo.List(ctx) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingWorkspace + } + return nil, err + } + return workspaceEntities, nil +} + +func (m *WorkspaceManager) GetWorkspaceByID(ctx context.Context, id uint) (*entity.Workspace, error) { + existingEntity, err := m.workspaceRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingWorkspace + } + return nil, err + } + return existingEntity, nil +} + +func (m *WorkspaceManager) DeleteWorkspaceByID(ctx context.Context, id uint) error { + err := m.workspaceRepo.Delete(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingWorkspace + } + return err + } + return nil +} + +func (m *WorkspaceManager) UpdateWorkspaceByID(ctx context.Context, id uint, requestPayload request.UpdateWorkspaceRequest) (*entity.Workspace, error) { + // Convert request payload to domain model + var requestEntity entity.Workspace + if err := copier.Copy(&requestEntity, &requestPayload); err != nil { + return nil, err + } + + // Get the existing workspace by id + updatedEntity, err := m.workspaceRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUpdatingNonExistingWorkspace + } + return nil, err + } + + // Overwrite non-zero values in request entity to existing entity + copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) + + // Update workspace with repository + err = m.workspaceRepo.Update(ctx, updatedEntity) + if err != nil { + return nil, err + } + return updatedEntity, nil +} + +func (m *WorkspaceManager) CreateWorkspace(ctx context.Context, requestPayload request.CreateWorkspaceRequest) (*entity.Workspace, error) { + // Convert request payload to domain model + var createdEntity entity.Workspace + if err := copier.Copy(&createdEntity, &requestPayload); err != nil { + return nil, err + } + + // Get backend by id + backendEntity, err := m.backendRepo.Get(ctx, requestPayload.BackendID) + if err != nil && err == gorm.ErrRecordNotFound { + return nil, ErrBackendNotFound + } else if err != nil { + return nil, err + } + createdEntity.Backend = backendEntity + + // Create workspace with repository + err = m.workspaceRepo.Create(ctx, &createdEntity) + if err != nil { + return nil, err + } + return &createdEntity, nil +} diff --git a/pkg/server/route/route.go b/pkg/server/route/route.go index 71594efc..b5703443 100644 --- a/pkg/server/route/route.go +++ b/pkg/server/route/route.go @@ -19,6 +19,12 @@ import ( "kusionstack.io/kusion/pkg/server/handler/source" "kusionstack.io/kusion/pkg/server/handler/stack" "kusionstack.io/kusion/pkg/server/handler/workspace" + backendmanager "kusionstack.io/kusion/pkg/server/manager/backend" + organizationmanager "kusionstack.io/kusion/pkg/server/manager/organization" + projectmanager "kusionstack.io/kusion/pkg/server/manager/project" + sourcemanager "kusionstack.io/kusion/pkg/server/manager/source" + stackmanager "kusionstack.io/kusion/pkg/server/manager/stack" + workspacemanager "kusionstack.io/kusion/pkg/server/manager/workspace" appmiddleware "kusionstack.io/kusion/pkg/server/middleware" "kusionstack.io/kusion/pkg/server/util" @@ -87,33 +93,40 @@ func setupRestAPIV1( workspaceRepo := persistence.NewWorkspaceRepository(config.DB) backendRepo := persistence.NewBackendRepository(config.DB) + stackManager := stackmanager.NewStackManager(stackRepo, projectRepo, workspaceRepo) + sourceManager := sourcemanager.NewSourceManager(sourceRepo) + organizationManager := organizationmanager.NewOrganizationManager(organizationRepo) + backendManager := backendmanager.NewBackendManager(backendRepo) + workspaceManager := workspacemanager.NewWorkspaceManager(workspaceRepo, backendRepo) + projectManager := projectmanager.NewProjectManager(projectRepo, organizationRepo, sourceRepo) + // Set up the handlers for the resources. - sourceHandler, err := source.NewHandler(sourceRepo) + sourceHandler, err := source.NewHandler(sourceManager) if err != nil { logger.Error(err, "Error creating source handler...", "error", err) return } - orgHandler, err := organization.NewHandler(organizationRepo) + orgHandler, err := organization.NewHandler(organizationManager) if err != nil { logger.Error(err, "Error creating org handler...", "error", err) return } - projectHandler, err := project.NewHandler(organizationRepo, projectRepo, sourceRepo) + projectHandler, err := project.NewHandler(projectManager) if err != nil { logger.Error(err, "Error creating project handler...", "error", err) return } - stackHandler, err := stack.NewHandler(organizationRepo, projectRepo, stackRepo, sourceRepo, workspaceRepo) + stackHandler, err := stack.NewHandler(stackManager) if err != nil { logger.Error(err, "Error creating stack handler...", "error", err) return } - workspaceHandler, err := workspace.NewHandler(workspaceRepo, backendRepo) + workspaceHandler, err := workspace.NewHandler(workspaceManager) if err != nil { logger.Error(err, "Error creating workspace handler...", "error", err) return } - backendHandler, err := backend.NewHandler(backendRepo) + backendHandler, err := backend.NewHandler(backendManager) if err != nil { logger.Error(err, "Error creating backend handler...", "error", err) return @@ -132,7 +145,7 @@ func setupRestAPIV1( r.Route("/stack", func(r chi.Router) { r.Route("/{stackID}", func(r chi.Router) { r.Post("/", stackHandler.CreateStack()) - r.Post("/build", stackHandler.BuildStack()) + r.Post("/generate", stackHandler.GenerateStack()) r.Post("/preview", stackHandler.PreviewStack()) r.Post("/apply", stackHandler.ApplyStack()) r.Post("/destroy", stackHandler.DestroyStack()) @@ -178,24 +191,4 @@ func setupRestAPIV1( }) r.Get("/", backendHandler.ListBackends()) }) - // r.Route("/project", func(r chi.Router) { - // //r.Get("/", projectHandler.ListProjects()) - // r.Route("/{projectName}", func(r chi.Router) { - // // r.Post("/", projectHandler.CreateProject()) - // // r.Get("/", projectHandler.GetProject()) - // // r.Put("/", projectHandler.UpdateProject()) - // // r.Delete("/", projectHandler.DeleteProject()) - // r.Route("/stack", func(r chi.Router) { - // //r.Get("/", stackHandler.ListStacks()) - // r.Route("/{stackName}", func(r chi.Router) { - // r.Post("/", stackHandler.CreateStack()) - // // r.Get("/", stackHandler.GetStack()) - // // r.Put("/", stackHandler.UpdateStack()) - // // r.Delete("/", stackHandler.DeleteStack()) - // r.Post("/preview", stack.ExecutePreview()) - // //r.Post("/apply", stack.ExecuteApply()) - // }) - // }) - // }) - // }) } diff --git a/pkg/server/route/route_test.go b/pkg/server/route/route_test.go new file mode 100644 index 00000000..d7764e20 --- /dev/null +++ b/pkg/server/route/route_test.go @@ -0,0 +1,59 @@ +package route + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "kusionstack.io/kusion/pkg/infra/persistence" + "kusionstack.io/kusion/pkg/server" +) + +// TestNewCoreRoute will test the NewCoreRoute function with different +// configurations. +func TestNewCoreRoute(t *testing.T) { + // Mock the NewSearchStorage function to return a mock storage instead of + // actual implementation. + + fakeGDB, _, err := persistence.GetMockDB() + require.NoError(t, err) + tests := []struct { + name string + config server.Config + expectError bool + expectRoutes []string + }{ + { + name: "route test", + config: server.Config{ + DB: fakeGDB, + }, + expectError: false, + expectRoutes: []string{ + "/endpoints", + "/server-configs", + "/api/v1/stack", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, err := NewCoreRoute(&tt.config) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + for _, route := range tt.expectRoutes { + req := httptest.NewRequest(http.MethodGet, route, nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + // Assert status code is not 404 to ensure the route exists. + require.NotEqual(t, http.StatusNotFound, rr.Code, "Route should exist: %s", route) + } + } + }) + } +}