Skip to content

Commit

Permalink
Automatically upload prediction and training input files (#339)
Browse files Browse the repository at this point in the history
Follow-up to #226

---------

Signed-off-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
mattt authored Aug 22, 2024
1 parent 54f9c32 commit ff41075
Show file tree
Hide file tree
Showing 11 changed files with 30,610 additions and 54 deletions.
31 changes: 31 additions & 0 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import Unpack, deprecated

from replicate.account import Account
from replicate.json import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.prediction import (
Prediction,
Expand Down Expand Up @@ -417,6 +418,13 @@ def create(
Create a new prediction with the deployment.
"""

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
Expand All @@ -436,6 +444,13 @@ async def async_create(
Create a new prediction with the deployment.
"""

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
Expand Down Expand Up @@ -463,6 +478,14 @@ def create(
"""

url = _create_prediction_url_from_deployment(deployment)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
Expand All @@ -484,6 +507,14 @@ async def async_create(
"""

url = _create_prediction_url_from_deployment(deployment)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
Expand Down
39 changes: 3 additions & 36 deletions replicate/file.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import base64
import io
import json
import mimetypes
import os
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union

import httpx
from typing_extensions import NotRequired, Unpack
from typing_extensions import Literal, NotRequired, Unpack

from replicate.resource import Namespace, Resource

FileEncodingStrategy = Literal["base64", "url"]


class File(Resource):
"""
Expand Down Expand Up @@ -169,36 +169,3 @@ def _create_file_params(

def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-outer-name
return File(**json)


def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
"""
Upload a file to the server.
Args:
file: A file handle to upload.
output_file_prefix: A string to prepend to the output file name.
Returns:
str: A URL to the uploaded file.
"""
# Lifted straight from cog.files

file.seek(0)

if output_file_prefix is not None:
name = getattr(file, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
resp.raise_for_status()

return url

body = file.read()
# Ensure the file handle is in bytes
body = body.encode("utf-8") if isinstance(body, str) else body
encoded_body = base64.b64encode(body).decode("utf-8")
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
mime_type = (
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
)
return f"data:{mime_type};base64,{encoded_body}"
84 changes: 77 additions & 7 deletions replicate/json.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import base64
import io
import mimetypes
from pathlib import Path
from types import GeneratorType
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Optional

if TYPE_CHECKING:
from replicate.client import Client
from replicate.file import FileEncodingStrategy


try:
import numpy as np # type: ignore
Expand All @@ -14,22 +21,62 @@
# pylint: disable=too-many-return-statements
def encode_json(
obj: Any, # noqa: ANN401
upload_file: Callable[[io.IOBase], str],
client: "Client",
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
) -> Any: # noqa: ANN401
"""
Return a JSON-compatible version of the object.
"""
# Effectively the same thing as cog.json.encode_json.

if isinstance(obj, dict):
return {key: encode_json(value, upload_file) for key, value in obj.items()}
return {
key: encode_json(value, client, file_encoding_strategy)
for key, value in obj.items()
}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [encode_json(value, client, file_encoding_strategy) for value in obj]
if isinstance(obj, Path):
with obj.open("rb") as file:
return encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
if file_encoding_strategy == "base64":
return base64.b64encode(obj.read()).decode("utf-8")
else:
return client.files.create(obj).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
if isinstance(obj, np.floating): # type: ignore
return float(obj)
if isinstance(obj, np.ndarray): # type: ignore
return obj.tolist()
return obj


async def async_encode_json(
obj: Any, # noqa: ANN401
client: "Client",
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
) -> Any: # noqa: ANN401
"""
Asynchronously return a JSON-compatible version of the object.
"""

if isinstance(obj, dict):
return {
key: (await async_encode_json(value, client, file_encoding_strategy))
for key, value in obj.items()
}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [encode_json(value, upload_file) for value in obj]
return [
(await async_encode_json(value, client, file_encoding_strategy))
for value in obj
]
if isinstance(obj, Path):
with obj.open("rb") as file:
return upload_file(file)
return encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
return upload_file(obj)
return (await client.files.async_create(obj)).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand All @@ -38,3 +85,26 @@ def encode_json(
if isinstance(obj, np.ndarray): # type: ignore
return obj.tolist()
return obj


def base64_encode_file(file: io.IOBase) -> str:
"""
Base64 encode a file.
Args:
file: A file handle to upload.
Returns:
str: A base64-encoded data URI.
"""

file.seek(0)
body = file.read()

# Ensure the file handle is in bytes
body = body.encode("utf-8") if isinstance(body, str) else body
encoded_body = base64.b64encode(body).decode("utf-8")

mime_type = (
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
)
return f"data:{mime_type};base64,{encoded_body}"
17 changes: 17 additions & 0 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from replicate.exceptions import ReplicateException
from replicate.identifier import ModelVersionIdentifier
from replicate.json import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.prediction import (
Prediction,
Expand Down Expand Up @@ -391,6 +392,14 @@ def create(
"""

url = _create_prediction_url_from_model(model)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
Expand All @@ -412,6 +421,14 @@ async def async_create(
"""

url = _create_prediction_url_from_model(model)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
Expand Down
24 changes: 21 additions & 3 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from typing_extensions import NotRequired, TypedDict, Unpack

from replicate.exceptions import ModelError, ReplicateError
from replicate.file import upload_file
from replicate.json import encode_json
from replicate.file import FileEncodingStrategy
from replicate.json import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
from replicate.stream import EventSource
Expand Down Expand Up @@ -383,6 +383,9 @@ class CreatePredictionParams(TypedDict):
stream: NotRequired[bool]
"""Enable streaming of prediction output."""

file_encoding_strategy: NotRequired[FileEncodingStrategy]
"""The strategy to use for encoding files in the prediction input."""

@overload
def create(
self,
Expand Down Expand Up @@ -453,6 +456,13 @@ def create( # type: ignore
**params,
)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(
version,
input,
Expand Down Expand Up @@ -537,6 +547,13 @@ async def async_create( # type: ignore
**params,
)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(
version,
input,
Expand Down Expand Up @@ -593,11 +610,12 @@ def _create_prediction_body( # pylint: disable=too-many-arguments
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
**_kwargs,
) -> Dict[str, Any]:
body = {}

if input is not None:
body["input"] = encode_json(input, upload_file=upload_file)
body["input"] = input

if version is not None:
body["version"] = version.id if isinstance(version, Version) else version
Expand Down
Loading

0 comments on commit ff41075

Please sign in to comment.