Skip to content

Commit

Permalink
add load method to SyftBaseModel
Browse files Browse the repository at this point in the history
  • Loading branch information
abyesilyurt committed Oct 11, 2024
1 parent e80f61a commit 9c57750
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
13 changes: 6 additions & 7 deletions syftbox/client/plugins/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from syftbox.lib import (
DirState,
FileInfo,
PermissionTree,
ResettableTimer,
bintostr,
Expand Down Expand Up @@ -236,7 +235,9 @@ def filter_changes(
return valid_changes, valid_change_files, invalid_changes


def push_changes(client_config: ClientConfig, changes: list[FileChange]):
def push_changes(
client_config: ClientConfig, changes: list[FileChange]
) -> list[FileChange]:
written_changes = []
for change in changes:
try:
Expand Down Expand Up @@ -352,9 +353,7 @@ def get_remote_state(client_config: ClientConfig, sub_path: str):
dir_state = DirState(**state_response["dir_state"])
fix_tree = {}
for key, value in dir_state.tree.items():
fix_tree[key] = (
FileInfo(**value) if isinstance(value, dict) else value
)
fix_tree[key] = value
dir_state.tree = fix_tree
return dir_state
else:
Expand Down Expand Up @@ -428,7 +427,7 @@ def filter_changes_ignore(
return filtered_changes


def sync_up(client_config):
def sync_up(client_config: ClientConfig):
# create a folder to store the change log
change_log_folder = f"{client_config.sync_folder}/{CLIENT_CHANGELOG_FOLDER}"
os.makedirs(change_log_folder, exist_ok=True)
Expand All @@ -453,7 +452,7 @@ def sync_up(client_config):
old_dir_state = DirState.load(dir_filename)
fix_tree = {}
for key, value in old_dir_state.tree.items():
fix_tree[key] = FileInfo(**value) if isinstance(value, dict) else value
fix_tree[key] = value
old_dir_state.tree = fix_tree
except Exception:
pass
Expand Down
7 changes: 7 additions & 0 deletions syftbox/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

from pydantic import BaseModel
from typing_extensions import Self


class SyftBaseModel(BaseModel):
Expand All @@ -16,6 +17,12 @@ def save(self, path: str) -> bool:
f.write(self.model_dump_json())
return self.model_dump(mode="json")

@classmethod
def load(cls, filepath: str) -> Self:
with open(filepath) as f:
data = f.read()
return cls.model_validate_json(data)


class FileChangeKind(Enum):
CREATE: str = "create"
Expand Down

0 comments on commit 9c57750

Please sign in to comment.