Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(runtime): support runtime replace image registry #2345

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/starwhale/core/runtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
from starwhale.base.uri.project import Project as ProjectURI
from starwhale.base.uri.resource import Resource, ResourceType

from .store import RuntimeStorage
from .store import RuntimeStorage, get_docker_run_image_by_manifest

RUNTIME_API_VERSION = "1.1"
_TEMPLATE_DIR = Path(__file__).parent / "template"
Expand Down Expand Up @@ -1758,7 +1758,7 @@ def _render_dockerfile(_manifest: t.Dict[str, t.Any]) -> None:
_template = _env.get_template("Dockerfile.tmpl")
_pip = _manifest["configs"].get("pip", {})
_out = _template.render(
base_image=_manifest["base_image"],
base_image=get_docker_run_image_by_manifest(manifest=_manifest),
runtime_name=self.uri.name,
runtime_version=_manifest["version"],
pypi_index_url=_pip.get("index_url", ""),
Expand Down
13 changes: 12 additions & 1 deletion client/tests/core/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,6 +2622,17 @@ def test_dockerize(self, m_check: MagicMock) -> None:
manifest["version"] = version
manifest["configs"]["docker"]["image"] = image

custom_image = "docker.io/sw/base:v1"
manifest["docker"] = {
"custom_run_image": custom_image,
"builtin_run_image": {
"repo": "self-registry/sw",
"name": "starwhale",
"tag": "v2-cuda11.7",
"fullname": "self-registry/sw/starwhale:v2",
},
}

sr = StandaloneRuntime(uri)

ensure_dir(sr.store.snapshot_workdir)
Expand All @@ -2640,7 +2651,7 @@ def test_dockerize(self, m_check: MagicMock) -> None:
assert dockerfile_path.exists()
assert dockerignore_path.exists()
dockerfile_content = dockerfile_path.read_text()
assert f"BASE_IMAGE={manifest['base_image']}" in dockerfile_content
assert f"BASE_IMAGE={custom_image}" in dockerfile_content
assert f"starwhale_runtime_version={version}" in dockerfile_content

assert m_check.call_count == 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package ai.starwhale.mlops.common;

import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.springframework.util.StringUtils;
Expand All @@ -35,8 +37,47 @@ public DockerImage(String registry, String image) {
this.image = image;
}

/**
* please refer to https://github.com/distribution/distribution/blob/v2.7.1/reference/reference.go
*/
private static final Pattern PATTERN_IMAGE_FULL = Pattern.compile("^(.+?)\\/(.+)$");

/**
* @param imageNameFull such as ghcr.io/starwhale-ai/starwhale:0.3.5-rc123.dev12432344
*/
public DockerImage(String imageNameFull) {
Matcher matcher = PATTERN_IMAGE_FULL.matcher(imageNameFull);
if (!matcher.matches()) {
this.registry = "";
this.image = imageNameFull;
} else {
String candidateRegistry = matcher.group(1);
if (isDomain(candidateRegistry)) {
this.registry = candidateRegistry;
image = matcher.group(2);
} else {
this.registry = "";
this.image = imageNameFull;
}

}
}

private static final Pattern PATTERN_DOMAIN_LOCAL_HOST = Pattern.compile("localhost(:\\d+)?");

private static boolean isDomain(String candidate) {
return candidate.contains(".") || PATTERN_DOMAIN_LOCAL_HOST.matcher(candidate).matches();
}

private static final String SLASH = "/";

public String resolve(String newRegistry) {
if (!StringUtils.hasText(newRegistry)) {
newRegistry = this.registry;
}
return StringUtils.trimTrailingCharacter(newRegistry, '/') + SLASH + image;
}

public String toString() {
return StringUtils.trimTrailingCharacter(registry, '/') + SLASH + image;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ private void deploy(
var name = getServiceName(id);

String builtImage = runtime.getBuiltImage();
String image = StringUtils.isNotEmpty(builtImage) ? builtImage : runtime.getImage();
String image = StringUtils.isNotEmpty(builtImage) ? builtImage :
runtime.getImage(systemSettingService.getSystemSetting().getDockerSetting().getRegistryForPull());

var rt = runtimeMapper.find(runtime.getRuntimeId());
var md = modelMapper.find(model.getModelId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ public Job fromEntity(JobEntity jobEntity) {
RuntimeEntity runtimeEntity = runtimeMapper.find(
runtimeVersionEntity.getRuntimeId());
String builtImage = runtimeVersionEntity.getBuiltImage();
String image = StringUtils.hasText(builtImage) ? builtImage : runtimeVersionEntity.getImage();
String image = StringUtils.hasText(builtImage) ? builtImage : runtimeVersionEntity.getImage(
systemSettingService.getSystemSetting().getDockerSetting().getRegistryForPull());

Job job;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@
import java.util.stream.Collectors;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
Expand All @@ -117,6 +119,7 @@
@Service
public class RuntimeService {

static final String RUNTIME_MANIFEST = "_manifest.yaml";
private final RuntimeMapper runtimeMapper;
private final RuntimeVersionMapper runtimeVersionMapper;
private final StorageService storageService;
Expand Down Expand Up @@ -492,15 +495,12 @@ public void upload(MultipartFile dsFile, ClientRuntimeRequest uploadRequest) {
}
/* create new entity */
if (!entityExists) {
RuntimeManifest runtimeManifestObj;
String runtimeManifest;
try (final InputStream inputStream = dsFile.getInputStream()) {
// only extract the eval job file content
// extract the manifest file content
runtimeManifest = new String(
Objects.requireNonNull(
TarFileUtil.getContentFromTarFile(inputStream, "", "_manifest.yaml")));
runtimeManifestObj = Constants.yamlMapper.readValue(runtimeManifest,
RuntimeManifest.class);
TarFileUtil.getContentFromTarFile(inputStream, "", RUNTIME_MANIFEST)));
} catch (IOException e) {
log.error("upload runtime failed {}", uploadRequest.getRuntime(), e);
throw new StarwhaleApiException(new SwProcessException(ErrorType.SYSTEM),
Expand All @@ -512,7 +512,6 @@ public void upload(MultipartFile dsFile, ClientRuntimeRequest uploadRequest) {
.runtimeId(entity.getId())
.versionName(uploadRequest.version())
.versionMeta(runtimeManifest)
.image(null == runtimeManifestObj ? null : runtimeManifestObj.getBaseImage())
.build();
runtimeVersionMapper.insert(version);
RevertManager.create(bundleManager, runtimeDao)
Expand All @@ -526,6 +525,33 @@ public static final class RuntimeManifest {

@JsonProperty("base_image")
String baseImage;

@JsonProperty("docker")
Docker docker;

@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public static class Docker {
@JsonProperty("builtin_run_image")
BuiltinImage builtinImage;

@JsonProperty("custom_run_image")
String customImage;
}

@Data
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public static class BuiltinImage {
@JsonProperty("fullname")
String fullName;
String name;
String repo;
String tag;
}
}

public void pull(String projectUrl, String runtimeUrl, String versionUrl, HttpServletResponse httpResponse) {
Expand Down Expand Up @@ -608,8 +634,11 @@ public BuildImageResult buildImage(String projectUrl, String runtimeUrl, String
// record image to annotations
k8sJobTemplate.updateAnnotations(job.getMetadata(), Map.of("image", image.toString()));

var baseImage = runtimeVersion.getImage(dockerSetting.getRegistryForPull());
Map<String, ContainerOverwriteSpec> ret = new HashMap<>();
List<V1EnvVar> envVars = new ArrayList<>(List.of(
new V1EnvVar().name("SW_IMAGE_REPO").value(
new DockerImage(baseImage).getRegistry()),
new V1EnvVar().name("SW_INSTANCE_URI").value(instanceUri),
new V1EnvVar().name("SW_PROJECT").value(project.getName()),
new V1EnvVar().name("SW_RUNTIME_VERSION").value(
Expand All @@ -631,7 +660,7 @@ public BuildImageResult buildImage(String projectUrl, String runtimeUrl, String
k8sJobTemplate.getInitContainerTemplates(job).forEach(templateContainer -> {
ContainerOverwriteSpec containerOverwriteSpec = new ContainerOverwriteSpec(templateContainer.getName());
containerOverwriteSpec.setEnvs(envVars);
containerOverwriteSpec.setImage(runtimeVersion.getImage());
containerOverwriteSpec.setImage(baseImage);
ret.put(templateContainer.getName(), containerOverwriteSpec);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.starwhale.mlops.api.protocol.runtime.RuntimeVersionVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.common.VersionAliasConverter;
import ai.starwhale.mlops.configuration.DockerSetting;
import ai.starwhale.mlops.domain.runtime.po.RuntimeVersionEntity;
import ai.starwhale.mlops.exception.ConvertException;
import org.springframework.stereotype.Component;
Expand All @@ -31,10 +32,14 @@ public class RuntimeVersionConverter {
private final IdConverter idConvertor;
private final VersionAliasConverter versionAliasConvertor;

private final DockerSetting dockerSetting;

public RuntimeVersionConverter(IdConverter idConvertor,
VersionAliasConverter versionAliasConvertor) {
VersionAliasConverter versionAliasConvertor,
DockerSetting dockerSetting) {
this.idConvertor = idConvertor;
this.versionAliasConvertor = versionAliasConvertor;
this.dockerSetting = dockerSetting;
}

public RuntimeVersionVo convert(RuntimeVersionEntity entity)
Expand All @@ -50,7 +55,7 @@ public RuntimeVersionVo convert(RuntimeVersionEntity entity, RuntimeVersionEntit
.alias(versionAliasConvertor.convert(entity.getVersionOrder(), latest, entity))
.tag(entity.getVersionTag())
.meta(entity.getVersionMeta())
.image(entity.getImage())
.image(entity.getImage(dockerSetting.getRegistryForPull()))
.builtImage(entity.getBuiltImage())
.shared(toInt(entity.getShared()))
.createdTime(entity.getCreatedTime().getTime())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,23 @@
package ai.starwhale.mlops.domain.runtime.po;

import ai.starwhale.mlops.common.BaseEntity;
import ai.starwhale.mlops.common.Constants;
import ai.starwhale.mlops.common.DockerImage;
import ai.starwhale.mlops.domain.bundle.base.BundleVersionEntity;
import ai.starwhale.mlops.domain.runtime.RuntimeService;
import ai.starwhale.mlops.exception.SwValidationException;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;

@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
Expand All @@ -47,6 +55,7 @@ public class RuntimeVersionEntity extends BaseEntity implements BundleVersionEnt

private String storagePath;

@Deprecated
private String image;

private String builtImage;
Expand All @@ -57,4 +66,37 @@ public class RuntimeVersionEntity extends BaseEntity implements BundleVersionEnt
public String getName() {
return versionName;
}

public static String extractImage(String manifest, String replaceableBuiltinRegistry) {
try {
var manifestObj = Constants.yamlMapper.readValue(
manifest, RuntimeService.RuntimeManifest.class);
if (manifestObj == null) {
return null;
}
if (manifestObj.getDocker() != null) {
var docker = manifestObj.getDocker();
if (StringUtils.hasText(docker.getCustomImage())) {
return docker.getCustomImage();
} else {
var dockerImage = new DockerImage(docker.getBuiltinImage().getFullName());
return StringUtils.hasText(replaceableBuiltinRegistry)
? dockerImage.resolve(replaceableBuiltinRegistry) : dockerImage.toString();
}
} else {
return manifestObj.getBaseImage();
}
} catch (JsonProcessingException e) {
log.error("runtime manifest parse error", e);
throw new SwValidationException(SwValidationException.ValidSubject.RUNTIME, "manifest parse error");
}
}

public String getImage() {
return extractImage(this.versionMeta, null);
}

public String getImage(String newRegistry) {
return extractImage(this.versionMeta, newRegistry);
}
}
Loading