Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(controller): add more filter for model tree #3013

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,7 @@ class ModelVersionViewVo(SwBaseModel):
latest: bool
tags: Optional[List[str]] = None
shared: int
draft: Optional[bool] = None
step_specs: List[StepSpec] = Field(..., alias='stepSpecs')
built_in_runtime: Optional[str] = Field(None, alias='builtInRuntime')
created_time: int = Field(..., alias='createdTime')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.starwhale.mlops.api.protocol.ResponseMessage;
import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceCreateRequest;
import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceVo;
import ai.starwhale.mlops.api.protocol.model.ModelViewVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.configuration.FeaturesProperties;
import ai.starwhale.mlops.domain.ft.FineTuneAppService;
Expand All @@ -30,8 +31,14 @@
import ai.starwhale.mlops.domain.user.UserService;
import com.github.pagehelper.PageInfo;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.enums.ParameterIn;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.List;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -195,4 +202,38 @@ public ResponseEntity<ResponseMessage<String>> exportEval(
return ResponseEntity.ok(Code.success.asResponse(""));
}

@GetMapping(
value = "/project/{projectId}/ftspace/{spaceId}/model-tree",
produces = MediaType.APPLICATION_JSON_VALUE
)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')")
ResponseEntity<ResponseMessage<List<ModelViewVo>>> listModelTree(
@PathVariable("projectId") Long projectId,
@PathVariable("spaceId") Long spaceId
) {
return ResponseEntity.ok(Code.success.asResponse(
fineTuneAppService.listModelVersionView(projectId, spaceId)
));
}

@GetMapping(
value = "/project/{projectId}/ftspace/{spaceId}/recent-model-tree",
produces = MediaType.APPLICATION_JSON_VALUE
)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')")
ResponseEntity<ResponseMessage<List<ModelViewVo>>> recentModelTree(
@PathVariable("projectId") Long projectId,
@PathVariable("spaceId") Long spaceId,
@Parameter(in = ParameterIn.QUERY, description = "Data limit", schema = @Schema())
@RequestParam(required = false, defaultValue = "5")
@Valid
@Min(value = 1, message = "limit must be greater than or equal to 1")
@Max(value = 50, message = "limit must be less than or equal to 50")
Integer limit
) {
return ResponseEntity.ok(Code.success.asResponse(
fineTuneAppService.listRecentlyModelVersionView(projectId, spaceId, limit)
));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public class ModelVersionViewVo {
@NotNull
private Integer shared;

private Boolean draft;

@NotNull
private List<StepSpec> stepSpecs;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static ai.starwhale.mlops.domain.evaluation.EvaluationService.TABLE_NAME_FORMAT;

import ai.starwhale.mlops.api.protocol.job.JobRequest;
import ai.starwhale.mlops.api.protocol.model.ModelViewVo;
import ai.starwhale.mlops.api.protocol.model.ModelVo;
import ai.starwhale.mlops.common.Constants;
import ai.starwhale.mlops.common.IdConverter;
Expand Down Expand Up @@ -372,6 +373,14 @@ public void releaseFt(
}
}

public List<ModelViewVo> listModelVersionView(Long projectId, Long spaceId) {
return modelService.listFtSpaceModelVersionView(String.valueOf(projectId), spaceId);
}

public List<ModelViewVo> listRecentlyModelVersionView(Long projectId, Long spaceId, Integer limit) {
return modelService.listRecentlyModelVersionView(String.valueOf(projectId), spaceId, limit);
}

private void checkFeatureEnabled() throws StarwhaleApiException {
if (!this.featuresProperties.isFineTuneEnabled()) {
throw new StarwhaleApiException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ public void shareModelVersion(String projectUrl, String modelUrl, String version
}

public List<ModelViewVo> listModelVersionView(
String projectUrl, boolean includeShared, boolean includeCurrentProject) {
String projectUrl, boolean includeShared, boolean includeCurrentProject
) {
var project = projectService.findProject(projectUrl);
var list = new ArrayList<ModelViewVo>();
if (includeCurrentProject) {
Expand All @@ -402,6 +403,20 @@ public List<ModelViewVo> listRecentlyModelVersionView(String projectUrl, Integer
return viewEntityToVo(list, project);
}

public List<ModelViewVo> listRecentlyModelVersionView(String projectUrl, Long spaceId, Integer limit) {
var project = projectService.findProject(projectUrl);
var userId = userService.currentUserDetail().getId();
var list = modelVersionMapper.listModelVersionsByUserRecentlyUsedInFtSpace(
project.getId(), userId, spaceId, limit);
return viewEntityToVo(list, project);
}

public List<ModelViewVo> listFtSpaceModelVersionView(String projectUrl, Long spaceId) {
var project = projectService.findProject(projectUrl);
var list = modelVersionMapper.listModelVersionViewByFtSpace(project.getId(), spaceId);
return viewEntityToVo(list, project);
}

private List<ModelViewVo> viewEntityToVo(List<ModelVersionViewEntity> list, Project currentProject) {
Map<Long, ModelViewVo> map = new LinkedHashMap<>();
var tags = new HashMap<Long, Map<Long, List<String>>>();
Expand Down Expand Up @@ -446,6 +461,7 @@ private List<ModelViewVo> viewEntityToVo(List<ModelVersionViewEntity> list, Proj
.latest(entity.getId() != null && entity.getId().equals(latest))
.createdTime(entity.getCreatedTime().getTime())
.shared(toInt(entity.getShared()))
.draft(entity.getDraft())
.builtInRuntime(entity.getBuiltInRuntime())
.stepSpecs(jobSpecParser.parseAndFlattenStepFromYaml(entity.getJobs()))
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,37 +65,28 @@ List<ModelVersionEntity> list(
@Update("update model_version set model_id = #{modelId}, draft = 0 where id = #{id}")
int updateModelRef(@Param("id") Long id, @Param("modelId") Long modelId);

@Select("select " + VERSION_VIEW_COLUMNS
+ " from model_info as m, model_version as v, project_info as p, user_info as u"
+ " where v.model_id = m.id"
+ " and m.project_id = p.id"
+ " and m.owner_id = u.id"
+ " and m.deleted_time = 0"
+ " and p.is_deleted = 0"
+ " and p.id = #{projectId}"
+ " order by m.id desc, v.version_order desc")
@SelectProvider(type = ModelVersionProvider.class, method = "listInProject")
List<ModelVersionViewEntity> listModelVersionViewByProject(@Param("projectId") Long projectId);

@Select({"<script>",
"select " + VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id",
"from model_version as v",
"inner join model_info as m on m.id = v.model_id",
"inner join job_info as j on j.model_version_id = v.id",
"inner join project_info as p on p.id = m.project_id",
"inner join user_info as u on u.id = m.owner_id",
"where",
// models in current project or other project but is shared
" (m.project_id = #{projectId} or (m.project_id != #{projectId} and v.shared = 1 and v.draft = 0))",
" and m.deleted_time = 0",
" and j.owner_id = #{userId}", // jobs in current user
" and j.project_id = #{projectId}", // jobs in current project
"group by v.id",
"order by job_id desc",
"limit #{limit}", // recently
"</script>"
})
@SelectProvider(type = ModelVersionProvider.class, method = "listUserRecentlyUsed")
List<ModelVersionViewEntity> listModelVersionsByUserRecentlyUsed(
@Param("projectId") Long projectId, @Param("userId") Long userId, @Param("limit") Integer limit
@Param("projectId") Long projectId,
@Param("userId") Long userId,
@Param("limit") Integer limit
);

@SelectProvider(type = ModelVersionProvider.class, method = "listInFtSpace")
List<ModelVersionViewEntity> listModelVersionViewByFtSpace(
@Param("projectId") Long projectId,
@Param("spaceId") Long spaceId
);

@SelectProvider(type = ModelVersionProvider.class, method = "listUserRecentlyUsedWithFtSpace")
List<ModelVersionViewEntity> listModelVersionsByUserRecentlyUsedInFtSpace(
@Param("projectId") Long projectId,
@Param("userId") Long userId,
@Param("spaceId") Long spaceId,
@Param("limit") Integer limit
);

@Select("select " + VERSION_VIEW_COLUMNS
Expand Down Expand Up @@ -155,6 +146,107 @@ ModelVersionEntity findByVersionOrder(

class ModelVersionProvider {

public String listInProject(@Param("projectId") Long projectId) {
return new SQL() {
{
SELECT(VERSION_VIEW_COLUMNS);
FROM("model_version as v");
INNER_JOIN("model_info as m on m.id = v.model_id");
INNER_JOIN("project_info as p on p.id = m.project_id");
INNER_JOIN("user_info as u on u.id = m.owner_id");
// (non-draft) models in current project
WHERE("v.draft=0");
WHERE("m.deleted_time = 0");
WHERE("p.id = #{projectId}");
ORDER_BY("m.id desc");
ORDER_BY("v.version_order desc");
}

}.toString();
}

public String listInFtSpace(
@Param("projectId") Long projectId,
@Param("spaceId") Long spaceId) {
return new SQL() {
{
SELECT(VERSION_VIEW_COLUMNS);
FROM("model_version as v");
INNER_JOIN("model_info as m on m.id = v.model_id");
INNER_JOIN("project_info as p on p.id = m.project_id");
INNER_JOIN("user_info as u on u.id = m.owner_id");
INNER_JOIN("fine_tune as ft on ft.target_model_version_id = v.id");
// all models in current space
WHERE("ft.space_id = #{spaceId}");
WHERE("m.deleted_time = 0");
WHERE("p.id = #{projectId}");
ORDER_BY("m.id desc");
ORDER_BY("v.version_order desc");
}

}.toString();
}

public String listUserRecentlyUsed(
@Param("projectId") Long projectId,
@Param("userId") Long userId,
@Param("limit") Integer limit) {
return new SQL() {
{
SELECT(VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id");
FROM("model_version as v");
INNER_JOIN("model_info as m on m.id = v.model_id");
INNER_JOIN("job_info as j on j.model_version_id = v.id");
INNER_JOIN("project_info as p on p.id = m.project_id");
INNER_JOIN("user_info as u on u.id = m.owner_id");
// (non-draft) models in current project or other project but is shared
WHERE("("
+ "(m.project_id = #{projectId} and v.draft=0) "
+ "or (m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)"
+ ")");

WHERE("m.deleted_time = 0");
WHERE("j.owner_id = #{userId}"); // jobs in current user
WHERE("j.project_id = #{projectId}"); // jobs in current project
GROUP_BY("v.id");
ORDER_BY("job_id desc"); // recently
LIMIT("#{limit}");
}
}.toString();
}

public String listUserRecentlyUsedWithFtSpace(
@Param("projectId") Long projectId,
@Param("userId") Long userId,
@Param("spaceId") Long spaceId,
@Param("limit") Integer limit) {
return new SQL() {
{
SELECT(VERSION_VIEW_COLUMNS + ", MAX(j.id) as job_id");
FROM("model_version as v");
INNER_JOIN("model_info as m on m.id = v.model_id");
INNER_JOIN("job_info as j on j.model_version_id = v.id");
INNER_JOIN("project_info as p on p.id = m.project_id");
INNER_JOIN("user_info as u on u.id = m.owner_id");
// non-draft models in current project
// or all in current space
// or (non-draft) other project but is shared
WHERE("("
+ "(m.project_id = #{projectId} and v.draft=0)"
+ "or v.id in "
+ " (select target_model_version_id from fine_tune where space_id = #{spaceId})"
+ "or (m.project_id != #{projectId} and v.shared = 1 and v.draft = 0)"
+ ")");
WHERE("m.deleted_time = 0");
WHERE("j.owner_id = #{userId}"); // jobs in current user
WHERE("j.project_id = #{projectId}"); // jobs in current project
GROUP_BY("v.id");
ORDER_BY("job_id desc"); // recently
LIMIT("#{limit}");
}
}.toString();
}

public String listSql(
@Param("modelId") Long modelId,
@Param("namePrefix") String namePrefix, @Param("draft") Boolean draft
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class ModelVersionViewEntity extends BaseEntity implements HasId {

private Boolean shared;

private Boolean draft;

private String storagePath;

private String builtInRuntime;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,24 @@ public void testListModelVersionView() {
allOf(hasProperty("versionName", is("v4")),
hasProperty("alias", is("v4")),
hasProperty("latest", is(true)))));

res = modelService.listModelVersionView("1", false, true);
assertEquals(2, res.size());

res = modelService.listModelVersionView("1", false, true);
assertEquals(2, res.size());

res = modelService.listFtSpaceModelVersionView("1", 1L);
assertEquals(0, res.size());

res = modelService.listFtSpaceModelVersionView("1", 1L);
assertEquals(0, res.size());

res = modelService.listRecentlyModelVersionView("1", 5);
assertEquals(0, res.size());

res = modelService.listRecentlyModelVersionView("1", 1L, 5);
assertEquals(0, res.size());
}

@Test
Expand Down
Loading