Skip to content

Commit

Permalink
Merge pull request #317 from OpenMined/aziz/bug_fix
Browse files Browse the repository at this point in the history
fix a bug
  • Loading branch information
abyesilyurt authored Nov 1, 2024
2 parents 3de9bb5 + 9edddd0 commit 4b1b0fa
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 62 deletions.
4 changes: 1 addition & 3 deletions syftbox/client/plugins/sync/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def get_metadata(client: httpx.Client, path: Path) -> FileMetadata:

response_data = handle_json_response("/sync/get_metadata", response)

if len(response_data) == 0:
raise SyftNotFound(f"[/sync/get_metadata] not found on server: {path}")
return FileMetadata(**response_data[0])
return FileMetadata(**response_data)


def get_diff(client: httpx.Client, path: Path, signature: bytes) -> DiffResponse:
Expand Down
23 changes: 21 additions & 2 deletions syftbox/server/sync/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def get_all_metadata(conn: sqlite3.Connection, path_like: Optional[str] = None)
params = ()

if path_like:
query += " WHERE path LIKE ?"
params = (path_like,)
if "%" in path_like:
raise ValueError("we don't support % in paths")
path_like = path_like + "%"
escaped_path = path_like.replace("_", "\\_")
query += " WHERE path LIKE ? ESCAPE '\\' "
params = (escaped_path,)

cursor = conn.execute(query, params)
# would be nice to paginate
Expand All @@ -84,6 +88,21 @@ def get_all_metadata(conn: sqlite3.Connection, path_like: Optional[str] = None)
]


def get_one_metadata(conn: sqlite3.Connection, path: str) -> FileMetadata:
cursor = conn.execute("SELECT * FROM file_metadata WHERE path = ?", (path,))
rows = cursor.fetchall()
if len(rows) == 0 or len(rows) > 1:
raise ValueError(f"Expected 1 metadata entry for {path}, got {len(rows)}")
row = rows[0]
return FileMetadata(
path=row[1],
hash=row[2],
signature=row[3],
file_size=row[4],
last_modified=row[5],
)


def get_all_datasites(conn: sqlite3.Connection) -> list[str]:
# INSTR(path, '/'): Finds the position of the first slash in the path.
cursor = conn.execute(
Expand Down
81 changes: 41 additions & 40 deletions syftbox/server/sync/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_all_datasites,
get_all_metadata,
get_db,
get_one_metadata,
move_with_transaction,
save_file_metadata,
)
Expand Down Expand Up @@ -46,7 +47,10 @@ def get_file_metadata(
) -> list[FileMetadata]:
# TODO check permissions

return get_all_metadata(conn, path_like=req.path_like)
try:
return get_one_metadata(conn, path=req.path_like)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))


router = APIRouter(prefix="/sync", tags=["sync"])
Expand All @@ -58,13 +62,11 @@ def get_diff(
conn: sqlite3.Connection = Depends(get_db_connection),
server_settings: ServerSettings = Depends(get_server_settings),
) -> DiffResponse:
metadata_list = get_all_metadata(conn, path_like=f"{req.path}")
if len(metadata_list) == 0:
raise HTTPException(status_code=404, detail="path not found")
elif len(metadata_list) > 1:
raise HTTPException(status_code=400, detail="too many files to get diff")
try:
metadata = get_one_metadata(conn, path=f"{req.path}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

metadata = metadata_list[0]
abs_path = server_settings.snapshot_folder / metadata.path
with open(abs_path, "rb") as f:
data = f.read()
Expand Down Expand Up @@ -107,24 +109,24 @@ def dir_state(
if dir.is_absolute():
raise HTTPException(status_code=400, detail="dir must be relative")

metadata = get_all_metadata(conn, path_like=f"{dir.as_posix()}%")
metadata = get_all_metadata(conn, path_like=f"{dir.as_posix()}")
full_path = server_settings.snapshot_folder / dir
# get the top level perm file
try:
perm_tree = PermissionTree.from_path(full_path, raise_on_corrupted_files=True)
except Exception as e:
logger.exception(f"Failed to parse permission tree: {dir}")
logger.warning(f"Failed to parse permission tree: {dir}")
raise e

# filter the read state for this user by the perm tree
filtered_metadata = filter_metadata(email, metadata, perm_tree, server_settings.snapshot_folder)
return filtered_metadata


@router.post("/get_metadata", response_model=list[FileMetadata])
@router.post("/get_metadata", response_model=FileMetadata)
def get_metadata(
metadata: list[FileMetadata] = Depends(get_file_metadata),
) -> list[FileMetadata]:
metadata: FileMetadata = Depends(get_file_metadata),
) -> FileMetadata:
return metadata


Expand All @@ -134,14 +136,10 @@ def apply_diffs(
conn: sqlite3.Connection = Depends(get_db_connection),
server_settings: ServerSettings = Depends(get_server_settings),
) -> ApplyDiffResponse:
metadata_list = get_all_metadata(conn, path_like=f"{req.path}")

if len(metadata_list) == 0:
raise HTTPException(status_code=404, detail="path not found")
elif len(metadata_list) > 1:
raise HTTPException(status_code=400, detail="found too many files to apply diff")

metadata = metadata_list[0]
try:
metadata = get_one_metadata(conn, path=f"{req.path}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

abs_path = server_settings.snapshot_folder / metadata.path
with open(abs_path, "rb") as f:
Expand Down Expand Up @@ -177,13 +175,10 @@ def delete_file(
conn: sqlite3.Connection = Depends(get_db_connection),
server_settings: ServerSettings = Depends(get_server_settings),
) -> JSONResponse:
metadata_list = get_all_metadata(conn, path_like=f"{req.path}")
if len(metadata_list) == 0:
raise HTTPException(status_code=404, detail="path not found")
elif len(metadata_list) > 1:
raise HTTPException(status_code=400, detail="too many files to delete")

metadata = metadata_list[0]
try:
metadata = get_one_metadata(conn, path=f"{req.path}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

try:
delete_file_metadata(conn, metadata.path.as_posix())
Expand All @@ -201,7 +196,9 @@ def create_file(
conn: sqlite3.Connection = Depends(get_db_connection),
server_settings: ServerSettings = Depends(get_server_settings),
) -> JSONResponse:
#
if "%" in file.filename:
raise HTTPException(status_code=400, detail="filename cannot contain '%'")

relative_path = Path(file.filename)
abs_path = server_settings.snapshot_folder / relative_path

Expand All @@ -217,9 +214,14 @@ def create_file(
f.write(contents)

cursor = conn.cursor()
metadata = get_all_metadata(cursor, path_like=f"{file.filename}")
if len(metadata) > 0:
try:
get_one_metadata(cursor, path=f"{file.filename}")
raise HTTPException(status_code=400, detail="file already exists")
except ValueError:
# this is ok, there should be no metadata in db
pass

# create a new metadata for db entry
metadata = hash_file(abs_path, root_dir=server_settings.snapshot_folder)
save_file_metadata(cursor, metadata)
conn.commit()
Expand All @@ -234,13 +236,11 @@ def download_file(
conn: sqlite3.Connection = Depends(get_db_connection),
server_settings: ServerSettings = Depends(get_server_settings),
) -> FileResponse:
metadata_list = get_all_metadata(conn, path_like=f"{req.path}")
if len(metadata_list) == 0:
raise HTTPException(status_code=404, detail="path not found")
elif len(metadata_list) > 1:
raise HTTPException(status_code=400, detail="too many files to download")
try:
metadata = get_one_metadata(conn, path=f"{req.path}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

metadata = metadata_list[0]
abs_path = server_settings.snapshot_folder / metadata.path
if not Path(abs_path).exists():
# could be a stale db entry, remove from db
Expand Down Expand Up @@ -273,11 +273,12 @@ async def get_files(
) -> StreamingResponse:
all_metadata = []
for path in req.paths:
metadata_list = get_all_metadata(conn, path_like=f"{path}")
if len(metadata_list) != 1:
logger.warning(f"Expected 1 metadata, got {len(metadata_list)} for {path}")
try:
metadata = get_one_metadata(conn, path=path)
except ValueError as e:
logger.warning(str(e))
continue
metadata = metadata_list[0]

abs_path = server_settings.snapshot_folder / metadata.path
if not Path(abs_path).exists() or not Path(abs_path).is_file():
logger.warning(f"File not found: {abs_path}")
Expand Down
21 changes: 4 additions & 17 deletions tests/unit/server/sync_endpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from py_fast_rsync import signature

from syftbox.client.plugins.sync.endpoints import (
SyftNotFound,
SyftServerError,
apply_diff,
download_bulk,
Expand All @@ -24,17 +23,6 @@
from tests.unit.server.conftest import PERMFILE_FILE, TEST_DATASITE_NAME, TEST_FILE


def test_get_all_permissions(client: TestClient):
# TODO: filter permissions and not return everything
response = client.post(
"/sync/get_metadata",
json={"path_like": "%.syftperm"},
)

response.raise_for_status()
assert len(response.json())


def test_get_diff_2(client: TestClient):
local_data = b"This is my local data"
sig = signature.calculate(local_data)
Expand Down Expand Up @@ -72,11 +60,11 @@ def file_digest(file_path, algorithm="sha256"):
def test_syft_client_push_flow(client: TestClient):
response = client.post(
"/sync/get_metadata",
json={"path_like": f"%{TEST_DATASITE_NAME}/{TEST_FILE}"},
json={"path_like": f"{TEST_DATASITE_NAME}/{TEST_FILE}"},
)

response.raise_for_status()
server_signature_b85 = response.json()[0]["signature"]
server_signature_b85 = response.json()["signature"]
server_signature = base64.b85decode(server_signature_b85)
assert server_signature

Expand Down Expand Up @@ -167,9 +155,8 @@ def test_get_diff(client: TestClient):

# diff nonexistent file
file_path = Path(TEST_DATASITE_NAME) / "nonexistent_file.txt"
with pytest.raises(SyftServerError) as e:
with pytest.raises(SyftServerError):
get_diff(client, file_path, sig)
assert "path not found" in str(e.value)


def test_delete_file(client: TestClient):
Expand All @@ -183,7 +170,7 @@ def test_delete_file(client: TestClient):
path = Path(f"{snapshot_folder}/{TEST_DATASITE_NAME}/{TEST_FILE}")
assert not path.exists()

with pytest.raises(SyftNotFound):
with pytest.raises(SyftServerError):
get_metadata(client, Path(TEST_DATASITE_NAME) / TEST_FILE)


Expand Down

0 comments on commit 4b1b0fa

Please sign in to comment.