Skip to content

Commit

Permalink
feat
Browse files Browse the repository at this point in the history
  • Loading branch information
goldenxinxing committed Nov 22, 2023
1 parent 34d2018 commit 046d6a0
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import ai.starwhale.mlops.api.protocol.storage.FileNode;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.common.PageParams;
import ai.starwhale.mlops.domain.job.BizType;
import ai.starwhale.mlops.domain.model.ModelService;
import ai.starwhale.mlops.domain.model.bo.ModelQuery;
import ai.starwhale.mlops.domain.model.bo.ModelVersion;
Expand Down Expand Up @@ -246,18 +247,20 @@ ResponseEntity<ResponseMessage<List<ModelViewVo>>> listModelTree(
@Parameter(in = ParameterIn.PATH, required = true, description = "Project url", schema = @Schema())
@PathVariable String projectUrl,
@Parameter(in = ParameterIn.QUERY, description = "Data range", schema = @Schema())
@RequestParam(required = false, defaultValue = "all") DataScope scope
@RequestParam(required = false, defaultValue = "all") DataScope scope,
@RequestParam(required = false) BizType bizType,
@RequestParam(required = false) Long bizId
) {
List<ModelViewVo> list;
switch (scope) {
case all:
list = modelService.listModelVersionView(projectUrl, true, true);
list = modelService.listModelVersionView(projectUrl, true, true, bizType, bizId);
break;
case shared:
list = modelService.listModelVersionView(projectUrl, true, false);
list = modelService.listModelVersionView(projectUrl, true, false, bizType, bizId);
break;
case project:
list = modelService.listModelVersionView(projectUrl, false, true);
list = modelService.listModelVersionView(projectUrl, false, true, bizType, bizId);
break;
default:
list = List.of();
Expand All @@ -274,10 +277,12 @@ ResponseEntity<ResponseMessage<List<ModelViewVo>>> recentModelTree(
@Valid
@Min(value = 1, message = "limit must be greater than or equal to 1")
@Max(value = 50, message = "limit must be less than or equal to 50")
Integer limit
Integer limit,
@RequestParam(required = false) BizType bizType,
@RequestParam(required = false) Long bizId
) {
return ResponseEntity.ok(Code.success.asResponse(
modelService.listRecentlyModelVersionView(projectUrl, limit)
modelService.listRecentlyModelVersionView(projectUrl, limit, bizType, bizId)
));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public class ModelVersionViewVo {
@NotNull
private Integer shared;

private Boolean draft;

@NotNull
private List<StepSpec> stepSpecs;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import ai.starwhale.mlops.domain.bundle.tag.BundleVersionTagDao;
import ai.starwhale.mlops.domain.bundle.tag.po.BundleVersionTagEntity;
import ai.starwhale.mlops.domain.ft.FineTuneDomainService;
import ai.starwhale.mlops.domain.job.BizType;
import ai.starwhale.mlops.domain.job.cache.HotJobHolder;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.status.JobStatus;
Expand Down Expand Up @@ -381,11 +382,12 @@ public void shareModelVersion(String projectUrl, String modelUrl, String version
}

public List<ModelViewVo> listModelVersionView(
String projectUrl, boolean includeShared, boolean includeCurrentProject) {
String projectUrl, boolean includeShared, boolean includeCurrentProject, BizType bizType, Long bizId
) {
var project = projectService.findProject(projectUrl);
var list = new ArrayList<ModelViewVo>();
if (includeCurrentProject) {
var versions = modelVersionMapper.listModelVersionViewByProject(project.getId());
var versions = modelVersionMapper.listModelVersionViewByProject(project.getId(), bizType, bizId);
list.addAll(viewEntityToVo(versions, project));
}
if (includeShared) {
Expand All @@ -395,10 +397,12 @@ public List<ModelViewVo> listModelVersionView(
return list;
}

public List<ModelViewVo> listRecentlyModelVersionView(String projectUrl, Integer limit) {
public List<ModelViewVo> listRecentlyModelVersionView(
String projectUrl, Integer limit, BizType bizType, Long bizId) {
var project = projectService.findProject(projectUrl);
var userId = userService.currentUserDetail().getId();
var list = modelVersionMapper.listModelVersionsByUserRecentlyUsed(project.getId(), userId, limit);
var list = modelVersionMapper.listModelVersionsByUserRecentlyUsed(
project.getId(), userId, limit, bizType, bizId);
return viewEntityToVo(list, project);
}

Expand Down Expand Up @@ -446,6 +450,7 @@ private List<ModelViewVo> viewEntityToVo(List<ModelVersionViewEntity> list, Proj
.latest(entity.getId() != null && entity.getId().equals(latest))
.createdTime(entity.getCreatedTime().getTime())
.shared(toInt(entity.getShared()))
.draft(entity.getDraft())
.builtInRuntime(entity.getBuiltInRuntime())
.stepSpecs(jobSpecParser.parseAndFlattenStepFromYaml(entity.getJobs()))
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package ai.starwhale.mlops.domain.model.mapper;

import ai.starwhale.mlops.domain.job.BizType;
import ai.starwhale.mlops.domain.model.po.ModelVersionEntity;
import ai.starwhale.mlops.domain.model.po.ModelVersionViewEntity;
import cn.hutool.core.util.StrUtil;
import java.util.List;
import java.util.Objects;
import org.apache.hadoop.yarn.webapp.hamlet.Hamlet.SELECT;
import org.apache.ibatis.annotations.Insert;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Options;
Expand Down Expand Up @@ -65,37 +67,20 @@ List<ModelVersionEntity> list(
@Update("update model_version set model_id = #{modelId}, draft = 0 where id = #{id}")
int updateModelRef(@Param("id") Long id, @Param("modelId") Long modelId);

@Select("select " + VERSION_VIEW_COLUMNS
+ " from model_info as m, model_version as v, project_info as p, user_info as u"
+ " where v.model_id = m.id"
+ " and m.project_id = p.id"
+ " and m.owner_id = u.id"
+ " and m.deleted_time = 0"
+ " and p.is_deleted = 0"
+ " and p.id = #{projectId}"
+ " order by m.id desc, v.version_order desc")
List<ModelVersionViewEntity> listModelVersionViewByProject(@Param("projectId") Long projectId);

@Select({"<script>",
"select " + VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id",
"from model_version as v",
"inner join model_info as m on m.id = v.model_id",
"inner join job_info as j on j.model_version_id = v.id",
"inner join project_info as p on p.id = m.project_id",
"inner join user_info as u on u.id = m.owner_id",
"where",
// models in current project or other project but is shared
" (m.project_id = #{projectId} or (m.project_id != #{projectId} and v.shared = 1 and v.draft = 0))",
" and m.deleted_time = 0",
" and j.owner_id = #{userId}", // jobs in current user
" and j.project_id = #{projectId}", // jobs in current project
"group by v.id",
"order by job_id desc",
"limit #{limit}", // recently
"</script>"
})
@SelectProvider(type = ModelVersionProvider.class, method = "listInProject")
List<ModelVersionViewEntity> listModelVersionViewByProject(
@Param("projectId") Long projectId,
@Param("bizType") BizType bizType,
@Param("bizId") Long bizId
);

@SelectProvider(type = ModelVersionProvider.class, method = "listUserRecentlyUsed")
List<ModelVersionViewEntity> listModelVersionsByUserRecentlyUsed(
@Param("projectId") Long projectId, @Param("userId") Long userId, @Param("limit") Integer limit
@Param("projectId") Long projectId,
@Param("userId") Long userId,
@Param("limit") Integer limit,
@Param("bizType") BizType bizType,
@Param("bizId") Long bizId
);

@Select("select " + VERSION_VIEW_COLUMNS
Expand Down Expand Up @@ -155,6 +140,81 @@ ModelVersionEntity findByVersionOrder(

class ModelVersionProvider {

public String listInProject(
@Param("projectId") Long projectId,
@Param("bizType") BizType bizType,
@Param("bizId") Long bizId
) {
return new SQL() {
{
SELECT(VERSION_VIEW_COLUMNS);
FROM("model_version as v");
INNER_JOIN("model_info as m on m.id = v.model_id");
INNER_JOIN("project_info as p on p.id = m.project_id");
INNER_JOIN("user_info as u on u.id = m.owner_id");
if (bizType == BizType.FINE_TUNE && bizId != null) {
// non-draft models in current project or draft models in current space
WHERE("(v.draft=0 or v.id in "
+ "(select target_model_version_id from fine_tune where space_id = #{bizId})"
+ ")"
);
} else {
// (non-draft) models in current project
WHERE("v.draft=0");
}
WHERE("m.deleted_time = 0");
WHERE("p.id = #{projectId}");
ORDER_BY("m.id desc");
ORDER_BY("v.version_order desc");
}

}.toString();
}

public String listUserRecentlyUsed(
@Param("projectId") Long projectId,
@Param("userId") Long userId,
@Param("limit") Integer limit,
@Param("bizType") BizType bizType,
@Param("bizId") Long bizId
) {
return new SQL() {
{
SELECT(VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id");
FROM("model_version as v");
INNER_JOIN("model_info as m on m.id = v.model_id");
INNER_JOIN("job_info as j on j.model_version_id = v.id");
INNER_JOIN("project_info as p on p.id = m.project_id");
INNER_JOIN("user_info as u on u.id = m.owner_id");
if (bizType == BizType.FINE_TUNE && bizId != null) {
// models in current project(non-draft or draft in current space)
// or (non-draft) other project but is shared
WHERE("("
+ "(m.project_id = #{projectId} and "
+ " (v.draft=0 or v.id in "
+ " (select target_model_version_id from fine_tune where space_id = #{bizId})"
+ " )"
+ ") "
+ "or "
+ "(m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)"
+ ")");
} else {
// (non-draft) models in current project or other project but is shared
WHERE("("
+ "(m.project_id = #{projectId} and v.draft=0) or "
+ "(m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)"
+ ")");
}
WHERE("m.deleted_time = 0");
WHERE("j.owner_id = #{userId}"); // jobs in current user
WHERE("j.project_id = #{projectId}"); // jobs in current project
GROUP_BY("v.id");
ORDER_BY("job_id desc");
LIMIT("#{limit}"); // recently
}
}.toString();
}

public String listSql(
@Param("modelId") Long modelId,
@Param("namePrefix") String namePrefix, @Param("draft") Boolean draft
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class ModelVersionViewEntity extends BaseEntity implements HasId {

private Boolean shared;

private Boolean draft;

private String storagePath;

private String builtInRuntime;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,17 @@ public void testHeadModel() {
@ParameterizedTest
@CsvSource({"all, 2", "project, 1", "shared, 1"})
public void testListModelTree(DataScope scope, int listCount) {
given(modelService.listModelVersionView(anyString(), eq(true), eq(true)))
given(modelService.listModelVersionView(anyString(), eq(true), eq(true), any(), any()))
.willReturn(List.of(
ModelViewVo.builder().projectName("p").build(),
ModelViewVo.builder().projectName("p").build()
));
given(modelService.listModelVersionView(anyString(), eq(false), eq(true)))
given(modelService.listModelVersionView(anyString(), eq(false), eq(true), any(), any()))
.willReturn(List.of(ModelViewVo.builder().projectName("p").build()));
given(modelService.listModelVersionView(anyString(), eq(true), eq(false)))
given(modelService.listModelVersionView(anyString(), eq(true), eq(false), any(), any()))
.willReturn(List.of(ModelViewVo.builder().projectName("p").build()));

var resp = controller.listModelTree("1", scope);
var resp = controller.listModelTree("1", scope, null, null);
assertThat(resp.getStatusCode(), is(HttpStatus.OK));
assertThat(resp.getBody(), notNullValue());
assertThat(resp.getBody().getData(), allOf(
Expand All @@ -211,13 +211,13 @@ public void testListModelTree(DataScope scope, int listCount) {

@Test
public void testRecentListModelTree() {
given(modelService.listRecentlyModelVersionView(anyString(), eq(5)))
given(modelService.listRecentlyModelVersionView(anyString(), eq(5), any(), any()))
.willReturn(List.of(
ModelViewVo.builder().projectName("p").build(),
ModelViewVo.builder().projectName("p").build()
));

var resp = controller.recentModelTree("1", 5);
var resp = controller.recentModelTree("1", 5, null, null);
assertThat(resp.getStatusCode(), is(HttpStatus.OK));
assertThat(resp.getBody(), notNullValue());
assertThat(resp.getBody().getData(), allOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import ai.starwhale.mlops.domain.blob.BlobService;
import ai.starwhale.mlops.domain.bundle.tag.BundleVersionTagDao;
import ai.starwhale.mlops.domain.ft.FineTuneDomainService;
import ai.starwhale.mlops.domain.job.BizType;
import ai.starwhale.mlops.domain.job.ModelServingService;
import ai.starwhale.mlops.domain.job.cache.HotJobHolder;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
Expand Down Expand Up @@ -804,7 +805,7 @@ public void testShareModelVersion() {

@Test
public void testListModelVersionView() {
var res = modelService.listModelVersionView("1", true, true);
var res = modelService.listModelVersionView("1", true, true, null, null);
assertEquals(2, res.size());
assertThat(res.get(1), allOf(hasProperty("projectName", is("starwhale")),
hasProperty("modelName", is("m"))));
Expand All @@ -829,6 +830,18 @@ public void testListModelVersionView() {
allOf(hasProperty("versionName", is("v4")),
hasProperty("alias", is("v4")),
hasProperty("latest", is(true)))));

res = modelService.listModelVersionView("1", false, true, null, null);
assertEquals(2, res.size());

res = modelService.listModelVersionView("1", false, true, BizType.FINE_TUNE, 1L);
assertEquals(2, res.size());

res = modelService.listRecentlyModelVersionView("1", 5, null, null);
assertEquals(0, res.size());

res = modelService.listRecentlyModelVersionView("1", 5, BizType.FINE_TUNE, 1L);
assertEquals(0, res.size());
}

@Test
Expand Down

0 comments on commit 046d6a0

Please sign in to comment.