Skip to content

Commit

Permalink
bk
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui committed Nov 30, 2022
1 parent a8e199f commit c653847
Show file tree
Hide file tree
Showing 11 changed files with 287 additions and 30 deletions.
8 changes: 8 additions & 0 deletions client/starwhale/api/_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from starwhale.api._impl import wrapper
from starwhale.base.type import URIType, RunSubDirType
from starwhale.utils.log import StreamWrapper
from starwhale.api.service import Service
from starwhale.utils.error import FieldTypeOrValueError
from starwhale.api._impl.job import context_holder
from starwhale.core.job.model import STATUS
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
ignore_error: bool = False,
flush_result: bool = False,
) -> None:
self.svc = Service()
self.context: Context = context_holder.context

# TODO: add args for compare result and label directly
Expand Down Expand Up @@ -283,3 +285,9 @@ def _starwhale_internal_run_ppl(self) -> None:
def _update_status(self, status: str) -> None:
fpath = self.status_dir / CURRENT_FNAME
ensure_file(fpath, status)

def add_api(self, input, output, func, name):
self.svc.add_api(input, output, func, name)

def serve(self, addr: str, port: int):
self.svc.serve(addr, port)
99 changes: 99 additions & 0 deletions client/starwhale/api/_impl/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
import base64
import typing as t
from abc import ABC, abstractmethod
from dataclasses import dataclass

from starwhale.core.dataset.type import MIMEType, ArtifactType, BaseArtifact


@dataclass
class Api:
input: t.Any
output: t.Any
func: t.Callable
uri: str

def to_yaml(self):
return self.func.__qualname__


class Request(ABC):
@abstractmethod
def to_dataset_type(self, req: t.Any) -> BaseArtifact:
raise NotImplementedError()


class Response(ABC):
@abstractmethod
def to_json(self, resp: t.Any) -> bytes:
raise NotImplementedError()


class GrayscaleImageRequest(Request):
def to_dataset_type(self, req: t.Any) -> BaseArtifact:
if isinstance(req, (str, bytes)):
raw = base64.b64decode(req)
elif isinstance(req, dict):
if "img" not in req:
raise Exception("can not get image")
raw = base64.b64decode(req["img"])
else:
raise Exception("can not get image from unknown request type")
return BaseArtifact.reflect(
raw,
{
"type": ArtifactType.Image.value,
"mime_type": MIMEType.GRAYSCALE,
"shape": [28, 28, 1],
},
)


class JsonResponse(Response):
def to_json(self, resp: t.Any) -> bytes:
return json.dumps(resp).encode("utf8")


class Service:
def __init__(self):
self.apis = {}

def api(self, input: t.Any, output: t.Any, uri: str = None):
def decorator(func: t.Any) -> t.Any:
self.add_api(input, output, func, uri or func.__name__)
return func

return decorator

# TODO: support checking duplication
def add_api(self, input: t.Any, output: t.Any, func: t.Any, uri: str):
_api = Api(input, output, func, uri)
self.apis[uri] = _api

def add_api_instance(self, api: Api) -> None:
self.apis[api.uri] = api

def serve(self, addr: str, port: int, handler_list: t.Optional[t.List[str]] = None):
"""
Default serve implementation, users can override this method
:param addr
:param port
:param handler_list, use all handlers if None
:return: None
"""
apis = self.apis

import flask

app = flask.Flask(__name__)
if handler_list:
apis = {uri: apis[uri] for uri in apis if uri in handler_list}
for api in apis:
app.add_url_rule(
rule=api,
endpoint=None,
view_func=self.apis[api].func,
methods=["POST"],
)
app.run(addr, port)
3 changes: 3 additions & 0 deletions client/starwhale/api/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._impl.service import Request, Service, JsonResponse, GrayscaleImageRequest

__all__ = ["Service", "Request", "GrayscaleImageRequest", "JsonResponse"]
1 change: 1 addition & 0 deletions client/starwhale/consts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DEFAULT_MANIFEST_NAME = "_manifest.yaml"
DEFAULT_EVALUATION_JOB_NAME = "default"
DEFAULT_EVALUATION_JOBS_FNAME = "eval_jobs.yaml"
DEFAULT_EVALUATION_SVC_META_FNAME = "svc.yaml"
DEFAULT_EVALUATION_PIPELINE = "starwhale.core.model.default_handler"
DEFAULT_LOCAL_SW_CONTROLLER_ADDR = "localhost:7827"
LOCAL_CONFIG_VERSION = "2.0"
Expand Down
41 changes: 41 additions & 0 deletions client/starwhale/core/model/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,44 @@ def _eval(
task_num=override_task_num,
dataset_uris=datasets,
)


@model_cmd.command("serve")
@click.argument("target")
@click.option(
"-f",
"--model-yaml",
default=DefaultYAMLName.MODEL,
help="Model yaml filename, default use ${MODEL_DIR}/model.yaml file",
)
@click.option(
"-p",
"--project",
envvar=SWEnv.project,
default="",
help=f"project name, env is {SWEnv.project}",
)
@click.option(
"--version",
envvar=SWEnv.eval_version,
default=None,
help=f"Evaluation job version, env is {SWEnv.eval_version}",
)
@click.option(
"--handler",
default=None,
help="List of service handlers, use all by default",
multiple=True,
)
@click.option("--runtime", default="", help="runtime uri")
@click.option("--host", default="", help="The host to listen on")
@click.option("--port", default=8080, help="The port of the server")
def _serve(
target: str,
model_yaml: str,
host: str,
port: int,
runtime: str,
handlers: t.Optional[t.List[str]],
) -> None:
ModelTermView.serve(target, model_yaml, host, port, runtime, handlers)
79 changes: 62 additions & 17 deletions client/starwhale/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import typing as t
import inspect
import tarfile
from abc import ABCMeta
from pathlib import Path
Expand All @@ -13,6 +14,7 @@
from fs.copy import copy_fs, copy_file
from fs.tarfs import TarFS

from starwhale import PipelineHandler
from starwhale.utils import console, now_str, load_yaml, gen_uniq_version
from starwhale.consts import (
DefaultYAMLName,
Expand All @@ -24,6 +26,7 @@
DEFAULT_EVALUATION_PIPELINE,
DEFAULT_EVALUATION_JOBS_FNAME,
DEFAULT_STARWHALE_API_VERSION,
DEFAULT_EVALUATION_SVC_META_FNAME,
)
from starwhale.base.tag import StandaloneTag
from starwhale.base.uri import URI
Expand All @@ -32,9 +35,11 @@
from starwhale.base.cloud import CloudRequestMixed, CloudBundleModelMixin
from starwhale.base.mixin import ASDictMixin
from starwhale.utils.http import ignore_error
from starwhale.utils.load import load_module
from starwhale.api.service import Service
from starwhale.base.bundle import BaseBundle, LocalStorageBundleMixin
from starwhale.utils.error import NoSupportError, FileFormatError
from starwhale.api._impl.job import Parser
from starwhale.api._impl.job import Parser, Context, context_holder
from starwhale.core.job.model import STATUS, Generator
from starwhale.utils.progress import run_with_progress_bar
from starwhale.core.eval.store import EvaluationStorage
Expand Down Expand Up @@ -192,9 +197,41 @@ def _gen_steps(self, typ: str, ppl: str) -> None:
if typ == EvalHandlerType.DEFAULT:
# use default
ppl = DEFAULT_EVALUATION_PIPELINE
_f = self.store.snapshot_workdir / "src" / DEFAULT_EVALUATION_JOBS_FNAME
d = self.store.snapshot_workdir / "src"
_f = d / DEFAULT_EVALUATION_JOBS_FNAME
logger.debug(f"job ppl path:{_f}, ppl is {ppl}")
Parser.generate_job_yaml(ppl, self.store.snapshot_workdir / "src", _f)
Parser.generate_job_yaml(ppl, d, _f)
svc = self._get_service(ppl, d)
_f = d / DEFAULT_EVALUATION_SVC_META_FNAME
ensure_file(
_f,
yaml.safe_dump(
{k: v.to_yaml() for k, v in svc.apis.items()}, default_flow_style=False
),
)

@staticmethod
def _get_service(module: str, pkg: Path) -> Service:
m = load_module(module, pkg)
apis = dict()
svc: Service = Service()

# TODO: check duplication
for k, v in m.__dict__.items():
if isinstance(v, Service):
apis.update(v.apis)
svc = v
if inspect.isclass(v) and issubclass(v, PipelineHandler):
# TODO: refine this ugly ad hoc
context_holder.context = Context(
Path("."), version="-1", project="tmp-project-for-build"
)
ins = v()
apis.update(ins.svc.apis)

for api in apis:
svc.add_api_instance(api)
return svc

@classmethod
def get_pipeline_handler(
Expand All @@ -204,9 +241,7 @@ def get_pipeline_handler(
) -> str:
_mp = workdir / yaml_name
_model_config = cls.load_model_config(_mp)
if _model_config.run.typ == EvalHandlerType.DEFAULT:
return DEFAULT_EVALUATION_PIPELINE
return _model_config.run.handler
return cls._get_module(_model_config)

@classmethod
def eval_user_handler(
Expand Down Expand Up @@ -240,22 +275,13 @@ def eval_user_handler(
_run_dir = EvaluationStorage.local_run_dir(_project_uri.project, version)
ensure_dir(_run_dir)

if _model_config.run.typ == EvalHandlerType.DEFAULT:
_module = DEFAULT_EVALUATION_PIPELINE
else:
_module = _model_config.run.handler

_module = cls._get_module(_model_config)
_yaml_path = str(workdir / DEFAULT_EVALUATION_JOBS_FNAME)

# generate if not exists
if not os.path.exists(_yaml_path):
if _model_config.run.typ == EvalHandlerType.DEFAULT:
_ppl = DEFAULT_EVALUATION_PIPELINE
else:
_ppl = _model_config.run.handler

_new_yaml_path = _run_dir / DEFAULT_EVALUATION_JOBS_FNAME
Parser.generate_job_yaml(_ppl, workdir, _new_yaml_path)
Parser.generate_job_yaml(_module, workdir, _new_yaml_path)
_yaml_path = str(_new_yaml_path)

# parse job steps from yaml
Expand Down Expand Up @@ -328,6 +354,13 @@ def eval_user_handler(
f":{100 if _status == STATUS.SUCCESS else 'broken_heart'}: finish run, {_status}!"
)

@classmethod
def _get_module(cls, _model_config):
if _model_config.run.typ == EvalHandlerType.DEFAULT:
return DEFAULT_EVALUATION_PIPELINE
else:
return _model_config.run.handler

def info(self) -> t.Dict[str, t.Any]:
_manifest = self._get_bundle_info()
_store = self.store
Expand Down Expand Up @@ -537,6 +570,18 @@ def _load_config_envs(cls, _config: ModelConfig) -> None:
def extract(self, force: bool = False, target: t.Union[str, Path] = "") -> Path:
return self._do_extract(force, target)

@classmethod
def serve(
cls,
model_yaml: str,
workdir: Path,
host: str,
port: int,
handlers: t.Optional[t.List[str]] = None,
):
svc = cls._get_service(cls.get_pipeline_handler(workdir, model_yaml), workdir)
svc.serve(host, port, handlers)


class CloudModel(CloudBundleModelMixin, Model):
def __init__(self, uri: URI) -> None:
Expand Down
32 changes: 24 additions & 8 deletions client/starwhale/core/model/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,10 @@ def eval(
task_num: int = 0,
runtime_uri: str = "",
) -> None:
if in_production() or (os.path.exists(target) and os.path.isdir(target)):
workdir = Path(target)
else:
_uri = URI(target, URIType.MODEL)
_store = ModelStorage(_uri)
workdir = _store.loc

kw = dict(
project=project,
version=version,
workdir=workdir,
workdir=cls._get_workdir(target),
dataset_uris=dataset_uris,
step_name=step,
task_index=task_index,
Expand All @@ -87,6 +80,16 @@ def eval(
else:
StandaloneModel.eval_user_handler(**kw) # type: ignore

@classmethod
def _get_workdir(cls, target):
if in_production() or (os.path.exists(target) and os.path.isdir(target)):
workdir = Path(target)
else:
_uri = URI(target, URIType.MODEL)
_store = ModelStorage(_uri)
workdir = _store.loc
return workdir

@classmethod
def list(
cls,
Expand Down Expand Up @@ -145,6 +148,19 @@ def tag(self, tags: t.List[str], remove: bool = False, quiet: bool = False) -> N
console.print(f":surfer: add tags [red]{tags}[/] @ {self.uri}...")
self.model.add_tags(tags, quiet)

@classmethod
def serve(
cls,
target: str,
model_yaml: str,
host: str,
port: int,
runtime: str,
handlers: t.Optional[t.List[str]] = None,
):
workdir = cls._get_workdir(target)
StandaloneModel.serve(model_yaml, workdir, host, port, handlers)


class ModelTermViewRich(ModelTermView):
@classmethod
Expand Down
Loading

0 comments on commit c653847

Please sign in to comment.