diff --git a/clients/ui/bff/Makefile b/clients/ui/bff/Makefile index 36da91c58..3049355f6 100644 --- a/clients/ui/bff/Makefile +++ b/clients/ui/bff/Makefile @@ -2,6 +2,7 @@ CONTAINER_TOOL ?= docker IMG ?= model-registry-bff:latest PORT ?= 4000 MOCK_K8S_CLIENT ?= false +MOCK_MR_CLIENT ?= false .PHONY: all all: build @@ -32,7 +33,7 @@ build: fmt vet test .PHONY: run run: fmt vet - go run ./cmd/main.go --port=$(PORT) --mock-k8s-client=$(MOCK_K8S_CLIENT) + go run ./cmd/main.go --port=$(PORT) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT) .PHONY: docker-build docker-build: diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index 8333af39a..5a566e4ac 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -31,9 +31,9 @@ After building it, you can run our app with: ```shell make run ``` -If you want to use a different port or mock kubernetes client, useful for front-end development, you can run: +If you want to use a different port, mock kubernetes client or model registry client - useful for front-end development, you can run: ```shell -make run PORT=8000 MOCK_K8S_CLIENT=true +make run PORT=8000 MOCK_K8S_CLIENT=true MOCK_MR_CLIENT=true ``` # Building and Deploying diff --git a/clients/ui/bff/api/app.go b/clients/ui/bff/api/app.go index 603d74397..41d1a52ab 100644 --- a/clients/ui/bff/api/app.go +++ b/clients/ui/bff/api/app.go @@ -23,10 +23,11 @@ const ( ) type App struct { - config config.EnvConfig - logger *slog.Logger - models data.Models - kubernetesClient integrations.KubernetesClientInterface + config config.EnvConfig + logger *slog.Logger + models data.Models + kubernetesClient integrations.KubernetesClientInterface + modelRegistryClient data.ModelRegistryClientInterface } func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) { @@ -43,10 +44,23 @@ func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) { return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) } + var mrClient data.ModelRegistryClientInterface + + if cfg.MockMRClient { + mrClient, err = mocks.NewModelRegistryClient(logger) + } else { + mrClient, err = data.NewModelRegistryClient(logger) + } + + if err != nil { + return nil, fmt.Errorf("failed to create ModelRegistry client: %w", err) + } + app := &App{ - config: cfg, - logger: logger, - kubernetesClient: k8sClient, + config: cfg, + logger: logger, + kubernetesClient: k8sClient, + modelRegistryClient: mrClient, } return app, nil } @@ -59,7 +73,7 @@ func (app *App) Routes() http.Handler { // HTTP client routes router.GET(HealthCheckPath, app.HealthcheckHandler) - router.GET(RegisteredModelsPath, app.AttachRESTClient(app.GetRegisteredModelsHandler)) + router.GET(RegisteredModelsPath, app.AttachRESTClient(app.GetAllRegisteredModelsHandler)) router.GET(RegisteredModelPath, app.AttachRESTClient(app.GetRegisteredModelHandler)) router.POST(RegisteredModelsPath, app.AttachRESTClient(app.CreateRegisteredModelHandler)) diff --git a/clients/ui/bff/api/helpers.go b/clients/ui/bff/api/helpers.go index a53c3c4bb..f35851af2 100644 --- a/clients/ui/bff/api/helpers.go +++ b/clients/ui/bff/api/helpers.go @@ -11,6 +11,8 @@ import ( type Envelope map[string]interface{} +type TypedEnvelope[T any] map[string]T + func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers http.Header) error { js, err := json.MarshalIndent(data, "", "\t") diff --git a/clients/ui/bff/api/registered_models_handler.go b/clients/ui/bff/api/registered_models_handler.go index 9e5f768b6..9e47d833a 100644 --- a/clients/ui/bff/api/registered_models_handler.go +++ b/clients/ui/bff/api/registered_models_handler.go @@ -6,13 +6,12 @@ import ( "fmt" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/pkg/openapi" - "github.com/kubeflow/model-registry/ui/bff/data" "github.com/kubeflow/model-registry/ui/bff/integrations" "github.com/kubeflow/model-registry/ui/bff/validation" "net/http" ) -func (app *App) GetRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { //TODO (ederign) implement pagination client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) if !ok { @@ -20,14 +19,14 @@ func (app *App) GetRegisteredModelsHandler(w http.ResponseWriter, r *http.Reques return } - modelList, err := data.GetAllRegisteredModels(client) + modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client) if err != nil { app.serverErrorResponse(w, r, err) return } modelRegistryRes := Envelope{ - "registered_models": modelList, + "registered_model_list": modelList, } err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil) @@ -60,7 +59,7 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ return } - createdModel, err := data.CreateRegisteredModel(client, jsonData) + createdModel, err := app.modelRegistryClient.CreateRegisteredModel(client, jsonData) if err != nil { var httpErr *integrations.HTTPError if errors.As(err, &httpErr) { @@ -91,13 +90,13 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request return } - model, err := data.GetRegisteredModel(client, ps.ByName(RegisteredModelId)) + model, err := app.modelRegistryClient.GetRegisteredModel(client, ps.ByName(RegisteredModelId)) if err != nil { app.serverErrorResponse(w, r, err) return } - if _, ok := model.GetNameOk(); !ok { + if _, ok := model.GetIdOk(); !ok { app.notFoundResponse(w, r) return } diff --git a/clients/ui/bff/api/registered_models_handler_test.go b/clients/ui/bff/api/registered_models_handler_test.go new file mode 100644 index 000000000..ec1eb0392 --- /dev/null +++ b/clients/ui/bff/api/registered_models_handler_test.go @@ -0,0 +1,135 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/internals/mocks" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetRegisteredModelHandler(t *testing.T) { + mockMRClient, _ := mocks.NewModelRegistryClient(nil) + mockClient := new(mocks.MockHTTPClient) + + testApp := App{ + modelRegistryClient: mockMRClient, + } + + req, err := http.NewRequest(http.MethodGet, + "/api/v1/model-registry/model-registry/registered_models/1", nil) + assert.NoError(t, err) + + ctx := context.WithValue(req.Context(), httpClientKey, mockClient) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + + testApp.GetRegisteredModelHandler(rr, req, nil) + rs := rr.Result() + + defer rs.Body.Close() + + body, err := io.ReadAll(rs.Body) + assert.NoError(t, err) + var registeredModelRes TypedEnvelope[openapi.RegisteredModel] + err = json.Unmarshal(body, ®isteredModelRes) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, rr.Code) + + var expected = TypedEnvelope[openapi.RegisteredModel]{ + "registered_model": mocks.GetRegisteredModelMocks()[0], + } + + //TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values + // this issue is in the test only + assert.Equal(t, expected["registered_model"].Name, registeredModelRes["registered_model"].Name) +} + +func TestGetAllRegisteredModelsHandler(t *testing.T) { + mockMRClient, _ := mocks.NewModelRegistryClient(nil) + mockClient := new(mocks.MockHTTPClient) + + testApp := App{ + modelRegistryClient: mockMRClient, + } + + req, err := http.NewRequest(http.MethodGet, + "/api/v1/model-registry/model-registry/registered_models", nil) + assert.NoError(t, err) + + ctx := context.WithValue(req.Context(), httpClientKey, mockClient) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + + testApp.GetAllRegisteredModelsHandler(rr, req, nil) + rs := rr.Result() + + defer rs.Body.Close() + + body, err := io.ReadAll(rs.Body) + assert.NoError(t, err) + var registeredModelsListRes TypedEnvelope[openapi.RegisteredModelList] + err = json.Unmarshal(body, ®isteredModelsListRes) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, rr.Code) + + var expected = TypedEnvelope[openapi.RegisteredModelList]{ + "registered_model_list": mocks.GetRegisteredModelListMock(), + } + + assert.Equal(t, expected["registered_model_list"].Size, registeredModelsListRes["registered_model_list"].Size) + assert.Equal(t, expected["registered_model_list"].PageSize, registeredModelsListRes["registered_model_list"].PageSize) + assert.Equal(t, expected["registered_model_list"].NextPageToken, registeredModelsListRes["registered_model_list"].NextPageToken) + assert.Equal(t, len(expected["registered_model_list"].Items), len(registeredModelsListRes["registered_model_list"].Items)) +} + +func TestCreateRegisteredModelHandler(t *testing.T) { + mockMRClient, _ := mocks.NewModelRegistryClient(nil) + mockClient := new(mocks.MockHTTPClient) + + testApp := App{ + modelRegistryClient: mockMRClient, + } + + newModel := openapi.NewRegisteredModelCreate("Model One") + newModelJSON, err := newModel.MarshalJSON() + assert.NoError(t, err) + + reqBody := bytes.NewReader(newModelJSON) + + req, err := http.NewRequest(http.MethodPost, + "/api/v1/model-registry/model-registry/registered_models", reqBody) + assert.NoError(t, err) + + ctx := context.WithValue(req.Context(), httpClientKey, mockClient) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + + testApp.CreateRegisteredModelHandler(rr, req, nil) + rs := rr.Result() + + defer rs.Body.Close() + + body, err := io.ReadAll(rs.Body) + assert.NoError(t, err) + var registeredModelRes openapi.RegisteredModel + err = json.Unmarshal(body, ®isteredModelRes) + assert.NoError(t, err) + + assert.Equal(t, http.StatusCreated, rr.Code) + + var expected = mocks.GetRegisteredModelMocks()[0] + + assert.Equal(t, expected.Name, registeredModelRes.Name) + assert.NotEmpty(t, rs.Header.Get("location")) +} diff --git a/clients/ui/bff/cmd/main.go b/clients/ui/bff/cmd/main.go index 5b52d9079..4b4b96ab1 100644 --- a/clients/ui/bff/cmd/main.go +++ b/clients/ui/bff/cmd/main.go @@ -17,6 +17,7 @@ func main() { var cfg config.EnvConfig flag.IntVar(&cfg.Port, "port", getEnvAsInt("PORT", 4000), "API server port") flag.BoolVar(&cfg.MockK8Client, "mock-k8s-client", false, "Use mock Kubernetes client") + flag.BoolVar(&cfg.MockMRClient, "mock-mr-client", false, "Use mock Model Registry client") flag.Parse() logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) diff --git a/clients/ui/bff/config/environment.go b/clients/ui/bff/config/environment.go index 0abcd067b..f7b10bf63 100644 --- a/clients/ui/bff/config/environment.go +++ b/clients/ui/bff/config/environment.go @@ -3,4 +3,5 @@ package config type EnvConfig struct { Port int MockK8Client bool + MockMRClient bool } diff --git a/clients/ui/bff/data/model_registry_client.go b/clients/ui/bff/data/model_registry_client.go new file mode 100644 index 000000000..ccc4ff9ed --- /dev/null +++ b/clients/ui/bff/data/model_registry_client.go @@ -0,0 +1,18 @@ +package data + +import ( + "log/slog" +) + +type ModelRegistryClientInterface interface { + RegisteredModelInterface +} + +type ModelRegistryClient struct { + logger *slog.Logger + RegisteredModel +} + +func NewModelRegistryClient(logger *slog.Logger) (ModelRegistryClientInterface, error) { + return &ModelRegistryClient{logger: logger}, nil +} diff --git a/clients/ui/bff/data/registered_model.go b/clients/ui/bff/data/registered_model.go index b6b4c9bef..e8a797f26 100644 --- a/clients/ui/bff/data/registered_model.go +++ b/clients/ui/bff/data/registered_model.go @@ -11,7 +11,17 @@ import ( const registerModelPath = "/registered_models" -func GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) { +type RegisteredModelInterface interface { + GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) + CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) + GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) +} + +type RegisteredModel struct { + RegisteredModelInterface +} + +func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) { responseData, err := client.GET(registerModelPath) if err != nil { @@ -26,7 +36,7 @@ func GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.R return &modelList, nil } -func CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) { +func (m RegisteredModel) CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) { responseData, err := client.POST(registerModelPath, bytes.NewBuffer(jsonData)) if err != nil { @@ -41,7 +51,7 @@ func CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []b return &model, nil } -func GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) { +func (m RegisteredModel) GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) { path, err := url.JoinPath(registerModelPath, id) if err != nil { return nil, err diff --git a/clients/ui/bff/data/registered_model_test.go b/clients/ui/bff/data/registered_model_test.go index aa0d85de2..6a8d75a66 100644 --- a/clients/ui/bff/data/registered_model_test.go +++ b/clients/ui/bff/data/registered_model_test.go @@ -17,10 +17,12 @@ func TestGetAllRegisteredModels(t *testing.T) { mockData, err := json.Marshal(expected) assert.NoError(t, err) + mrClient := ModelRegistryClient{} + mockClient := new(mocks.MockHTTPClient) mockClient.On("GET", registerModelPath).Return(mockData, nil) - actual, err := GetAllRegisteredModels(mockClient) + actual, err := mrClient.GetAllRegisteredModels(mockClient) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.NextPageToken, actual.NextPageToken) @@ -39,13 +41,15 @@ func TestCreateRegisteredModel(t *testing.T) { mockData, err := json.Marshal(expected) assert.NoError(t, err) + mrClient := ModelRegistryClient{} + mockClient := new(mocks.MockHTTPClient) mockClient.On("POST", registerModelPath, mock.Anything).Return(mockData, nil) jsonInput, err := json.Marshal(expected) assert.NoError(t, err) - actual, err := CreateRegisteredModel(mockClient, jsonInput) + actual, err := mrClient.CreateRegisteredModel(mockClient, jsonInput) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.Name, actual.Name) @@ -62,10 +66,12 @@ func TestGetRegisteredModel(t *testing.T) { mockData, err := json.Marshal(expected) assert.NoError(t, err) + mrClient := ModelRegistryClient{} + mockClient := new(mocks.MockHTTPClient) mockClient.On("GET", registerModelPath+"/"+expected.GetId()).Return(mockData, nil) - actual, err := GetRegisteredModel(mockClient, expected.GetId()) + actual, err := mrClient.GetRegisteredModel(mockClient, expected.GetId()) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.Name, actual.Name) diff --git a/clients/ui/bff/internals/mocks/model_registry_client_mock.go b/clients/ui/bff/internals/mocks/model_registry_client_mock.go new file mode 100644 index 000000000..966ab8ec2 --- /dev/null +++ b/clients/ui/bff/internals/mocks/model_registry_client_mock.go @@ -0,0 +1,31 @@ +package mocks + +import ( + "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/integrations" + "github.com/stretchr/testify/mock" + "log/slog" +) + +type ModelRegistryClientMock struct { + mock.Mock +} + +func NewModelRegistryClient(logger *slog.Logger) (*ModelRegistryClientMock, error) { + return &ModelRegistryClientMock{}, nil +} + +func (m *ModelRegistryClientMock) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) { + mockData := GetRegisteredModelListMock() + return &mockData, nil +} + +func (m *ModelRegistryClientMock) CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) { + mockData := GetRegisteredModelMocks()[0] + return &mockData, nil +} + +func (m *ModelRegistryClientMock) GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) { + mockData := GetRegisteredModelMocks()[0] + return &mockData, nil +} diff --git a/clients/ui/bff/internals/mocks/static_data_mock.go b/clients/ui/bff/internals/mocks/static_data_mock.go new file mode 100644 index 000000000..732210030 --- /dev/null +++ b/clients/ui/bff/internals/mocks/static_data_mock.go @@ -0,0 +1,58 @@ +package mocks + +import ( + "github.com/kubeflow/model-registry/pkg/openapi" +) + +func GetRegisteredModelMocks() []openapi.RegisteredModel { + model1 := openapi.RegisteredModel{ + CustomProperties: &map[string]openapi.MetadataValue{ + "my-label9": { + MetadataStringValue: &openapi.MetadataStringValue{ + StringValue: "property9", + MetadataType: "string", + }, + }, + }, + Name: "Model One", + Description: stringToPointer("This model does things and stuff"), + ExternalId: stringToPointer("934589798"), + Id: stringToPointer("1"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + Owner: stringToPointer("Sherlock Holmes"), + State: stateToPointer(openapi.REGISTEREDMODELSTATE_LIVE), + } + + model2 := openapi.RegisteredModel{ + CustomProperties: &map[string]openapi.MetadataValue{ + "my-label9": { + MetadataStringValue: &openapi.MetadataStringValue{ + StringValue: "property9", + MetadataType: "string", + }, + }, + }, + Name: "Model Two", + Description: stringToPointer("This model does things and stuff"), + ExternalId: stringToPointer("345235987"), + Id: stringToPointer("2"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + Owner: stringToPointer("John Watson"), + State: stateToPointer(openapi.REGISTEREDMODELSTATE_LIVE), + } + + return []openapi.RegisteredModel{model1, model2} +} + +func GetRegisteredModelListMock() openapi.RegisteredModelList { + models := GetRegisteredModelMocks() + + return openapi.RegisteredModelList{ + NextPageToken: "abcdefgh", + PageSize: 2, + Size: int32(len(models)), + Items: models, + } +}