diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index 932da820f..b0e0a3e57 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -67,6 +67,7 @@ make docker-build | PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} | UpdateModelVersionHandler | Update a ModelVersion entity by ID | | GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | GetAllModelVersionsForRegisteredModelHandler | Get all ModelVersion entities by RegisteredModel ID | | POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | CreateModelVersionForRegisteredModelHandler | Create a ModelVersion entity for a specific RegisteredModel | +| GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | GetAllModelArtifactsByModelVersionHandler | Get all ModelArtifact entities by ModelVersion ID | ### Sample local calls ``` @@ -184,4 +185,8 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/regi "state": "LIVE", "author": "alex" }}' +``` +``` +# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts +curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts ``` \ No newline at end of file diff --git a/clients/ui/bff/api/app.go b/clients/ui/bff/api/app.go index e9c1b8309..3d5fd0fe9 100644 --- a/clients/ui/bff/api/app.go +++ b/clients/ui/bff/api/app.go @@ -13,19 +13,20 @@ import ( ) const ( - Version = "1.0.0" - PathPrefix = "/api/v1" - ModelRegistryId = "model_registry_id" - RegisteredModelId = "registered_model_id" - ModelVersionId = "model_version_id" - HealthCheckPath = PathPrefix + "/healthcheck" - ModelRegistryListPath = PathPrefix + "/model_registry" - ModelRegistryPath = ModelRegistryListPath + "/:" + ModelRegistryId - RegisteredModelListPath = ModelRegistryPath + "/registered_models" - RegisteredModelPath = RegisteredModelListPath + "/:" + RegisteredModelId - RegisteredModelVersionsPath = RegisteredModelPath + "/versions" - ModelVersionListPath = ModelRegistryPath + "/model_versions" - ModelVersionPath = ModelVersionListPath + "/:" + ModelVersionId + Version = "1.0.0" + PathPrefix = "/api/v1" + ModelRegistryId = "model_registry_id" + RegisteredModelId = "registered_model_id" + ModelVersionId = "model_version_id" + HealthCheckPath = PathPrefix + "/healthcheck" + ModelRegistryListPath = PathPrefix + "/model_registry" + ModelRegistryPath = ModelRegistryListPath + "/:" + ModelRegistryId + RegisteredModelListPath = ModelRegistryPath + "/registered_models" + RegisteredModelPath = RegisteredModelListPath + "/:" + RegisteredModelId + RegisteredModelVersionsPath = RegisteredModelPath + "/versions" + ModelVersionListPath = ModelRegistryPath + "/model_versions" + ModelVersionPath = ModelVersionListPath + "/:" + ModelVersionId + ModelVersionArtifactListPath = ModelVersionPath + "/artifacts" ) type App struct { @@ -89,6 +90,7 @@ func (app *App) Routes() http.Handler { router.GET(ModelVersionPath, app.AttachRESTClient(app.GetModelVersionHandler)) router.POST(ModelVersionListPath, app.AttachRESTClient(app.CreateModelVersionHandler)) router.PATCH(ModelVersionPath, app.AttachRESTClient(app.UpdateModelVersionHandler)) + router.GET(ModelVersionArtifactListPath, app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler)) // Kubernetes client routes router.GET(ModelRegistryListPath, app.ModelRegistryHandler) diff --git a/clients/ui/bff/api/model_versions_handler.go b/clients/ui/bff/api/model_versions_handler.go index f65dc2674..33dc001db 100644 --- a/clients/ui/bff/api/model_versions_handler.go +++ b/clients/ui/bff/api/model_versions_handler.go @@ -13,6 +13,7 @@ import ( type ModelVersionEnvelope Envelope[*openapi.ModelVersion, None] type ModelVersionListEnvelope Envelope[*openapi.ModelVersionList, None] +type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None] func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) @@ -148,3 +149,26 @@ func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request return } } + +func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId)) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + result := ModelArtifactListEnvelope{ + Data: data, + } + + err = app.WriteJSON(w, http.StatusOK, result, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} diff --git a/clients/ui/bff/api/model_versions_handler_test.go b/clients/ui/bff/api/model_versions_handler_test.go index 1b845b31a..3c426392e 100644 --- a/clients/ui/bff/api/model_versions_handler_test.go +++ b/clients/ui/bff/api/model_versions_handler_test.go @@ -45,3 +45,17 @@ func TestUpdateModelVersionHandler(t *testing.T) { assert.Equal(t, http.StatusOK, rs.StatusCode) assert.Equal(t, expected.Data.Name, actual.Data.Name) } + +func TestGetAllModelArtifactsByModelVersionHandler(t *testing.T) { + data := mocks.GetModelArtifactListMock() + expected := ModelArtifactListEnvelope{Data: &data} + + actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, rs.StatusCode) + assert.Equal(t, expected.Data.Size, actual.Data.Size) + assert.Equal(t, expected.Data.PageSize, actual.Data.PageSize) + assert.Equal(t, expected.Data.NextPageToken, actual.Data.NextPageToken) + assert.Equal(t, len(expected.Data.Items), len(actual.Data.Items)) +} diff --git a/clients/ui/bff/data/model_version.go b/clients/ui/bff/data/model_version.go index a8c0fa628..f93318049 100644 --- a/clients/ui/bff/data/model_version.go +++ b/clients/ui/bff/data/model_version.go @@ -10,11 +10,13 @@ import ( ) const modelVersionPath = "/model_versions" +const artifactsByModelVersionPath = "/artifacts" type ModelVersionInterface interface { GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) + GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) } type ModelVersion struct { @@ -75,3 +77,23 @@ func (m ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface return &model, nil } + +func (m ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { + path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath) + + if err != nil { + return nil, err + } + + responseData, err := client.GET(path) + if err != nil { + return nil, fmt.Errorf("error fetching model version artifacts: %w", err) + } + + var model openapi.ModelArtifactList + if err := json.Unmarshal(responseData, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} diff --git a/clients/ui/bff/data/model_version_test.go b/clients/ui/bff/data/model_version_test.go index ce6193fef..25236c77a 100644 --- a/clients/ui/bff/data/model_version_test.go +++ b/clients/ui/bff/data/model_version_test.go @@ -88,3 +88,29 @@ func TestUpdateModelVersion(t *testing.T) { mockClient.AssertExpectations(t) } + +func TestGetModelArtifactsByModelVersion(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelArtifactList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath) + assert.NoError(t, err) + + mockClient := new(mocks.MockHTTPClient) + mockClient.On(http.MethodGet, path, mock.Anything).Return(mockData, nil) + + actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1") + assert.NoError(t, err) + + assert.NotNil(t, actual) + assert.Equal(t, expected.Size, actual.Size) + assert.Equal(t, expected.NextPageToken, actual.NextPageToken) + assert.Equal(t, expected.PageSize, actual.PageSize) + assert.Equal(t, len(expected.Items), len(actual.Items)) +} diff --git a/clients/ui/bff/internals/mocks/model_registry_client_mock.go b/clients/ui/bff/internals/mocks/model_registry_client_mock.go index 4e4a4dd0a..a1819cac8 100644 --- a/clients/ui/bff/internals/mocks/model_registry_client_mock.go +++ b/clients/ui/bff/internals/mocks/model_registry_client_mock.go @@ -59,3 +59,8 @@ func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client in mockData := GetModelVersionMocks()[0] return &mockData, nil } + +func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { + mockData := GetModelArtifactListMock() + return &mockData, nil +} diff --git a/clients/ui/bff/internals/mocks/static_data_mock.go b/clients/ui/bff/internals/mocks/static_data_mock.go index 2edb872a1..c659da347 100644 --- a/clients/ui/bff/internals/mocks/static_data_mock.go +++ b/clients/ui/bff/internals/mocks/static_data_mock.go @@ -111,3 +111,65 @@ func GetModelVersionListMock() openapi.ModelVersionList { Size: 2, } } + +func GetModelArtifactMocks() []openapi.ModelArtifact { + artifact1 := openapi.ModelArtifact{ + ArtifactType: "TYPE_ONE", + CustomProperties: newCustomProperties(), + Description: stringToPointer("This artifact can do more than you would expect"), + ExternalId: stringToPointer("1000001"), + Uri: stringToPointer("http://localhost/artifacts/1"), + State: stateToPointer(openapi.ARTIFACTSTATE_LIVE), + Name: stringToPointer("Artifact One"), + Id: stringToPointer("1"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + ModelFormatName: stringToPointer("ONNX"), + StorageKey: stringToPointer("key1"), + StoragePath: stringToPointer("/artifacts/1"), + ModelFormatVersion: stringToPointer("1.0.0"), + ServiceAccountName: stringToPointer("service-1"), + } + + artifact2 := openapi.ModelArtifact{ + ArtifactType: "TYPE_TWO", + CustomProperties: newCustomProperties(), + Description: stringToPointer("This artifact can do more than you would expect, but less than you would hope"), + ExternalId: stringToPointer("1000002"), + Uri: stringToPointer("http://localhost/artifacts/2"), + State: stateToPointer(openapi.ARTIFACTSTATE_PENDING), + Name: stringToPointer("Artifact Two"), + Id: stringToPointer("2"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + ModelFormatName: stringToPointer("TensorFlow"), + StorageKey: stringToPointer("key2"), + StoragePath: stringToPointer("/artifacts/2"), + ModelFormatVersion: stringToPointer("1.0.0"), + ServiceAccountName: stringToPointer("service-2"), + } + + return []openapi.ModelArtifact{artifact1, artifact2} +} + +func GetModelArtifactListMock() openapi.ModelArtifactList { + return openapi.ModelArtifactList{ + NextPageToken: "abcdefgh", + PageSize: 2, + Items: GetModelArtifactMocks(), + Size: 2, + } +} + +func newCustomProperties() *map[string]openapi.MetadataValue { + result := map[string]openapi.MetadataValue{ + "my-label9": { + MetadataStringValue: &openapi.MetadataStringValue{ + StringValue: "property9", + MetadataType: "string", + }, + }, + } + + return &result +} diff --git a/clients/ui/bff/internals/mocks/types_mock.go b/clients/ui/bff/internals/mocks/types_mock.go index 11f9739b0..f1d81a9a7 100644 --- a/clients/ui/bff/internals/mocks/types_mock.go +++ b/clients/ui/bff/internals/mocks/types_mock.go @@ -35,8 +35,8 @@ func GenerateMockRegisteredModel() openapi.RegisteredModel { ExternalId: stringToPointer(gofakeit.UUID()), Name: gofakeit.Name(), Id: stringToPointer(gofakeit.UUID()), - CreateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())), - LastUpdateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())), + CreateTimeSinceEpoch: randomEpochTime(), + LastUpdateTimeSinceEpoch: randomEpochTime(), Owner: stringToPointer(gofakeit.Name()), State: stateToPointer(openapi.RegisteredModelState(gofakeit.RandomString([]string{string(openapi.REGISTEREDMODELSTATE_LIVE), string(openapi.REGISTEREDMODELSTATE_ARCHIVED)}))), } @@ -57,8 +57,8 @@ func GenerateMockModelVersion() openapi.ModelVersion { ExternalId: stringToPointer(gofakeit.UUID()), Name: gofakeit.Name(), Id: stringToPointer(gofakeit.UUID()), - CreateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())), - LastUpdateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())), + CreateTimeSinceEpoch: randomEpochTime(), + LastUpdateTimeSinceEpoch: randomEpochTime(), Author: stringToPointer(gofakeit.Name()), State: stateToPointer(openapi.ModelVersionState(gofakeit.RandomString([]string{string(openapi.MODELVERSIONSTATE_LIVE), string(openapi.MODELVERSIONSTATE_ARCHIVED)}))), } @@ -81,6 +81,66 @@ func GenerateMockModelVersionList() openapi.ModelVersionList { } } +func GenerateMockModelArtifact() openapi.ModelArtifact { + artifact := openapi.ModelArtifact{ + ArtifactType: gofakeit.Word(), + CustomProperties: &map[string]openapi.MetadataValue{ + "example_key": { + MetadataStringValue: &openapi.MetadataStringValue{ + StringValue: gofakeit.Sentence(3), + MetadataType: "string", + }, + }, + }, + Description: stringToPointer(gofakeit.Sentence(5)), + ExternalId: stringToPointer(gofakeit.UUID()), + Uri: stringToPointer(gofakeit.URL()), + State: randomArtifactState(), + Name: stringToPointer(gofakeit.Name()), + Id: stringToPointer(gofakeit.UUID()), + CreateTimeSinceEpoch: randomEpochTime(), + LastUpdateTimeSinceEpoch: randomEpochTime(), + ModelFormatName: stringToPointer(gofakeit.Name()), + StorageKey: stringToPointer(gofakeit.Word()), + StoragePath: stringToPointer("/" + gofakeit.Word() + "/" + gofakeit.Word()), + ModelFormatVersion: stringToPointer(gofakeit.AppVersion()), + ServiceAccountName: stringToPointer(gofakeit.Username()), + } + return artifact +} + +func GenerateMockModelArtifactList() openapi.ModelArtifactList { + var artifacts []openapi.ModelArtifact + + for i := 0; i < 2; i++ { + artifact := GenerateMockModelArtifact() + artifacts = append(artifacts, artifact) + } + + return openapi.ModelArtifactList{ + NextPageToken: gofakeit.UUID(), + PageSize: int32(gofakeit.Number(1, 20)), + Size: int32(len(artifacts)), + Items: artifacts, + } +} + +func randomEpochTime() *string { + return stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())) +} + +func randomArtifactState() *openapi.ArtifactState { + return stateToPointer(openapi.ArtifactState(gofakeit.RandomString([]string{ + string(openapi.ARTIFACTSTATE_LIVE), + string(openapi.ARTIFACTSTATE_DELETED), + string(openapi.ARTIFACTSTATE_ABANDONED), + string(openapi.ARTIFACTSTATE_MARKED_FOR_DELETION), + string(openapi.ARTIFACTSTATE_PENDING), + string(openapi.ARTIFACTSTATE_REFERENCE), + string(openapi.ARTIFACTSTATE_UNKNOWN), + }))) +} + func stateToPointer[T any](s T) *T { return &s }