diff --git a/syftbox/client/plugins/sync/endpoints.py b/syftbox/client/plugins/sync/endpoints.py index 8b5a34e9..64ca01ad 100644 --- a/syftbox/client/plugins/sync/endpoints.py +++ b/syftbox/client/plugins/sync/endpoints.py @@ -73,9 +73,7 @@ def get_metadata(client: httpx.Client, path: Path) -> FileMetadata: response_data = handle_json_response("/sync/get_metadata", response) - if len(response_data) == 0: - raise SyftNotFound(f"[/sync/get_metadata] not found on server: {path}") - return FileMetadata(**response_data[0]) + return FileMetadata(**response_data) def get_diff(client: httpx.Client, path: Path, signature: bytes) -> DiffResponse: diff --git a/syftbox/server/sync/db.py b/syftbox/server/sync/db.py index 0ce09625..a37c45b1 100644 --- a/syftbox/server/sync/db.py +++ b/syftbox/server/sync/db.py @@ -67,8 +67,12 @@ def get_all_metadata(conn: sqlite3.Connection, path_like: Optional[str] = None) params = () if path_like: - query += " WHERE path LIKE ?" - params = (path_like,) + if "%" in path_like: + raise ValueError("we don't support % in paths") + path_like = path_like + "%" + escaped_path = path_like.replace("_", "\\_") + query += " WHERE path LIKE ? ESCAPE '\\' " + params = (escaped_path,) cursor = conn.execute(query, params) # would be nice to paginate @@ -84,6 +88,21 @@ def get_all_metadata(conn: sqlite3.Connection, path_like: Optional[str] = None) ] +def get_one_metadata(conn: sqlite3.Connection, path: str) -> FileMetadata: + cursor = conn.execute("SELECT * FROM file_metadata WHERE path = ?", (path,)) + rows = cursor.fetchall() + if len(rows) == 0 or len(rows) > 1: + raise ValueError(f"Expected 1 metadata entry for {path}, got {len(rows)}") + row = rows[0] + return FileMetadata( + path=row[1], + hash=row[2], + signature=row[3], + file_size=row[4], + last_modified=row[5], + ) + + def get_all_datasites(conn: sqlite3.Connection) -> list[str]: # INSTR(path, '/'): Finds the position of the first slash in the path. cursor = conn.execute( diff --git a/syftbox/server/sync/router.py b/syftbox/server/sync/router.py index b8bbd6b0..e813f810 100644 --- a/syftbox/server/sync/router.py +++ b/syftbox/server/sync/router.py @@ -17,6 +17,7 @@ get_all_datasites, get_all_metadata, get_db, + get_one_metadata, move_with_transaction, save_file_metadata, ) @@ -46,7 +47,10 @@ def get_file_metadata( ) -> list[FileMetadata]: # TODO check permissions - return get_all_metadata(conn, path_like=req.path_like) + try: + return get_one_metadata(conn, path=req.path_like) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) router = APIRouter(prefix="/sync", tags=["sync"]) @@ -58,13 +62,11 @@ def get_diff( conn: sqlite3.Connection = Depends(get_db_connection), server_settings: ServerSettings = Depends(get_server_settings), ) -> DiffResponse: - metadata_list = get_all_metadata(conn, path_like=f"{req.path}") - if len(metadata_list) == 0: - raise HTTPException(status_code=404, detail="path not found") - elif len(metadata_list) > 1: - raise HTTPException(status_code=400, detail="too many files to get diff") + try: + metadata = get_one_metadata(conn, path=f"{req.path}") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) - metadata = metadata_list[0] abs_path = server_settings.snapshot_folder / metadata.path with open(abs_path, "rb") as f: data = f.read() @@ -107,13 +109,13 @@ def dir_state( if dir.is_absolute(): raise HTTPException(status_code=400, detail="dir must be relative") - metadata = get_all_metadata(conn, path_like=f"{dir.as_posix()}%") + metadata = get_all_metadata(conn, path_like=f"{dir.as_posix()}") full_path = server_settings.snapshot_folder / dir # get the top level perm file try: perm_tree = PermissionTree.from_path(full_path, raise_on_corrupted_files=True) except Exception as e: - logger.exception(f"Failed to parse permission tree: {dir}") + logger.warning(f"Failed to parse permission tree: {dir}") raise e # filter the read state for this user by the perm tree @@ -121,10 +123,10 @@ def dir_state( return filtered_metadata -@router.post("/get_metadata", response_model=list[FileMetadata]) +@router.post("/get_metadata", response_model=FileMetadata) def get_metadata( - metadata: list[FileMetadata] = Depends(get_file_metadata), -) -> list[FileMetadata]: + metadata: FileMetadata = Depends(get_file_metadata), +) -> FileMetadata: return metadata @@ -134,14 +136,10 @@ def apply_diffs( conn: sqlite3.Connection = Depends(get_db_connection), server_settings: ServerSettings = Depends(get_server_settings), ) -> ApplyDiffResponse: - metadata_list = get_all_metadata(conn, path_like=f"{req.path}") - - if len(metadata_list) == 0: - raise HTTPException(status_code=404, detail="path not found") - elif len(metadata_list) > 1: - raise HTTPException(status_code=400, detail="found too many files to apply diff") - - metadata = metadata_list[0] + try: + metadata = get_one_metadata(conn, path=f"{req.path}") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) abs_path = server_settings.snapshot_folder / metadata.path with open(abs_path, "rb") as f: @@ -177,13 +175,10 @@ def delete_file( conn: sqlite3.Connection = Depends(get_db_connection), server_settings: ServerSettings = Depends(get_server_settings), ) -> JSONResponse: - metadata_list = get_all_metadata(conn, path_like=f"{req.path}") - if len(metadata_list) == 0: - raise HTTPException(status_code=404, detail="path not found") - elif len(metadata_list) > 1: - raise HTTPException(status_code=400, detail="too many files to delete") - - metadata = metadata_list[0] + try: + metadata = get_one_metadata(conn, path=f"{req.path}") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) try: delete_file_metadata(conn, metadata.path.as_posix()) @@ -201,7 +196,9 @@ def create_file( conn: sqlite3.Connection = Depends(get_db_connection), server_settings: ServerSettings = Depends(get_server_settings), ) -> JSONResponse: - # + if "%" in file.filename: + raise HTTPException(status_code=400, detail="filename cannot contain '%'") + relative_path = Path(file.filename) abs_path = server_settings.snapshot_folder / relative_path @@ -217,9 +214,14 @@ def create_file( f.write(contents) cursor = conn.cursor() - metadata = get_all_metadata(cursor, path_like=f"{file.filename}") - if len(metadata) > 0: + try: + get_one_metadata(cursor, path=f"{file.filename}") raise HTTPException(status_code=400, detail="file already exists") + except ValueError: + # this is ok, there should be no metadata in db + pass + + # create a new metadata for db entry metadata = hash_file(abs_path, root_dir=server_settings.snapshot_folder) save_file_metadata(cursor, metadata) conn.commit() @@ -234,13 +236,11 @@ def download_file( conn: sqlite3.Connection = Depends(get_db_connection), server_settings: ServerSettings = Depends(get_server_settings), ) -> FileResponse: - metadata_list = get_all_metadata(conn, path_like=f"{req.path}") - if len(metadata_list) == 0: - raise HTTPException(status_code=404, detail="path not found") - elif len(metadata_list) > 1: - raise HTTPException(status_code=400, detail="too many files to download") + try: + metadata = get_one_metadata(conn, path=f"{req.path}") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) - metadata = metadata_list[0] abs_path = server_settings.snapshot_folder / metadata.path if not Path(abs_path).exists(): # could be a stale db entry, remove from db @@ -273,11 +273,12 @@ async def get_files( ) -> StreamingResponse: all_metadata = [] for path in req.paths: - metadata_list = get_all_metadata(conn, path_like=f"{path}") - if len(metadata_list) != 1: - logger.warning(f"Expected 1 metadata, got {len(metadata_list)} for {path}") + try: + metadata = get_one_metadata(conn, path=path) + except ValueError as e: + logger.warning(str(e)) continue - metadata = metadata_list[0] + abs_path = server_settings.snapshot_folder / metadata.path if not Path(abs_path).exists() or not Path(abs_path).is_file(): logger.warning(f"File not found: {abs_path}") diff --git a/tests/unit/server/sync_endpoint_test.py b/tests/unit/server/sync_endpoint_test.py index eaf2a160..45f98a5a 100644 --- a/tests/unit/server/sync_endpoint_test.py +++ b/tests/unit/server/sync_endpoint_test.py @@ -10,7 +10,6 @@ from py_fast_rsync import signature from syftbox.client.plugins.sync.endpoints import ( - SyftNotFound, SyftServerError, apply_diff, download_bulk, @@ -24,17 +23,6 @@ from tests.unit.server.conftest import PERMFILE_FILE, TEST_DATASITE_NAME, TEST_FILE -def test_get_all_permissions(client: TestClient): - # TODO: filter permissions and not return everything - response = client.post( - "/sync/get_metadata", - json={"path_like": "%.syftperm"}, - ) - - response.raise_for_status() - assert len(response.json()) - - def test_get_diff_2(client: TestClient): local_data = b"This is my local data" sig = signature.calculate(local_data) @@ -72,11 +60,11 @@ def file_digest(file_path, algorithm="sha256"): def test_syft_client_push_flow(client: TestClient): response = client.post( "/sync/get_metadata", - json={"path_like": f"%{TEST_DATASITE_NAME}/{TEST_FILE}"}, + json={"path_like": f"{TEST_DATASITE_NAME}/{TEST_FILE}"}, ) response.raise_for_status() - server_signature_b85 = response.json()[0]["signature"] + server_signature_b85 = response.json()["signature"] server_signature = base64.b85decode(server_signature_b85) assert server_signature @@ -167,9 +155,8 @@ def test_get_diff(client: TestClient): # diff nonexistent file file_path = Path(TEST_DATASITE_NAME) / "nonexistent_file.txt" - with pytest.raises(SyftServerError) as e: + with pytest.raises(SyftServerError): get_diff(client, file_path, sig) - assert "path not found" in str(e.value) def test_delete_file(client: TestClient): @@ -183,7 +170,7 @@ def test_delete_file(client: TestClient): path = Path(f"{snapshot_folder}/{TEST_DATASITE_NAME}/{TEST_FILE}") assert not path.exists() - with pytest.raises(SyftNotFound): + with pytest.raises(SyftServerError): get_metadata(client, Path(TEST_DATASITE_NAME) / TEST_FILE)