From 0381d613d276cb6d79f5f4a66beb23f0b920e524 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Sun, 15 Dec 2024 15:38:18 +0530 Subject: [PATCH] move base classes to __init__ --- syftbox/client/base.py | 30 +----------------------- syftbox/client/benchmark/__init__.py | 34 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 29 deletions(-) create mode 100644 syftbox/client/benchmark/__init__.py diff --git a/syftbox/client/base.py b/syftbox/client/base.py index 3ebf7b0c..dc262eae 100644 --- a/syftbox/client/base.py +++ b/syftbox/client/base.py @@ -1,8 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING import httpx from loguru import logger @@ -114,30 +113,3 @@ def from_config(cls, config: SyftClientConfig): headers=cls._make_headers(config), ) return cls(conn) - - -@dataclass -class BaseMetric: - """Base class for all metrics with common fields.""" - - num_runs: int - - -class MetricCollector(Protocol): - """ - Protocol for classes that collect performance metrics. - """ - - client_config: SyftClientConfig - - def collect_metrics(self, num_runs: int) -> BaseMetric: - """Calculate performance metrics.""" - ... - - -class BenchmarkReporter(Protocol): - """Protocol defining the interface for benchmark result reporters.""" - - def generate(self, metrics: dict[str, BaseMetric], report_path: Optional[Path] = None) -> Any: - """Generate the benchmark report.""" - ... diff --git a/syftbox/client/benchmark/__init__.py b/syftbox/client/benchmark/__init__.py new file mode 100644 index 00000000..a3fc6adf --- /dev/null +++ b/syftbox/client/benchmark/__init__.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +from typing_extensions import Protocol + +from syftbox.lib.client_config import SyftClientConfig + + +@dataclass +class BaseMetric: + """Base class for all metrics with common fields.""" + + num_runs: int + + +class MetricCollector(Protocol): + """ + Protocol for classes that collect performance metrics. + """ + + client_config: SyftClientConfig + + def collect_metrics(self, num_runs: int) -> BaseMetric: + """Calculate performance metrics.""" + ... + + +class BenchmarkReporter(Protocol): + """Protocol defining the interface for benchmark result reporters.""" + + def generate(self, metrics: dict[str, BaseMetric], report_path: Optional[Path] = None) -> Any: + """Generate the benchmark report.""" + ...