Skip to content

Commit

Permalink
Set filename parameter when delegating to create or async_create
Browse files Browse the repository at this point in the history
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
mattt committed Aug 30, 2024
1 parent 33ac93a commit 81dc0df
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 1 deletion.
6 changes: 5 additions & 1 deletion replicate/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def create(
"""

if isinstance(file, (str, pathlib.Path)):
file_path = pathlib.Path(file)
params["filename"] = file_path.name
with open(file, "rb") as f:
return self.create(f, **params)
elif not isinstance(file, (io.IOBase, BinaryIO)):
Expand All @@ -92,7 +94,9 @@ async def async_create(
"""Upload a file asynchronously that can be passed as an input when running a model."""

if isinstance(file, (str, pathlib.Path)):
with open(file, "rb") as f:
file_path = pathlib.Path(file)
params["filename"] = file_path.name
with open(file_path, "rb") as f:
return await self.async_create(f, **params)
elif not isinstance(file, (io.IOBase, BinaryIO)):
raise ValueError(
Expand Down
97 changes: 97 additions & 0 deletions tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,108 @@

import httpx
import pytest
import respx

import replicate
from replicate.client import Client

from .conftest import skip_if_no_token

router = respx.Router(base_url="https://api.replicate.com/v1")

router.route(
method="POST",
path="/files",
name="files.create",
).mock(
return_value=httpx.Response(
201,
json={
"id": "0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
"name": "hello.txt",
"size": 14,
"content_type": "text/plain",
"etag": "746308829575e17c3331bbcb00c0898b",
"checksums": {
"md5": "746308829575e17c3331bbcb00c0898b",
"sha256": "d9014c4624844aa5bac314773d6b689ad467fa4e1d1a50a1b8a99d5a95f72ff5",
},
"metadata": {
"foo": "bar",
},
"urls": {
"get": "https://api.replicate.com/v1/files/0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
},
"created_at": "2024-08-22T12:26:51.079Z",
"expires_at": "2024-08-22T13:26:51.079Z",
},
)
)


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
@pytest.mark.parametrize("use_path", [True, False])
async def test_file_create(async_flag, use_path):
client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)

temp_dir = tempfile.mkdtemp()
temp_file_path = os.path.join(temp_dir, "hello.txt")

try:
with open(temp_file_path, "w", encoding="utf-8") as temp_file:
temp_file.write("Hello, world!")

metadata = {"foo": "bar"}

if use_path:
file_arg = temp_file_path
if async_flag:
created_file = await client.files.async_create(
file_arg, metadata=metadata
)
else:
created_file = client.files.create(file_arg, metadata=metadata)
else:
with open(temp_file_path, "rb") as file_arg:
if async_flag:
created_file = await client.files.async_create(
file_arg, metadata=metadata
)
else:
created_file = client.files.create(file_arg, metadata=metadata)

assert router["files.create"].called
request = router["files.create"].calls[0].request

# Check that the request is multipart/form-data
assert request.headers["Content-Type"].startswith("multipart/form-data")

# Check that the filename is included and matches the fixed file name
assert b'filename="hello.txt"' in request.content
assert b"Hello, world!" in request.content

# Check the response
assert created_file.id == "0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy"
assert created_file.name == "hello.txt"
assert created_file.size == 14
assert created_file.content_type == "text/plain"
assert created_file.etag == "746308829575e17c3331bbcb00c0898b"
assert created_file.checksums == {
"md5": "746308829575e17c3331bbcb00c0898b",
"sha256": "d9014c4624844aa5bac314773d6b689ad467fa4e1d1a50a1b8a99d5a95f72ff5",
}
assert created_file.metadata == metadata
assert created_file.urls == {
"get": "https://api.replicate.com/v1/files/0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
}

finally:
os.unlink(temp_file_path)
os.rmdir(temp_dir)


@skip_if_no_token
@pytest.mark.skipif(os.environ.get("CI") is not None, reason="Do not run on CI")
Expand Down

0 comments on commit 81dc0df

Please sign in to comment.