From 046d6a09703c1fc105535f5bb96bb35d183295ac Mon Sep 17 00:00:00 2001 From: gaoxinxing <15031259256@163.com> Date: Wed, 22 Nov 2023 14:31:11 +0800 Subject: [PATCH] feat --- .../starwhale/mlops/api/ModelController.java | 17 ++- .../protocol/model/ModelVersionViewVo.java | 2 + .../mlops/domain/model/ModelService.java | 13 +- .../model/mapper/ModelVersionMapper.java | 120 +++++++++++++----- .../model/po/ModelVersionViewEntity.java | 2 + .../mlops/api/ModelControllerTest.java | 12 +- .../mlops/domain/model/ModelServiceTest.java | 15 ++- 7 files changed, 134 insertions(+), 47 deletions(-) diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java b/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java index 29986bffaf..7e0687f6f3 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/ModelController.java @@ -36,6 +36,7 @@ import ai.starwhale.mlops.api.protocol.storage.FileNode; import ai.starwhale.mlops.common.IdConverter; import ai.starwhale.mlops.common.PageParams; +import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.model.ModelService; import ai.starwhale.mlops.domain.model.bo.ModelQuery; import ai.starwhale.mlops.domain.model.bo.ModelVersion; @@ -246,18 +247,20 @@ ResponseEntity>> listModelTree( @Parameter(in = ParameterIn.PATH, required = true, description = "Project url", schema = @Schema()) @PathVariable String projectUrl, @Parameter(in = ParameterIn.QUERY, description = "Data range", schema = @Schema()) - @RequestParam(required = false, defaultValue = "all") DataScope scope + @RequestParam(required = false, defaultValue = "all") DataScope scope, + @RequestParam(required = false) BizType bizType, + @RequestParam(required = false) Long bizId ) { List list; switch (scope) { case all: - list = modelService.listModelVersionView(projectUrl, true, true); + list = modelService.listModelVersionView(projectUrl, true, true, bizType, bizId); break; case shared: - list = modelService.listModelVersionView(projectUrl, true, false); + list = modelService.listModelVersionView(projectUrl, true, false, bizType, bizId); break; case project: - list = modelService.listModelVersionView(projectUrl, false, true); + list = modelService.listModelVersionView(projectUrl, false, true, bizType, bizId); break; default: list = List.of(); @@ -274,10 +277,12 @@ ResponseEntity>> recentModelTree( @Valid @Min(value = 1, message = "limit must be greater than or equal to 1") @Max(value = 50, message = "limit must be less than or equal to 50") - Integer limit + Integer limit, + @RequestParam(required = false) BizType bizType, + @RequestParam(required = false) Long bizId ) { return ResponseEntity.ok(Code.success.asResponse( - modelService.listRecentlyModelVersionView(projectUrl, limit) + modelService.listRecentlyModelVersionView(projectUrl, limit, bizType, bizId) )); } diff --git a/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/model/ModelVersionViewVo.java b/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/model/ModelVersionViewVo.java index c3edf8d972..0f9b695ffb 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/model/ModelVersionViewVo.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/api/protocol/model/ModelVersionViewVo.java @@ -47,6 +47,8 @@ public class ModelVersionViewVo { @NotNull private Integer shared; + private Boolean draft; + @NotNull private List stepSpecs; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java index 30152c77ad..afc8e52b03 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/ModelService.java @@ -44,6 +44,7 @@ import ai.starwhale.mlops.domain.bundle.tag.BundleVersionTagDao; import ai.starwhale.mlops.domain.bundle.tag.po.BundleVersionTagEntity; import ai.starwhale.mlops.domain.ft.FineTuneDomainService; +import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.job.cache.HotJobHolder; import ai.starwhale.mlops.domain.job.spec.JobSpecParser; import ai.starwhale.mlops.domain.job.status.JobStatus; @@ -381,11 +382,12 @@ public void shareModelVersion(String projectUrl, String modelUrl, String version } public List listModelVersionView( - String projectUrl, boolean includeShared, boolean includeCurrentProject) { + String projectUrl, boolean includeShared, boolean includeCurrentProject, BizType bizType, Long bizId + ) { var project = projectService.findProject(projectUrl); var list = new ArrayList(); if (includeCurrentProject) { - var versions = modelVersionMapper.listModelVersionViewByProject(project.getId()); + var versions = modelVersionMapper.listModelVersionViewByProject(project.getId(), bizType, bizId); list.addAll(viewEntityToVo(versions, project)); } if (includeShared) { @@ -395,10 +397,12 @@ public List listModelVersionView( return list; } - public List listRecentlyModelVersionView(String projectUrl, Integer limit) { + public List listRecentlyModelVersionView( + String projectUrl, Integer limit, BizType bizType, Long bizId) { var project = projectService.findProject(projectUrl); var userId = userService.currentUserDetail().getId(); - var list = modelVersionMapper.listModelVersionsByUserRecentlyUsed(project.getId(), userId, limit); + var list = modelVersionMapper.listModelVersionsByUserRecentlyUsed( + project.getId(), userId, limit, bizType, bizId); return viewEntityToVo(list, project); } @@ -446,6 +450,7 @@ private List viewEntityToVo(List list, Proj .latest(entity.getId() != null && entity.getId().equals(latest)) .createdTime(entity.getCreatedTime().getTime()) .shared(toInt(entity.getShared())) + .draft(entity.getDraft()) .builtInRuntime(entity.getBuiltInRuntime()) .stepSpecs(jobSpecParser.parseAndFlattenStepFromYaml(entity.getJobs())) .build()); diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java index 8739134012..acc8f61c90 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/mapper/ModelVersionMapper.java @@ -16,11 +16,13 @@ package ai.starwhale.mlops.domain.model.mapper; +import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.model.po.ModelVersionEntity; import ai.starwhale.mlops.domain.model.po.ModelVersionViewEntity; import cn.hutool.core.util.StrUtil; import java.util.List; import java.util.Objects; +import org.apache.hadoop.yarn.webapp.hamlet.Hamlet.SELECT; import org.apache.ibatis.annotations.Insert; import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Options; @@ -65,37 +67,20 @@ List list( @Update("update model_version set model_id = #{modelId}, draft = 0 where id = #{id}") int updateModelRef(@Param("id") Long id, @Param("modelId") Long modelId); - @Select("select " + VERSION_VIEW_COLUMNS - + " from model_info as m, model_version as v, project_info as p, user_info as u" - + " where v.model_id = m.id" - + " and m.project_id = p.id" - + " and m.owner_id = u.id" - + " and m.deleted_time = 0" - + " and p.is_deleted = 0" - + " and p.id = #{projectId}" - + " order by m.id desc, v.version_order desc") - List listModelVersionViewByProject(@Param("projectId") Long projectId); - - @Select({"" - }) + @SelectProvider(type = ModelVersionProvider.class, method = "listInProject") + List listModelVersionViewByProject( + @Param("projectId") Long projectId, + @Param("bizType") BizType bizType, + @Param("bizId") Long bizId + ); + + @SelectProvider(type = ModelVersionProvider.class, method = "listUserRecentlyUsed") List listModelVersionsByUserRecentlyUsed( - @Param("projectId") Long projectId, @Param("userId") Long userId, @Param("limit") Integer limit + @Param("projectId") Long projectId, + @Param("userId") Long userId, + @Param("limit") Integer limit, + @Param("bizType") BizType bizType, + @Param("bizId") Long bizId ); @Select("select " + VERSION_VIEW_COLUMNS @@ -155,6 +140,81 @@ ModelVersionEntity findByVersionOrder( class ModelVersionProvider { + public String listInProject( + @Param("projectId") Long projectId, + @Param("bizType") BizType bizType, + @Param("bizId") Long bizId + ) { + return new SQL() { + { + SELECT(VERSION_VIEW_COLUMNS); + FROM("model_version as v"); + INNER_JOIN("model_info as m on m.id = v.model_id"); + INNER_JOIN("project_info as p on p.id = m.project_id"); + INNER_JOIN("user_info as u on u.id = m.owner_id"); + if (bizType == BizType.FINE_TUNE && bizId != null) { + // non-draft models in current project or draft models in current space + WHERE("(v.draft=0 or v.id in " + + "(select target_model_version_id from fine_tune where space_id = #{bizId})" + + ")" + ); + } else { + // (non-draft) models in current project + WHERE("v.draft=0"); + } + WHERE("m.deleted_time = 0"); + WHERE("p.id = #{projectId}"); + ORDER_BY("m.id desc"); + ORDER_BY("v.version_order desc"); + } + + }.toString(); + } + + public String listUserRecentlyUsed( + @Param("projectId") Long projectId, + @Param("userId") Long userId, + @Param("limit") Integer limit, + @Param("bizType") BizType bizType, + @Param("bizId") Long bizId + ) { + return new SQL() { + { + SELECT(VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id"); + FROM("model_version as v"); + INNER_JOIN("model_info as m on m.id = v.model_id"); + INNER_JOIN("job_info as j on j.model_version_id = v.id"); + INNER_JOIN("project_info as p on p.id = m.project_id"); + INNER_JOIN("user_info as u on u.id = m.owner_id"); + if (bizType == BizType.FINE_TUNE && bizId != null) { + // models in current project(non-draft or draft in current space) + // or (non-draft) other project but is shared + WHERE("(" + + "(m.project_id = #{projectId} and " + + " (v.draft=0 or v.id in " + + " (select target_model_version_id from fine_tune where space_id = #{bizId})" + + " )" + + ") " + + "or " + + "(m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)" + + ")"); + } else { + // (non-draft) models in current project or other project but is shared + WHERE("(" + + "(m.project_id = #{projectId} and v.draft=0) or " + + "(m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)" + + ")"); + } + WHERE("m.deleted_time = 0"); + WHERE("j.owner_id = #{userId}"); // jobs in current user + WHERE("j.project_id = #{projectId}"); // jobs in current project + GROUP_BY("v.id"); + ORDER_BY("job_id desc"); + LIMIT("#{limit}"); // recently + } + }.toString(); + } + public String listSql( @Param("modelId") Long modelId, @Param("namePrefix") String namePrefix, @Param("draft") Boolean draft diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/po/ModelVersionViewEntity.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/po/ModelVersionViewEntity.java index 43150c4b12..900e9359c3 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/model/po/ModelVersionViewEntity.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/model/po/ModelVersionViewEntity.java @@ -49,6 +49,8 @@ public class ModelVersionViewEntity extends BaseEntity implements HasId { private Boolean shared; + private Boolean draft; + private String storagePath; private String builtInRuntime; diff --git a/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java index bc79412286..d2419db3f7 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/api/ModelControllerTest.java @@ -190,17 +190,17 @@ public void testHeadModel() { @ParameterizedTest @CsvSource({"all, 2", "project, 1", "shared, 1"}) public void testListModelTree(DataScope scope, int listCount) { - given(modelService.listModelVersionView(anyString(), eq(true), eq(true))) + given(modelService.listModelVersionView(anyString(), eq(true), eq(true), any(), any())) .willReturn(List.of( ModelViewVo.builder().projectName("p").build(), ModelViewVo.builder().projectName("p").build() )); - given(modelService.listModelVersionView(anyString(), eq(false), eq(true))) + given(modelService.listModelVersionView(anyString(), eq(false), eq(true), any(), any())) .willReturn(List.of(ModelViewVo.builder().projectName("p").build())); - given(modelService.listModelVersionView(anyString(), eq(true), eq(false))) + given(modelService.listModelVersionView(anyString(), eq(true), eq(false), any(), any())) .willReturn(List.of(ModelViewVo.builder().projectName("p").build())); - var resp = controller.listModelTree("1", scope); + var resp = controller.listModelTree("1", scope, null, null); assertThat(resp.getStatusCode(), is(HttpStatus.OK)); assertThat(resp.getBody(), notNullValue()); assertThat(resp.getBody().getData(), allOf( @@ -211,13 +211,13 @@ public void testListModelTree(DataScope scope, int listCount) { @Test public void testRecentListModelTree() { - given(modelService.listRecentlyModelVersionView(anyString(), eq(5))) + given(modelService.listRecentlyModelVersionView(anyString(), eq(5), any(), any())) .willReturn(List.of( ModelViewVo.builder().projectName("p").build(), ModelViewVo.builder().projectName("p").build() )); - var resp = controller.recentModelTree("1", 5); + var resp = controller.recentModelTree("1", 5, null, null); assertThat(resp.getStatusCode(), is(HttpStatus.OK)); assertThat(resp.getBody(), notNullValue()); assertThat(resp.getBody().getData(), allOf( diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java index 040fa975d2..1dd75646c6 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/model/ModelServiceTest.java @@ -53,6 +53,7 @@ import ai.starwhale.mlops.domain.blob.BlobService; import ai.starwhale.mlops.domain.bundle.tag.BundleVersionTagDao; import ai.starwhale.mlops.domain.ft.FineTuneDomainService; +import ai.starwhale.mlops.domain.job.BizType; import ai.starwhale.mlops.domain.job.ModelServingService; import ai.starwhale.mlops.domain.job.cache.HotJobHolder; import ai.starwhale.mlops.domain.job.spec.JobSpecParser; @@ -804,7 +805,7 @@ public void testShareModelVersion() { @Test public void testListModelVersionView() { - var res = modelService.listModelVersionView("1", true, true); + var res = modelService.listModelVersionView("1", true, true, null, null); assertEquals(2, res.size()); assertThat(res.get(1), allOf(hasProperty("projectName", is("starwhale")), hasProperty("modelName", is("m")))); @@ -829,6 +830,18 @@ public void testListModelVersionView() { allOf(hasProperty("versionName", is("v4")), hasProperty("alias", is("v4")), hasProperty("latest", is(true))))); + + res = modelService.listModelVersionView("1", false, true, null, null); + assertEquals(2, res.size()); + + res = modelService.listModelVersionView("1", false, true, BizType.FINE_TUNE, 1L); + assertEquals(2, res.size()); + + res = modelService.listRecentlyModelVersionView("1", 5, null, null); + assertEquals(0, res.size()); + + res = modelService.listRecentlyModelVersionView("1", 5, BizType.FINE_TUNE, 1L); + assertEquals(0, res.size()); } @Test