Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically upload prediction and training input files #339

Merged
merged 6 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import Unpack, deprecated

from replicate.account import Account
from replicate.json import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.prediction import (
Prediction,
Expand Down Expand Up @@ -417,6 +418,13 @@ def create(
Create a new prediction with the deployment.
"""

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
Expand All @@ -436,6 +444,13 @@ async def async_create(
Create a new prediction with the deployment.
"""

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
Expand Down Expand Up @@ -463,6 +478,14 @@ def create(
"""

url = _create_prediction_url_from_deployment(deployment)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
Expand All @@ -484,6 +507,14 @@ async def async_create(
"""

url = _create_prediction_url_from_deployment(deployment)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
Expand Down
39 changes: 3 additions & 36 deletions replicate/file.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import base64
import io
import json
import mimetypes
import os
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union

import httpx
from typing_extensions import NotRequired, Unpack
from typing_extensions import Literal, NotRequired, Unpack

from replicate.resource import Namespace, Resource

FileEncodingStrategy = Literal["base64", "url"]


class File(Resource):
"""
Expand Down Expand Up @@ -169,36 +169,3 @@ def _create_file_params(

def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-outer-name
return File(**json)


def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
"""
Upload a file to the server.

Args:
file: A file handle to upload.
output_file_prefix: A string to prepend to the output file name.
Returns:
str: A URL to the uploaded file.
"""
# Lifted straight from cog.files

file.seek(0)

if output_file_prefix is not None:
name = getattr(file, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
resp.raise_for_status()

return url

body = file.read()
# Ensure the file handle is in bytes
body = body.encode("utf-8") if isinstance(body, str) else body
encoded_body = base64.b64encode(body).decode("utf-8")
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
mime_type = (
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
)
return f"data:{mime_type};base64,{encoded_body}"
84 changes: 77 additions & 7 deletions replicate/json.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import base64
import io
import mimetypes
from pathlib import Path
from types import GeneratorType
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Optional

if TYPE_CHECKING:
from replicate.client import Client
from replicate.file import FileEncodingStrategy


try:
import numpy as np # type: ignore
Expand All @@ -14,22 +21,62 @@
# pylint: disable=too-many-return-statements
def encode_json(
obj: Any, # noqa: ANN401
upload_file: Callable[[io.IOBase], str],
client: "Client",
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
) -> Any: # noqa: ANN401
"""
Return a JSON-compatible version of the object.
"""
# Effectively the same thing as cog.json.encode_json.

if isinstance(obj, dict):
return {key: encode_json(value, upload_file) for key, value in obj.items()}
return {
key: encode_json(value, client, file_encoding_strategy)
for key, value in obj.items()
}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [encode_json(value, client, file_encoding_strategy) for value in obj]
if isinstance(obj, Path):
with obj.open("rb") as file:
return encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
if file_encoding_strategy == "base64":
return base64.b64encode(obj.read()).decode("utf-8")
else:
return client.files.create(obj).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
if isinstance(obj, np.floating): # type: ignore
return float(obj)
if isinstance(obj, np.ndarray): # type: ignore
return obj.tolist()
return obj


async def async_encode_json(
obj: Any, # noqa: ANN401
client: "Client",
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
) -> Any: # noqa: ANN401
"""
Asynchronously return a JSON-compatible version of the object.
"""

if isinstance(obj, dict):
return {
key: (await async_encode_json(value, client, file_encoding_strategy))
for key, value in obj.items()
}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [encode_json(value, upload_file) for value in obj]
return [
(await async_encode_json(value, client, file_encoding_strategy))
for value in obj
]
if isinstance(obj, Path):
with obj.open("rb") as file:
return upload_file(file)
return encode_json(file, client, file_encoding_strategy)
if isinstance(obj, io.IOBase):
return upload_file(obj)
return (await client.files.async_create(obj)).urls["get"]
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand All @@ -38,3 +85,26 @@ def encode_json(
if isinstance(obj, np.ndarray): # type: ignore
return obj.tolist()
return obj


def base64_encode_file(file: io.IOBase) -> str:
"""
Base64 encode a file.

Args:
file: A file handle to upload.
Returns:
str: A base64-encoded data URI.
"""

file.seek(0)
body = file.read()

# Ensure the file handle is in bytes
body = body.encode("utf-8") if isinstance(body, str) else body
encoded_body = base64.b64encode(body).decode("utf-8")

mime_type = (
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
)
return f"data:{mime_type};base64,{encoded_body}"
17 changes: 17 additions & 0 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from replicate.exceptions import ReplicateException
from replicate.identifier import ModelVersionIdentifier
from replicate.json import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.prediction import (
Prediction,
Expand Down Expand Up @@ -391,6 +392,14 @@ def create(
"""

url = _create_prediction_url_from_model(model)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
Expand All @@ -412,6 +421,14 @@ async def async_create(
"""

url = _create_prediction_url_from_model(model)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
Expand Down
24 changes: 21 additions & 3 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from typing_extensions import NotRequired, TypedDict, Unpack

from replicate.exceptions import ModelError, ReplicateError
from replicate.file import upload_file
from replicate.json import encode_json
from replicate.file import FileEncodingStrategy
from replicate.json import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
from replicate.stream import EventSource
Expand Down Expand Up @@ -383,6 +383,9 @@ class CreatePredictionParams(TypedDict):
stream: NotRequired[bool]
"""Enable streaming of prediction output."""

file_encoding_strategy: NotRequired[FileEncodingStrategy]
"""The strategy to use for encoding files in the prediction input."""

@overload
def create(
self,
Expand Down Expand Up @@ -453,6 +456,13 @@ def create( # type: ignore
**params,
)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(
version,
input,
Expand Down Expand Up @@ -537,6 +547,13 @@ async def async_create( # type: ignore
**params,
)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
body = _create_prediction_body(
version,
input,
Expand Down Expand Up @@ -593,11 +610,12 @@ def _create_prediction_body( # pylint: disable=too-many-arguments
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
**_kwargs,
) -> Dict[str, Any]:
body = {}

if input is not None:
body["input"] = encode_json(input, upload_file=upload_file)
body["input"] = input

if version is not None:
body["version"] = version.id if isinstance(version, Version) else version
Expand Down
Loading