From 046d6a09703c1fc105535f5bb96bb35d183295ac Mon Sep 17 00:00:00 2001 From: gaoxinxing <15031259256@163.com> Date: Wed, 22 Nov 2023 14:31:11 +0800 Subject: [PATCH 1/3] feat --- .../starwhale/mlops/api/ModelController.java | 17 ++- .../protocol/model/ModelVersionViewVo.java | 2 + .../mlops/domain/model/ModelService.java | 13 +- .../model/mapper/ModelVersionMapper.java | 120 +++++++++++++----- .../model/po/ModelVersionViewEntity.java | 2 + .../mlops/api/ModelControllerTest.java | 12 +- .../mlops/domain/model/ModelServiceTest.java | 15 ++- 7 files changed, 134 insertions(+), 47 deletions(-) diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java b/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java index 29986bffaf..7e0687f6f3 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java @@ -36,6 +36,7 @@ import ai.starwhale.mlops.api.protocol.storage.FileNode; import ai.starwhale.mlops.common.IdConverter; import ai.starwhale.mlops.common.PageParams; +import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.model.ModelService; import ai.starwhale.mlops.domain.model.bo.ModelQuery; import ai.starwhale.mlops.domain.model.bo.ModelVersion; @@ -246,18 +247,20 @@ ResponseEntity>> listModelTree( @Parameter(in = ParameterIn.PATH, required = true, description = "Project url", schema = @Schema()) @PathVariable String projectUrl, @Parameter(in = ParameterIn.QUERY, description = "Data range", schema = @Schema()) - @RequestParam(required = false, defaultValue = "all") DataScope scope + @RequestParam(required = false, defaultValue = "all") DataScope scope, + @RequestParam(required = false) BizType bizType, + @RequestParam(required = false) Long bizId ) { List list; switch (scope) { case all: - list = modelService.listModelVersionView(projectUrl, true, true); + list = modelService.listModelVersionView(projectUrl, true, true, bizType, bizId); break; case shared: - list = modelService.listModelVersionView(projectUrl, true, false); + list = modelService.listModelVersionView(projectUrl, true, false, bizType, bizId); break; case project: - list = modelService.listModelVersionView(projectUrl, false, true); + list = modelService.listModelVersionView(projectUrl, false, true, bizType, bizId); break; default: list = List.of(); @@ -274,10 +277,12 @@ ResponseEntity>> recentModelTree( @Valid @Min(value = 1, message = "limit must be greater than or equal to 1") @Max(value = 50, message = "limit must be less than or equal to 50") - Integer limit + Integer limit, + @RequestParam(required = false) BizType bizType, + @RequestParam(required = false) Long bizId ) { return ResponseEntity.ok(Code.success.asResponse( - modelService.listRecentlyModelVersionView(projectUrl, limit) + modelService.listRecentlyModelVersionView(projectUrl, limit, bizType, bizId) )); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/model/ModelVersionViewVo.java b/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/model/ModelVersionViewVo.java index c3edf8d972..0f9b695ffb 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/model/ModelVersionViewVo.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/model/ModelVersionViewVo.java @@ -47,6 +47,8 @@ public class ModelVersionViewVo { @NotNull private Integer shared; + private Boolean draft; + @NotNull private List stepSpecs; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java index 30152c77ad..afc8e52b03 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java @@ -44,6 +44,7 @@ import ai.starwhale.mlops.domain.bundle.tag.BundleVersionTagDao; import ai.starwhale.mlops.domain.bundle.tag.po.BundleVersionTagEntity; import ai.starwhale.mlops.domain.ft.FineTuneDomainService; +import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.job.cache.HotJobHolder; import ai.starwhale.mlops.domain.job.spec.JobSpecParser; import ai.starwhale.mlops.domain.job.status.JobStatus; @@ -381,11 +382,12 @@ public void shareModelVersion(String projectUrl, String modelUrl, String version } public List listModelVersionView( - String projectUrl, boolean includeShared, boolean includeCurrentProject) { + String projectUrl, boolean includeShared, boolean includeCurrentProject, BizType bizType, Long bizId + ) { var project = projectService.findProject(projectUrl); var list = new ArrayList(); if (includeCurrentProject) { - var versions = modelVersionMapper.listModelVersionViewByProject(project.getId()); + var versions = modelVersionMapper.listModelVersionViewByProject(project.getId(), bizType, bizId); list.addAll(viewEntityToVo(versions, project)); } if (includeShared) { @@ -395,10 +397,12 @@ public List listModelVersionView( return list; } - public List listRecentlyModelVersionView(String projectUrl, Integer limit) { + public List listRecentlyModelVersionView( + String projectUrl, Integer limit, BizType bizType, Long bizId) { var project = projectService.findProject(projectUrl); var userId = userService.currentUserDetail().getId(); - var list = modelVersionMapper.listModelVersionsByUserRecentlyUsed(project.getId(), userId, limit); + var list = modelVersionMapper.listModelVersionsByUserRecentlyUsed( + project.getId(), userId, limit, bizType, bizId); return viewEntityToVo(list, project); } @@ -446,6 +450,7 @@ private List viewEntityToVo(List list, Proj .latest(entity.getId() != null && entity.getId().equals(latest)) .createdTime(entity.getCreatedTime().getTime()) .shared(toInt(entity.getShared())) + .draft(entity.getDraft()) .builtInRuntime(entity.getBuiltInRuntime()) .stepSpecs(jobSpecParser.parseAndFlattenStepFromYaml(entity.getJobs())) .build()); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java index 8739134012..acc8f61c90 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java @@ -16,11 +16,13 @@ package ai.starwhale.mlops.domain.model.mapper; +import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.model.po.ModelVersionEntity; import ai.starwhale.mlops.domain.model.po.ModelVersionViewEntity; import cn.hutool.core.util.StrUtil; import java.util.List; import java.util.Objects; +import org.apache.hadoop.yarn.webapp.hamlet.Hamlet.SELECT; import org.apache.ibatis.annotations.Insert; import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Options; @@ -65,37 +67,20 @@ List list( @Update("update model_version set model_id = #{modelId}, draft = 0 where id = #{id}") int updateModelRef(@Param("id") Long id, @Param("modelId") Long modelId); - @Select("select " + VERSION_VIEW_COLUMNS - + " from model_info as m, model_version as v, project_info as p, user_info as u" - + " where v.model_id = m.id" - + " and m.project_id = p.id" - + " and m.owner_id = u.id" - + " and m.deleted_time = 0" - + " and p.is_deleted = 0" - + " and p.id = #{projectId}" - + " order by m.id desc, v.version_order desc") - List listModelVersionViewByProject(@Param("projectId") Long projectId); - - @Select({"" - }) + @SelectProvider(type = ModelVersionProvider.class, method = "listInProject") + List listModelVersionViewByProject( + @Param("projectId") Long projectId, + @Param("bizType") BizType bizType, + @Param("bizId") Long bizId + ); + + @SelectProvider(type = ModelVersionProvider.class, method = "listUserRecentlyUsed") List listModelVersionsByUserRecentlyUsed( - @Param("projectId") Long projectId, @Param("userId") Long userId, @Param("limit") Integer limit + @Param("projectId") Long projectId, + @Param("userId") Long userId, + @Param("limit") Integer limit, + @Param("bizType") BizType bizType, + @Param("bizId") Long bizId ); @Select("select " + VERSION_VIEW_COLUMNS @@ -155,6 +140,81 @@ ModelVersionEntity findByVersionOrder( class ModelVersionProvider { + public String listInProject( + @Param("projectId") Long projectId, + @Param("bizType") BizType bizType, + @Param("bizId") Long bizId + ) { + return new SQL() { + { + SELECT(VERSION_VIEW_COLUMNS); + FROM("model_version as v"); + INNER_JOIN("model_info as m on m.id = v.model_id"); + INNER_JOIN("project_info as p on p.id = m.project_id"); + INNER_JOIN("user_info as u on u.id = m.owner_id"); + if (bizType == BizType.FINE_TUNE && bizId != null) { + // non-draft models in current project or draft models in current space + WHERE("(v.draft=0 or v.id in " + + "(select target_model_version_id from fine_tune where space_id = #{bizId})" + + ")" + ); + } else { + // (non-draft) models in current project + WHERE("v.draft=0"); + } + WHERE("m.deleted_time = 0"); + WHERE("p.id = #{projectId}"); + ORDER_BY("m.id desc"); + ORDER_BY("v.version_order desc"); + } + + }.toString(); + } + + public String listUserRecentlyUsed( + @Param("projectId") Long projectId, + @Param("userId") Long userId, + @Param("limit") Integer limit, + @Param("bizType") BizType bizType, + @Param("bizId") Long bizId + ) { + return new SQL() { + { + SELECT(VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id"); + FROM("model_version as v"); + INNER_JOIN("model_info as m on m.id = v.model_id"); + INNER_JOIN("job_info as j on j.model_version_id = v.id"); + INNER_JOIN("project_info as p on p.id = m.project_id"); + INNER_JOIN("user_info as u on u.id = m.owner_id"); + if (bizType == BizType.FINE_TUNE && bizId != null) { + // models in current project(non-draft or draft in current space) + // or (non-draft) other project but is shared + WHERE("(" + + "(m.project_id = #{projectId} and " + + " (v.draft=0 or v.id in " + + " (select target_model_version_id from fine_tune where space_id = #{bizId})" + + " )" + + ") " + + "or " + + "(m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)" + + ")"); + } else { + // (non-draft) models in current project or other project but is shared + WHERE("(" + + "(m.project_id = #{projectId} and v.draft=0) or " + + "(m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)" + + ")"); + } + WHERE("m.deleted_time = 0"); + WHERE("j.owner_id = #{userId}"); // jobs in current user + WHERE("j.project_id = #{projectId}"); // jobs in current project + GROUP_BY("v.id"); + ORDER_BY("job_id desc"); + LIMIT("#{limit}"); // recently + } + }.toString(); + } + public String listSql( @Param("modelId") Long modelId, @Param("namePrefix") String namePrefix, @Param("draft") Boolean draft diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/po/ModelVersionViewEntity.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/po/ModelVersionViewEntity.java index 43150c4b12..900e9359c3 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/po/ModelVersionViewEntity.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/po/ModelVersionViewEntity.java @@ -49,6 +49,8 @@ public class ModelVersionViewEntity extends BaseEntity implements HasId { private Boolean shared; + private Boolean draft; + private String storagePath; private String builtInRuntime; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java index bc79412286..d2419db3f7 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java @@ -190,17 +190,17 @@ public void testHeadModel() { @ParameterizedTest @CsvSource({"all, 2", "project, 1", "shared, 1"}) public void testListModelTree(DataScope scope, int listCount) { - given(modelService.listModelVersionView(anyString(), eq(true), eq(true))) + given(modelService.listModelVersionView(anyString(), eq(true), eq(true), any(), any())) .willReturn(List.of( ModelViewVo.builder().projectName("p").build(), ModelViewVo.builder().projectName("p").build() )); - given(modelService.listModelVersionView(anyString(), eq(false), eq(true))) + given(modelService.listModelVersionView(anyString(), eq(false), eq(true), any(), any())) .willReturn(List.of(ModelViewVo.builder().projectName("p").build())); - given(modelService.listModelVersionView(anyString(), eq(true), eq(false))) + given(modelService.listModelVersionView(anyString(), eq(true), eq(false), any(), any())) .willReturn(List.of(ModelViewVo.builder().projectName("p").build())); - var resp = controller.listModelTree("1", scope); + var resp = controller.listModelTree("1", scope, null, null); assertThat(resp.getStatusCode(), is(HttpStatus.OK)); assertThat(resp.getBody(), notNullValue()); assertThat(resp.getBody().getData(), allOf( @@ -211,13 +211,13 @@ public void testListModelTree(DataScope scope, int listCount) { @Test public void testRecentListModelTree() { - given(modelService.listRecentlyModelVersionView(anyString(), eq(5))) + given(modelService.listRecentlyModelVersionView(anyString(), eq(5), any(), any())) .willReturn(List.of( ModelViewVo.builder().projectName("p").build(), ModelViewVo.builder().projectName("p").build() )); - var resp = controller.recentModelTree("1", 5); + var resp = controller.recentModelTree("1", 5, null, null); assertThat(resp.getStatusCode(), is(HttpStatus.OK)); assertThat(resp.getBody(), notNullValue()); assertThat(resp.getBody().getData(), allOf( diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java index 040fa975d2..1dd75646c6 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java @@ -53,6 +53,7 @@ import ai.starwhale.mlops.domain.blob.BlobService; import ai.starwhale.mlops.domain.bundle.tag.BundleVersionTagDao; import ai.starwhale.mlops.domain.ft.FineTuneDomainService; +import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.job.ModelServingService; import ai.starwhale.mlops.domain.job.cache.HotJobHolder; import ai.starwhale.mlops.domain.job.spec.JobSpecParser; @@ -804,7 +805,7 @@ public void testShareModelVersion() { @Test public void testListModelVersionView() { - var res = modelService.listModelVersionView("1", true, true); + var res = modelService.listModelVersionView("1", true, true, null, null); assertEquals(2, res.size()); assertThat(res.get(1), allOf(hasProperty("projectName", is("starwhale")), hasProperty("modelName", is("m")))); @@ -829,6 +830,18 @@ public void testListModelVersionView() { allOf(hasProperty("versionName", is("v4")), hasProperty("alias", is("v4")), hasProperty("latest", is(true))))); + + res = modelService.listModelVersionView("1", false, true, null, null); + assertEquals(2, res.size()); + + res = modelService.listModelVersionView("1", false, true, BizType.FINE_TUNE, 1L); + assertEquals(2, res.size()); + + res = modelService.listRecentlyModelVersionView("1", 5, null, null); + assertEquals(0, res.size()); + + res = modelService.listRecentlyModelVersionView("1", 5, BizType.FINE_TUNE, 1L); + assertEquals(0, res.size()); } @Test From da7adb7e499acf100e164a0a717413c0f99df835 Mon Sep 17 00:00:00 2001 From: gaoxinxing <15031259256@163.com> Date: Wed, 22 Nov 2023 17:31:28 +0800 Subject: [PATCH 2/3] split api --- .../mlops/api/FineTuneController.java | 41 ++++++ .../starwhale/mlops/api/ModelController.java | 17 +-- .../mlops/domain/ft/FineTuneAppService.java | 9 ++ .../mlops/domain/model/ModelService.java | 25 +++- .../model/mapper/ModelVersionMapper.java | 130 +++++++++++------- .../mlops/api/ModelControllerTest.java | 12 +- .../mlops/domain/model/ModelServiceTest.java | 17 ++- 7 files changed, 172 insertions(+), 79 deletions(-) diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/FineTuneController.java b/server/controller/src/main/java/ai/starwhale/mlops/api/FineTuneController.java index 1f72c01988..0fd2d4fe93 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/FineTuneController.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/FineTuneController.java @@ -20,6 +20,7 @@ import ai.starwhale.mlops.api.protocol.ResponseMessage; import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceCreateRequest; import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceVo; +import ai.starwhale.mlops.api.protocol.model.ModelViewVo; import ai.starwhale.mlops.common.IdConverter; import ai.starwhale.mlops.configuration.FeaturesProperties; import ai.starwhale.mlops.domain.ft.FineTuneAppService; @@ -30,8 +31,14 @@ import ai.starwhale.mlops.domain.user.UserService; import com.github.pagehelper.PageInfo; import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.enums.ParameterIn; +import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.tags.Tag; import java.util.List; +import javax.validation.Valid; +import javax.validation.constraints.Max; +import javax.validation.constraints.Min; import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; @@ -195,4 +202,38 @@ public ResponseEntity> exportEval( return ResponseEntity.ok(Code.success.asResponse("")); } + @GetMapping( + value = "/project/{projectId}/ftspace/{spaceId}/model-tree", + produces = MediaType.APPLICATION_JSON_VALUE + ) + @PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')") + ResponseEntity>> listModelTree( + @PathVariable("projectId") Long projectId, + @PathVariable("spaceId") Long spaceId + ) { + return ResponseEntity.ok(Code.success.asResponse( + fineTuneAppService.listModelVersionView(projectId, spaceId) + )); + } + + @GetMapping( + value = "/project/{projectId}/ftspace/{spaceId}/recent-model-tree", + produces = MediaType.APPLICATION_JSON_VALUE + ) + @PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')") + ResponseEntity>> recentModelTree( + @PathVariable("projectId") Long projectId, + @PathVariable("spaceId") Long spaceId, + @Parameter(in = ParameterIn.QUERY, description = "Data limit", schema = @Schema()) + @RequestParam(required = false, defaultValue = "5") + @Valid + @Min(value = 1, message = "limit must be greater than or equal to 1") + @Max(value = 50, message = "limit must be less than or equal to 50") + Integer limit + ) { + return ResponseEntity.ok(Code.success.asResponse( + fineTuneAppService.listRecentlyModelVersionView(projectId, spaceId, limit) + )); + } + } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java b/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java index 7e0687f6f3..29986bffaf 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java @@ -36,7 +36,6 @@ import ai.starwhale.mlops.api.protocol.storage.FileNode; import ai.starwhale.mlops.common.IdConverter; import ai.starwhale.mlops.common.PageParams; -import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.model.ModelService; import ai.starwhale.mlops.domain.model.bo.ModelQuery; import ai.starwhale.mlops.domain.model.bo.ModelVersion; @@ -247,20 +246,18 @@ ResponseEntity>> listModelTree( @Parameter(in = ParameterIn.PATH, required = true, description = "Project url", schema = @Schema()) @PathVariable String projectUrl, @Parameter(in = ParameterIn.QUERY, description = "Data range", schema = @Schema()) - @RequestParam(required = false, defaultValue = "all") DataScope scope, - @RequestParam(required = false) BizType bizType, - @RequestParam(required = false) Long bizId + @RequestParam(required = false, defaultValue = "all") DataScope scope ) { List list; switch (scope) { case all: - list = modelService.listModelVersionView(projectUrl, true, true, bizType, bizId); + list = modelService.listModelVersionView(projectUrl, true, true); break; case shared: - list = modelService.listModelVersionView(projectUrl, true, false, bizType, bizId); + list = modelService.listModelVersionView(projectUrl, true, false); break; case project: - list = modelService.listModelVersionView(projectUrl, false, true, bizType, bizId); + list = modelService.listModelVersionView(projectUrl, false, true); break; default: list = List.of(); @@ -277,12 +274,10 @@ ResponseEntity>> recentModelTree( @Valid @Min(value = 1, message = "limit must be greater than or equal to 1") @Max(value = 50, message = "limit must be less than or equal to 50") - Integer limit, - @RequestParam(required = false) BizType bizType, - @RequestParam(required = false) Long bizId + Integer limit ) { return ResponseEntity.ok(Code.success.asResponse( - modelService.listRecentlyModelVersionView(projectUrl, limit, bizType, bizId) + modelService.listRecentlyModelVersionView(projectUrl, limit) )); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/ft/FineTuneAppService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/ft/FineTuneAppService.java index bcb9ccd4d8..4ededfb76f 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/ft/FineTuneAppService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/ft/FineTuneAppService.java @@ -19,6 +19,7 @@ import static ai.starwhale.mlops.domain.evaluation.EvaluationService.TABLE_NAME_FORMAT; import ai.starwhale.mlops.api.protocol.job.JobRequest; +import ai.starwhale.mlops.api.protocol.model.ModelViewVo; import ai.starwhale.mlops.api.protocol.model.ModelVo; import ai.starwhale.mlops.common.Constants; import ai.starwhale.mlops.common.IdConverter; @@ -372,6 +373,14 @@ public void releaseFt( } } + public List listModelVersionView(Long projectId, Long spaceId) { + return modelService.listFtSpaceModelVersionView(String.valueOf(projectId), spaceId); + } + + public List listRecentlyModelVersionView(Long projectId, Long spaceId, Integer limit) { + return modelService.listRecentlyModelVersionView(String.valueOf(projectId), spaceId, limit); + } + private void checkFeatureEnabled() throws StarwhaleApiException { if (!this.featuresProperties.isFineTuneEnabled()) { throw new StarwhaleApiException( diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java index afc8e52b03..a03d0b3035 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java @@ -44,7 +44,6 @@ import ai.starwhale.mlops.domain.bundle.tag.BundleVersionTagDao; import ai.starwhale.mlops.domain.bundle.tag.po.BundleVersionTagEntity; import ai.starwhale.mlops.domain.ft.FineTuneDomainService; -import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.job.cache.HotJobHolder; import ai.starwhale.mlops.domain.job.spec.JobSpecParser; import ai.starwhale.mlops.domain.job.status.JobStatus; @@ -382,12 +381,12 @@ public void shareModelVersion(String projectUrl, String modelUrl, String version } public List listModelVersionView( - String projectUrl, boolean includeShared, boolean includeCurrentProject, BizType bizType, Long bizId + String projectUrl, boolean includeShared, boolean includeCurrentProject ) { var project = projectService.findProject(projectUrl); var list = new ArrayList(); if (includeCurrentProject) { - var versions = modelVersionMapper.listModelVersionViewByProject(project.getId(), bizType, bizId); + var versions = modelVersionMapper.listModelVersionViewByProject(project.getId()); list.addAll(viewEntityToVo(versions, project)); } if (includeShared) { @@ -397,12 +396,24 @@ public List listModelVersionView( return list; } - public List listRecentlyModelVersionView( - String projectUrl, Integer limit, BizType bizType, Long bizId) { + public List listRecentlyModelVersionView(String projectUrl, Integer limit) { var project = projectService.findProject(projectUrl); var userId = userService.currentUserDetail().getId(); - var list = modelVersionMapper.listModelVersionsByUserRecentlyUsed( - project.getId(), userId, limit, bizType, bizId); + var list = modelVersionMapper.listModelVersionsByUserRecentlyUsed(project.getId(), userId, limit); + return viewEntityToVo(list, project); + } + + public List listRecentlyModelVersionView(String projectUrl, Long spaceId, Integer limit) { + var project = projectService.findProject(projectUrl); + var userId = userService.currentUserDetail().getId(); + var list = modelVersionMapper.listModelVersionsByUserRecentlyUsedInFtSpace( + project.getId(), userId, spaceId, limit); + return viewEntityToVo(list, project); + } + + public List listFtSpaceModelVersionView(String projectUrl, Long spaceId) { + var project = projectService.findProject(projectUrl); + var list = modelVersionMapper.listModelVersionViewByFtSpace(project.getId(), spaceId); return viewEntityToVo(list, project); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java index acc8f61c90..a760e152bd 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java @@ -16,13 +16,11 @@ package ai.starwhale.mlops.domain.model.mapper; -import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.model.po.ModelVersionEntity; import ai.starwhale.mlops.domain.model.po.ModelVersionViewEntity; import cn.hutool.core.util.StrUtil; import java.util.List; import java.util.Objects; -import org.apache.hadoop.yarn.webapp.hamlet.Hamlet.SELECT; import org.apache.ibatis.annotations.Insert; import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Options; @@ -68,19 +66,27 @@ List list( int updateModelRef(@Param("id") Long id, @Param("modelId") Long modelId); @SelectProvider(type = ModelVersionProvider.class, method = "listInProject") - List listModelVersionViewByProject( - @Param("projectId") Long projectId, - @Param("bizType") BizType bizType, - @Param("bizId") Long bizId - ); + List listModelVersionViewByProject(@Param("projectId") Long projectId); @SelectProvider(type = ModelVersionProvider.class, method = "listUserRecentlyUsed") List listModelVersionsByUserRecentlyUsed( @Param("projectId") Long projectId, @Param("userId") Long userId, - @Param("limit") Integer limit, - @Param("bizType") BizType bizType, - @Param("bizId") Long bizId + @Param("limit") Integer limit + ); + + @SelectProvider(type = ModelVersionProvider.class, method = "listInFtSpace") + List listModelVersionViewByFtSpace( + @Param("projectId") Long projectId, + @Param("spaceId") Long spaceId + ); + + @SelectProvider(type = ModelVersionProvider.class, method = "listUserRecentlyUsedWithFtSpace") + List listModelVersionsByUserRecentlyUsedInFtSpace( + @Param("projectId") Long projectId, + @Param("userId") Long userId, + @Param("spaceId") Long spaceId, + @Param("limit") Integer limit ); @Select("select " + VERSION_VIEW_COLUMNS @@ -140,11 +146,28 @@ ModelVersionEntity findByVersionOrder( class ModelVersionProvider { - public String listInProject( + public String listInProject(@Param("projectId") Long projectId) { + return new SQL() { + { + SELECT(VERSION_VIEW_COLUMNS); + FROM("model_version as v"); + INNER_JOIN("model_info as m on m.id = v.model_id"); + INNER_JOIN("project_info as p on p.id = m.project_id"); + INNER_JOIN("user_info as u on u.id = m.owner_id"); + // (non-draft) models in current project + WHERE("v.draft=0"); + WHERE("m.deleted_time = 0"); + WHERE("p.id = #{projectId}"); + ORDER_BY("m.id desc"); + ORDER_BY("v.version_order desc"); + } + + }.toString(); + } + + public String listInFtSpace( @Param("projectId") Long projectId, - @Param("bizType") BizType bizType, - @Param("bizId") Long bizId - ) { + @Param("spaceId") Long spaceId) { return new SQL() { { SELECT(VERSION_VIEW_COLUMNS); @@ -152,16 +175,9 @@ public String listInProject( INNER_JOIN("model_info as m on m.id = v.model_id"); INNER_JOIN("project_info as p on p.id = m.project_id"); INNER_JOIN("user_info as u on u.id = m.owner_id"); - if (bizType == BizType.FINE_TUNE && bizId != null) { - // non-draft models in current project or draft models in current space - WHERE("(v.draft=0 or v.id in " - + "(select target_model_version_id from fine_tune where space_id = #{bizId})" - + ")" - ); - } else { - // (non-draft) models in current project - WHERE("v.draft=0"); - } + INNER_JOIN("fine_tune as ft on ft.target_model_version_id = v.id"); + // all models in current space + WHERE("ft.space_id = #{spaceId}"); WHERE("m.deleted_time = 0"); WHERE("p.id = #{projectId}"); ORDER_BY("m.id desc"); @@ -174,10 +190,7 @@ public String listInProject( public String listUserRecentlyUsed( @Param("projectId") Long projectId, @Param("userId") Long userId, - @Param("limit") Integer limit, - @Param("bizType") BizType bizType, - @Param("bizId") Long bizId - ) { + @Param("limit") Integer limit) { return new SQL() { { SELECT(VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id"); @@ -186,31 +199,50 @@ public String listUserRecentlyUsed( INNER_JOIN("job_info as j on j.model_version_id = v.id"); INNER_JOIN("project_info as p on p.id = m.project_id"); INNER_JOIN("user_info as u on u.id = m.owner_id"); - if (bizType == BizType.FINE_TUNE && bizId != null) { - // models in current project(non-draft or draft in current space) - // or (non-draft) other project but is shared - WHERE("(" - + "(m.project_id = #{projectId} and " - + " (v.draft=0 or v.id in " - + " (select target_model_version_id from fine_tune where space_id = #{bizId})" - + " )" - + ") " - + "or " - + "(m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)" - + ")"); - } else { - // (non-draft) models in current project or other project but is shared - WHERE("(" - + "(m.project_id = #{projectId} and v.draft=0) or " - + "(m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)" - + ")"); - } + // (non-draft) models in current project or other project but is shared + WHERE("(" + + "(m.project_id = #{projectId} and v.draft=0) " + + "or (m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)" + + ")"); + + WHERE("m.deleted_time = 0"); + WHERE("j.owner_id = #{userId}"); // jobs in current user + WHERE("j.project_id = #{projectId}"); // jobs in current project + GROUP_BY("v.id"); + ORDER_BY("job_id desc"); // recently + LIMIT("#{limit}"); + } + }.toString(); + } + + public String listUserRecentlyUsedWithFtSpace( + @Param("projectId") Long projectId, + @Param("userId") Long userId, + @Param("spaceId") Long spaceId, + @Param("limit") Integer limit) { + return new SQL() { + { + SELECT(VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id"); + FROM("model_version as v"); + INNER_JOIN("model_info as m on m.id = v.model_id"); + INNER_JOIN("job_info as j on j.model_version_id = v.id"); + INNER_JOIN("project_info as p on p.id = m.project_id"); + INNER_JOIN("user_info as u on u.id = m.owner_id"); + // non-draft models in current project + // or all in current space + // or (non-draft) other project but is shared + WHERE("(" + + "(m.project_id = #{projectId} and v.draft=0)" + + "or v.id in " + + " (select target_model_version_id from fine_tune where space_id = #{spaceId})" + + "or (m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)" + + ")"); WHERE("m.deleted_time = 0"); WHERE("j.owner_id = #{userId}"); // jobs in current user WHERE("j.project_id = #{projectId}"); // jobs in current project GROUP_BY("v.id"); - ORDER_BY("job_id desc"); - LIMIT("#{limit}"); // recently + ORDER_BY("job_id desc"); // recently + LIMIT("#{limit}"); } }.toString(); } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java index d2419db3f7..bc79412286 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java @@ -190,17 +190,17 @@ public void testHeadModel() { @ParameterizedTest @CsvSource({"all, 2", "project, 1", "shared, 1"}) public void testListModelTree(DataScope scope, int listCount) { - given(modelService.listModelVersionView(anyString(), eq(true), eq(true), any(), any())) + given(modelService.listModelVersionView(anyString(), eq(true), eq(true))) .willReturn(List.of( ModelViewVo.builder().projectName("p").build(), ModelViewVo.builder().projectName("p").build() )); - given(modelService.listModelVersionView(anyString(), eq(false), eq(true), any(), any())) + given(modelService.listModelVersionView(anyString(), eq(false), eq(true))) .willReturn(List.of(ModelViewVo.builder().projectName("p").build())); - given(modelService.listModelVersionView(anyString(), eq(true), eq(false), any(), any())) + given(modelService.listModelVersionView(anyString(), eq(true), eq(false))) .willReturn(List.of(ModelViewVo.builder().projectName("p").build())); - var resp = controller.listModelTree("1", scope, null, null); + var resp = controller.listModelTree("1", scope); assertThat(resp.getStatusCode(), is(HttpStatus.OK)); assertThat(resp.getBody(), notNullValue()); assertThat(resp.getBody().getData(), allOf( @@ -211,13 +211,13 @@ public void testListModelTree(DataScope scope, int listCount) { @Test public void testRecentListModelTree() { - given(modelService.listRecentlyModelVersionView(anyString(), eq(5), any(), any())) + given(modelService.listRecentlyModelVersionView(anyString(), eq(5))) .willReturn(List.of( ModelViewVo.builder().projectName("p").build(), ModelViewVo.builder().projectName("p").build() )); - var resp = controller.recentModelTree("1", 5, null, null); + var resp = controller.recentModelTree("1", 5); assertThat(resp.getStatusCode(), is(HttpStatus.OK)); assertThat(resp.getBody(), notNullValue()); assertThat(resp.getBody().getData(), allOf( diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java index 1dd75646c6..d174744888 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java @@ -53,7 +53,6 @@ import ai.starwhale.mlops.domain.blob.BlobService; import ai.starwhale.mlops.domain.bundle.tag.BundleVersionTagDao; import ai.starwhale.mlops.domain.ft.FineTuneDomainService; -import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.job.ModelServingService; import ai.starwhale.mlops.domain.job.cache.HotJobHolder; import ai.starwhale.mlops.domain.job.spec.JobSpecParser; @@ -805,7 +804,7 @@ public void testShareModelVersion() { @Test public void testListModelVersionView() { - var res = modelService.listModelVersionView("1", true, true, null, null); + var res = modelService.listModelVersionView("1", true, true); assertEquals(2, res.size()); assertThat(res.get(1), allOf(hasProperty("projectName", is("starwhale")), hasProperty("modelName", is("m")))); @@ -831,16 +830,22 @@ public void testListModelVersionView() { hasProperty("alias", is("v4")), hasProperty("latest", is(true))))); - res = modelService.listModelVersionView("1", false, true, null, null); + res = modelService.listModelVersionView("1", false, true); assertEquals(2, res.size()); - res = modelService.listModelVersionView("1", false, true, BizType.FINE_TUNE, 1L); + res = modelService.listModelVersionView("1", false, true); assertEquals(2, res.size()); - res = modelService.listRecentlyModelVersionView("1", 5, null, null); + res = modelService.listFtSpaceModelVersionView("1", 1L); assertEquals(0, res.size()); - res = modelService.listRecentlyModelVersionView("1", 5, BizType.FINE_TUNE, 1L); + res = modelService.listFtSpaceModelVersionView("1", 1L); + assertEquals(0, res.size()); + + res = modelService.listRecentlyModelVersionView("1", 5); + assertEquals(0, res.size()); + + res = modelService.listRecentlyModelVersionView("1", 1L, 5); assertEquals(0, res.size()); } From 142ade0e76d610bab34ba6a7450cce992bda8eb6 Mon Sep 17 00:00:00 2001 From: gaoxinxing <15031259256@163.com> Date: Wed, 22 Nov 2023 17:51:17 +0800 Subject: [PATCH 3/3] gen model --- client/starwhale/base/client/models/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/client/starwhale/base/client/models/models.py b/client/starwhale/base/client/models/models.py index 77c7563b1f..11a793bea2 100644 --- a/client/starwhale/base/client/models/models.py +++ b/client/starwhale/base/client/models/models.py @@ -1575,6 +1575,7 @@ class ModelVersionViewVo(SwBaseModel): latest: bool tags: Optional[List[str]] = None shared: int + draft: Optional[bool] = None step_specs: List[StepSpec] = Field(..., alias='stepSpecs') built_in_runtime: Optional[str] = Field(None, alias='builtInRuntime') created_time: int = Field(..., alias='createdTime')