Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for writing with s3 file system #539

Merged
merged 7 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
610 changes: 485 additions & 125 deletions pyathena/filesystem/s3.py

Large diffs are not rendered by default.

438 changes: 408 additions & 30 deletions pyathena/filesystem/s3_object.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ profile_file = "tests/sqlalchemy/profiles.txt"
line-length = 100
exclude = [
".venv",
"tests"
]
target-version = "py38"

Expand Down
126 changes: 110 additions & 16 deletions tests/pyathena/filesystem/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# -*- coding: utf-8 -*-
import uuid
from itertools import chain
from typing import Dict

import fsspec
import pytest

from pyathena.filesystem.s3 import S3File, S3FileSystem
from tests import ENV
from tests.pyathena.conftest import connect


@pytest.fixture(scope="class")
def register_filesystem():
fsspec.register_implementation("s3", "pyathena.filesystem.s3.S3FileSystem", clobber=True)
fsspec.register_implementation("s3a", "pyathena.filesystem.s3.S3FileSystem", clobber=True)


@pytest.mark.usefixtures("register_filesystem")
class TestS3FileSystem:
def test_parse_path(self):
actual = S3FileSystem.parse_path("s3://bucket")
Expand Down Expand Up @@ -109,35 +117,34 @@ def test_parse_path_invalid(self):
S3FileSystem.parse_path("s3a://bucket/path/to/obj?foo=bar")

@pytest.fixture(scope="class")
def fs(self) -> Dict[str, S3FileSystem]:
fs = {
"default": S3FileSystem(connect()),
"small_batches": S3FileSystem(connect(), default_block_size=3),
}
return fs
def fs(self, request):
if not hasattr(request, "param"):
setattr(request, "param", {})
return S3FileSystem(connect(), **request.param)

@pytest.mark.parametrize(
["start", "end", "batch_mode", "target_data"],
["fs", "start", "end", "target_data"],
list(
chain(
*[
[
(0, 5, x, b"01234"),
(2, 7, x, b"23456"),
(0, 10, x, b"0123456789"),
({"default_block_size": x}, 0, 5, b"01234"),
({"default_block_size": x}, 2, 7, b"23456"),
({"default_block_size": x}, 0, 10, b"0123456789"),
]
for x in ("default", "small_batches")
for x in (S3FileSystem.DEFAULT_BLOCK_SIZE, 3)
]
)
),
indirect=["fs"],
)
def test_read(self, fs, start, end, batch_mode, target_data):
def test_read(self, fs, start, end, target_data):
# lowest level access: use _get_object
data = fs[batch_mode]._get_object(
data = fs._get_object(
ENV.s3_staging_bucket, ENV.s3_filesystem_test_file_key, ranges=(start, end)
)
assert data == (start, target_data), data
with fs[batch_mode].open(
with fs.open(
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_filesystem_test_file_key}", "rb"
) as file:
# mid-level access: use _fetch_range
Expand All @@ -148,7 +155,84 @@ def test_read(self, fs, start, end, batch_mode, target_data):
data = file.read(end - start)
assert data == target_data, data

def test_compatibility_with_s3fs(self):
@pytest.mark.parametrize(
["base", "exp"],
[
(1, 2**10),
(10, 2**10),
(100, 2**10),
(1, 2**20),
(10, 2**20),
(100, 2**20),
(1000, 2**20),
],
)
def test_write(self, fs, base, exp):
data = b"a" * (base * exp)
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.dat"
with fs.open(path, "wb") as f:
f.write(data)
with fs.open(path, "rb") as f:
actual = f.read()
assert len(actual) == len(data)
assert actual == data

@pytest.mark.parametrize(
["base", "exp"],
[
(1, 2**10),
(10, 2**10),
(100, 2**10),
(1, 2**20),
(10, 2**20),
(100, 2**20),
(1000, 2**20),
],
)
def test_append(self, fs, base, exp):
# TODO: Check the metadata is kept.
data = b"a" * (base * exp)
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.dat"
with fs.open(path, "ab") as f:
f.write(data)
extra = b"extra"
with fs.open(path, "ab") as f:
f.write(extra)
with fs.open(path, "rb") as f:
actual = f.read()
assert len(actual) == len(data + extra)
assert actual == data + extra

def test_exists(self, fs):
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_filesystem_test_file_key}"
assert fs.exists(path)

not_exists_path = (
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}"
)
assert not fs.exists(not_exists_path)

def test_touch(self, fs):
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}"
assert not fs.exists(path)
fs.touch(path)
assert fs.exists(path)
assert fs.size(path) == 0

with fs.open(path, "wb") as f:
f.write(b"data")
assert fs.size(path) == 4
fs.touch(path, truncate=True)
assert fs.size(path) == 0

with fs.open(path, "wb") as f:
f.write(b"data")
assert fs.size(path) == 4
with pytest.raises(ValueError):
fs.touch(path, truncate=False)
assert fs.size(path) == 4

def test_pandas_read_csv(self):
import pandas

df = pandas.read_csv(
Expand All @@ -158,6 +242,16 @@ def test_compatibility_with_s3fs(self):
)
assert [(row["col"],) for _, row in df.iterrows()] == [(123456789,)]

def test_pandas_write_csv(self):
import pandas

df = pandas.DataFrame({"a": [1], "b": [2]})
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.csv"
df.to_csv(path, index=False)

actual = pandas.read_csv(path)
pandas.testing.assert_frame_equal(df, actual)


class TestS3File:
@pytest.mark.parametrize(
Expand Down
Loading