Skip to content

Commit

Permalink
Update stream interface to always use FileOutput
Browse files Browse the repository at this point in the history
  • Loading branch information
aron authored and zeke committed Oct 9, 2024
1 parent 9921f4c commit 57e7255
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
16 changes: 10 additions & 6 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
*,
use_file_output: Optional[bool] = True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Expand All @@ -177,7 +178,8 @@ async def async_run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
*,
use_file_output: Optional[bool] = True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
"""
Expand All @@ -191,28 +193,30 @@ async def async_run(
def stream(
self,
ref: str,
*,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
use_file_output: Optional[bool] = True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator["ServerSentEvent"]:
"""
Stream a model's output.
"""

return stream(self, ref, input, use_file_output, **params)
return stream(self, ref, input, use_file_output=use_file_output, **params)

async def async_stream(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
*,
use_file_output: Optional[bool] = True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator["ServerSentEvent"]:
"""
Stream a model's output asynchronously.
"""

return async_stream(self, ref, input, use_file_output, **params)
return async_stream(self, ref, input, use_file_output=use_file_output, **params)


# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
Expand Down
11 changes: 7 additions & 4 deletions replicate/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def __init__(
self,
client: "Client",
response: "httpx.Response",
use_file_output: Optional[bool] = None,
*,
use_file_output: Optional[bool] = True,
) -> None:
self.client = client
self.response = response
self.use_file_output = use_file_output or False
self.use_file_output = use_file_output or True
content_type, _, _ = response.headers["content-type"].partition(";")
if content_type != "text/event-stream":
raise ValueError(
Expand Down Expand Up @@ -193,7 +194,8 @@ def stream(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
*,
use_file_output: Optional[bool] = True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator[ServerSentEvent]:
"""
Expand Down Expand Up @@ -234,7 +236,8 @@ async def async_stream(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
*,
use_file_output: Optional[bool] = True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator[ServerSentEvent]:
"""
Expand Down

0 comments on commit 57e7255

Please sign in to comment.