Skip to content

Commit

Permalink
Merge pull request #315 from OpenMined/aziz/batch_download_attempt_2
Browse files Browse the repository at this point in the history
batch download at the start
  • Loading branch information
eelcovdw authored Nov 1, 2024
2 parents f7bab05 + fbf0dda commit 10d90c7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
35 changes: 15 additions & 20 deletions syftbox/client/plugins/sync/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import hashlib
import threading
import zipfile
from collections import defaultdict
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Dict, Optional
from typing import Optional

import py_fast_rsync
from loguru import logger
Expand All @@ -25,7 +24,7 @@
)
from syftbox.client.plugins.sync.exceptions import FatalSyncError, SyncEnvironmentError
from syftbox.client.plugins.sync.queue import SyncQueue, SyncQueueItem
from syftbox.client.plugins.sync.sync import SyncSide
from syftbox.client.plugins.sync.sync import DatasiteState, SyncSide
from syftbox.lib.lib import Client, SyftPermission
from syftbox.server.sync.hash import hash_file
from syftbox.server.sync.models import FileMetadata
Expand Down Expand Up @@ -79,8 +78,8 @@ def create_local(client: Client, remote_syncstate: FileMetadata):
abs_path.write_bytes(content_bytes)


def create_local_batch(client: Client, remote_syncstates: list[FileMetadata]):
paths = [str(remote_syncstate.data.path) for remote_syncstate in remote_syncstates]
def create_local_batch(client: Client, remote_syncstates: list[Path]):
paths = [str(path) for path in remote_syncstates]
content_bytes = download_bulk(client.server_client, paths)
zip_file = zipfile.ZipFile(BytesIO(content_bytes))
zip_file.extractall(client.sync_folder)
Expand Down Expand Up @@ -405,14 +404,9 @@ def validate_sync_environment(self):
raise SyncEnvironmentError("Your previous sync state has been deleted by a different process.")

def consume_all(self):
batched_items: Dict[SyncActionType, list[SyncQueueItem]] = defaultdict(list)
while not self.queue.empty():
self.validate_sync_environment()
item = self.queue.get(timeout=0.1)
if self.get_decisions(item).local_decision.action_type == SyncActionType.CREATE_LOCAL:
batched_items[SyncActionType.CREATE_LOCAL].append(item)
continue

try:
self.process_filechange(item)
except FatalSyncError as e:
Expand All @@ -421,17 +415,18 @@ def consume_all(self):
except Exception as e:
logger.exception(f"Failed to sync file {item.data.path}. Reason: {e}")

download_items = batched_items[SyncActionType.CREATE_LOCAL]
if download_items:
self.batch_download(download_items)

def batch_download(self, download_items: list[SyncQueueItem]):
self.validate_sync_environment()
create_local_batch(self.client, download_items)
for item in download_items:
def download_all_missing(self, datasite_states: list[DatasiteState]):
missing_files = []
for datasite_state in datasite_states:
for file in datasite_state.remote_state:
path = file.path
if not self.previous_state.states.get(path):
missing_files.append(path)
create_local_batch(self.client, missing_files)
for path in missing_files:
self.previous_state.insert(
path=item.data.path,
state=self.get_decisions(item).result_local_state,
path=path,
state=self.get_current_local_syncstate(path),
)

def get_decisions(self, item: SyncQueueItem) -> SyncDecisionTuple:
Expand Down
9 changes: 9 additions & 0 deletions syftbox/client/plugins/sync/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, client: Client):
self.sync_interval = 1 # seconds
self.thread: Optional[Thread] = None
self.is_stop_requested = False
self.sync_run_once = False

def is_alive(self) -> bool:
return self.thread is not None and self.thread.is_alive()
Expand Down Expand Up @@ -89,8 +90,16 @@ def run_single_thread(self):
logger.info(f"Syncing {len(datasite_states)} datasites")
logger.debug(f"Datasites: {', '.join([datasite.email for datasite in datasite_states])}")

if not self.sync_run_once:
# Download all missing files at the start
self.consumer.download_all_missing(
datasite_states=datasite_states,
)

for datasite_state in datasite_states:
self.enqueue_datasite_changes(datasite_state)

# TODO stop consumer if self.is_stop_requested
self.consumer.consume_all()

self.sync_run_once = True

0 comments on commit 10d90c7

Please sign in to comment.