Skip to content

Commit

Permalink
Support for writing with s3 file system (#539)
Browse files Browse the repository at this point in the history
* Support for writing with s3 filesystem (fix #465)

* Fix test cases

* Apply fmt

* Remove unnecessary vars

* FIx fixture args

* Add test cases for file writing

* Add type hints
  • Loading branch information
laughingman7743 authored May 2, 2024
1 parent ee7748e commit 4fbced8
Show file tree
Hide file tree
Showing 9 changed files with 1,486 additions and 409 deletions.
610 changes: 485 additions & 125 deletions pyathena/filesystem/s3.py

Large diffs are not rendered by default.

438 changes: 408 additions & 30 deletions pyathena/filesystem/s3_object.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ profile_file = "tests/sqlalchemy/profiles.txt"
line-length = 100
exclude = [
".venv",
"tests"
]
target-version = "py38"

Expand Down
126 changes: 110 additions & 16 deletions tests/pyathena/filesystem/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# -*- coding: utf-8 -*-
import uuid
from itertools import chain
from typing import Dict

import fsspec
import pytest

from pyathena.filesystem.s3 import S3File, S3FileSystem
from tests import ENV
from tests.pyathena.conftest import connect


@pytest.fixture(scope="class")
def register_filesystem():
fsspec.register_implementation("s3", "pyathena.filesystem.s3.S3FileSystem", clobber=True)
fsspec.register_implementation("s3a", "pyathena.filesystem.s3.S3FileSystem", clobber=True)


@pytest.mark.usefixtures("register_filesystem")
class TestS3FileSystem:
def test_parse_path(self):
actual = S3FileSystem.parse_path("s3://bucket")
Expand Down Expand Up @@ -109,35 +117,34 @@ def test_parse_path_invalid(self):
S3FileSystem.parse_path("s3a://bucket/path/to/obj?foo=bar")

@pytest.fixture(scope="class")
def fs(self) -> Dict[str, S3FileSystem]:
fs = {
"default": S3FileSystem(connect()),
"small_batches": S3FileSystem(connect(), default_block_size=3),
}
return fs
def fs(self, request):
if not hasattr(request, "param"):
setattr(request, "param", {})
return S3FileSystem(connect(), **request.param)

@pytest.mark.parametrize(
["start", "end", "batch_mode", "target_data"],
["fs", "start", "end", "target_data"],
list(
chain(
*[
[
(0, 5, x, b"01234"),
(2, 7, x, b"23456"),
(0, 10, x, b"0123456789"),
({"default_block_size": x}, 0, 5, b"01234"),
({"default_block_size": x}, 2, 7, b"23456"),
({"default_block_size": x}, 0, 10, b"0123456789"),
]
for x in ("default", "small_batches")
for x in (S3FileSystem.DEFAULT_BLOCK_SIZE, 3)
]
)
),
indirect=["fs"],
)
def test_read(self, fs, start, end, batch_mode, target_data):
def test_read(self, fs, start, end, target_data):
# lowest level access: use _get_object
data = fs[batch_mode]._get_object(
data = fs._get_object(
ENV.s3_staging_bucket, ENV.s3_filesystem_test_file_key, ranges=(start, end)
)
assert data == (start, target_data), data
with fs[batch_mode].open(
with fs.open(
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_filesystem_test_file_key}", "rb"
) as file:
# mid-level access: use _fetch_range
Expand All @@ -148,7 +155,84 @@ def test_read(self, fs, start, end, batch_mode, target_data):
data = file.read(end - start)
assert data == target_data, data

def test_compatibility_with_s3fs(self):
@pytest.mark.parametrize(
["base", "exp"],
[
(1, 2**10),
(10, 2**10),
(100, 2**10),
(1, 2**20),
(10, 2**20),
(100, 2**20),
(1000, 2**20),
],
)
def test_write(self, fs, base, exp):
data = b"a" * (base * exp)
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.dat"
with fs.open(path, "wb") as f:
f.write(data)
with fs.open(path, "rb") as f:
actual = f.read()
assert len(actual) == len(data)
assert actual == data

@pytest.mark.parametrize(
["base", "exp"],
[
(1, 2**10),
(10, 2**10),
(100, 2**10),
(1, 2**20),
(10, 2**20),
(100, 2**20),
(1000, 2**20),
],
)
def test_append(self, fs, base, exp):
# TODO: Check the metadata is kept.
data = b"a" * (base * exp)
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.dat"
with fs.open(path, "ab") as f:
f.write(data)
extra = b"extra"
with fs.open(path, "ab") as f:
f.write(extra)
with fs.open(path, "rb") as f:
actual = f.read()
assert len(actual) == len(data + extra)
assert actual == data + extra

def test_exists(self, fs):
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_filesystem_test_file_key}"
assert fs.exists(path)

not_exists_path = (
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}"
)
assert not fs.exists(not_exists_path)

def test_touch(self, fs):
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}"
assert not fs.exists(path)
fs.touch(path)
assert fs.exists(path)
assert fs.size(path) == 0

with fs.open(path, "wb") as f:
f.write(b"data")
assert fs.size(path) == 4
fs.touch(path, truncate=True)
assert fs.size(path) == 0

with fs.open(path, "wb") as f:
f.write(b"data")
assert fs.size(path) == 4
with pytest.raises(ValueError):
fs.touch(path, truncate=False)
assert fs.size(path) == 4

def test_pandas_read_csv(self):
import pandas

df = pandas.read_csv(
Expand All @@ -158,6 +242,16 @@ def test_compatibility_with_s3fs(self):
)
assert [(row["col"],) for _, row in df.iterrows()] == [(123456789,)]

def test_pandas_write_csv(self):
import pandas

df = pandas.DataFrame({"a": [1], "b": [2]})
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.csv"
df.to_csv(path, index=False)

actual = pandas.read_csv(path)
pandas.testing.assert_frame_equal(df, actual)


class TestS3File:
@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit 4fbced8

Please sign in to comment.