Skip to content

Commit

Permalink
Add training
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Firshman <ben@firshman.com>
  • Loading branch information
bfirsh committed Apr 6, 2023
1 parent 2b0beea commit 0da9814
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
1 change: 1 addition & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
run = default_client.run
models = default_client.models
predictions = default_client.predictions
trainings = default_client.trainings
5 changes: 5 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from replicate.exceptions import ModelError, ReplicateError
from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
from replicate.training import TrainingCollection


class Client:
Expand Down Expand Up @@ -107,6 +108,10 @@ def models(self) -> ModelCollection:
def predictions(self) -> PredictionCollection:
return PredictionCollection(client=self)

@property
def trainings(self) -> TrainingCollection:
return TrainingCollection(client=self)

def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]:
"""
Run a model in the format owner/name:version.
Expand Down
78 changes: 78 additions & 0 deletions replicate/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import re
import time
from typing import Any, Dict, Iterator, List, Optional

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ModelError, ReplicateException
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.version import Version


class Training(BaseModel):
completed_at: Optional[str]
created_at: Optional[str]
destination: Optional[str]
error: Optional[str]
id: str
input: Optional[Dict[str, Any]]
logs: Optional[str]
output: Optional[Any]
started_at: Optional[str]
status: str
version: str

def cancel(self):
"""Cancel a running training"""
self._client._request("POST", f"/v1/trainings/{self.id}/cancel")


class TrainingCollection(Collection):
model = Training

def create(
self,
version: str,
input: Dict[str, Any],
destination: str,
webhook: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
) -> Training:
input = encode_json(input, upload_file=upload_file)
body = {
"input": input,
"destination": destination,
}
if webhook is not None:
body["webhook"] = webhook
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter

# Split version in format "username/model_name:version_id"
match = re.match(
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$", version
)
if not match:
raise ReplicateException(
f"version must be in format username/model_name:version_id"
)
username = match.group("username")
model_name = match.group("model_name")
version_id = match.group("version_id")

resp = self._client._request(
"POST",
f"/v1/models/{username}/{model_name}/versions/{version_id}/trainings",
json=body,
)
obj = resp.json()
return self.prepare_model(obj)

def get(self, id: str) -> Training:
resp = self._client._request(
"GET",
f"/v1/trainings/{id}",
)
obj = resp.json()
return self.prepare_model(obj)

0 comments on commit 0da9814

Please sign in to comment.