From d920f8ec57fb2029c398acc11ac0d8155225401c Mon Sep 17 00:00:00 2001
From: Jie Pu <pujie2@huawei.com>
Date: Wed, 25 Aug 2021 19:22:16 +0800
Subject: [PATCH 1/2] add transmitter, client_choose, aggregation interface

1. add transmitter, client_choose, aggregation interface to Lib.
2. add example of how to use new added interface.

Signed-off-by: Jie Pu <pujie2@huawei.com>
Signed-off-by: XinYao1994 <xyao@cs.hku.hk>

update
---
 .../federatedlearningjob_yolo_v1alpha1.yaml   |  54 ++++
 .../yolov5_coco128_mistnet/README.md          | 238 ++++++++++++++++++
 .../yolov5_coco128_mistnet/aggregate.py       |  35 +++
 .../yolov5_coco128_mistnet/interface.py       | 149 +++++++++++
 .../yolov5_coco128_mistnet/train.py           |  33 +++
 lib/sedna/algorithms/aggregation/__init__.py  |   1 +
 .../algorithms/aggregation/aggregation.py     |  11 +
 .../algorithms/client_choose/__init__.py      |  15 ++
 .../algorithms/client_choose/client_choose.py |  36 +++
 lib/sedna/algorithms/transmitter/__init__.py  |  15 ++
 .../algorithms/transmitter/transmitter.py     |  64 +++++
 lib/sedna/common/config.py                    |   7 +
 .../federated_learning/federated_learning.py  |  79 +++++-
 13 files changed, 732 insertions(+), 5 deletions(-)
 create mode 100644 build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml
 create mode 100644 examples/federated_learning/yolov5_coco128_mistnet/README.md
 create mode 100644 examples/federated_learning/yolov5_coco128_mistnet/aggregate.py
 create mode 100644 examples/federated_learning/yolov5_coco128_mistnet/interface.py
 create mode 100644 examples/federated_learning/yolov5_coco128_mistnet/train.py
 create mode 100644 lib/sedna/algorithms/client_choose/__init__.py
 create mode 100644 lib/sedna/algorithms/client_choose/client_choose.py
 create mode 100644 lib/sedna/algorithms/transmitter/__init__.py
 create mode 100644 lib/sedna/algorithms/transmitter/transmitter.py

diff --git a/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml b/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml
new file mode 100644
index 000000000..9864951a9
--- /dev/null
+++ b/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml
@@ -0,0 +1,54 @@
+apiVersion: sedna.io/v1alpha1
+kind: FederatedLearningJob
+metadata:
+  name: yolo-v5
+spec:
+  pretrainedModel: # option
+    name: "yolo-v5-pretrained-model"
+  transimitter: # option
+    ws: { } # option, by default
+    s3: # option, but at least one
+      aggDataPath: "s3://sedna/fl/aggregation_data"
+      credentialName: mysecret
+  aggregationWorker:
+    model:
+      name: "yolo-v5-model"
+    template:
+      spec:
+        nodeName: "sedna-control-plane"
+        containers:
+          - image: kubeedge/sedna-fl-aggregation:mistnetyolo
+            name: agg-worker
+            imagePullPolicy: IfNotPresent
+            env: # user defined environments
+              - name: "cut_layer"
+                value: "4"
+              - name: "epsilon"
+                value: "100"
+              - name: "aggregation_algorithm"
+                value: "mistnet"
+              - name: "batch_size"
+            resources: # user defined resources
+              limits:
+                memory: 8Gi
+  trainingWorkers:
+    - dataset:
+        name: "coco-dataset"
+      template:
+        spec:
+          nodeName: "edge-node"
+          containers:
+            - image: kubeedge/sedna-fl-train:mistnetyolo
+              name: train-worker
+              imagePullPolicy: IfNotPresent
+              args: [ "-i", "1" ]
+              env: # user defined environments
+                - name: "batch_size"
+                  value: "32"
+                - name: "learning_rate"
+                  value: "0.001"
+                - name: "epochs"
+                  value: "1"
+              resources: # user defined resources
+                limits:
+                  memory: 2Gi
\ No newline at end of file
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/README.md b/examples/federated_learning/yolov5_coco128_mistnet/README.md
new file mode 100644
index 000000000..509734ae8
--- /dev/null
+++ b/examples/federated_learning/yolov5_coco128_mistnet/README.md
@@ -0,0 +1,238 @@
+# Collaboratively Train Yolo-v5 Using MistNet on COCO128 Dataset
+
+This case introduces how to train a federated learning job with an aggregation algorithm named MistNet in MNIST
+handwritten digit classification scenario. Data is scattered in different places (such as edge nodes, cameras, and
+others) and cannot be aggregated at the server due to data privacy and bandwidth. As a result, we cannot use all the
+data for training. In some cases, edge nodes have limited computing resources and even have no training capability. The
+edge cannot gain the updated weights from the training process. Therefore, traditional algorithms (e.g., federated
+average), which usually aggregate the updated weights trained by different edge clients, cannot work in this scenario.
+MistNet is proposed to address this issue.
+
+MistNet partitions a DNN model into two parts, a lightweight feature extractor at the edge side to generate meaningful
+features from the raw data, and a classifier including the most model layers at the cloud to be iteratively trained for
+specific tasks. MistNet achieves acceptable model utility while greatly reducing privacy leakage from the released
+intermediate features.
+
+## Object Detection Experiment
+
+> Assume that there are two edge nodes and a cloud node. Data on the edge nodes cannot be migrated to the cloud due to privacy issues.
+> Base on this scenario, we will demonstrate the mnist example.
+
+### Prepare Nodes
+
+```
+CLOUD_NODE="cloud-node-name"
+EDGE1_NODE="edge1-node-name"
+EDGE2_NODE="edge2-node-name"
+```
+
+### Install Sedna
+
+Follow the [Sedna installation document](/docs/setup/install.md) to install Sedna.
+
+### Prepare Dataset
+
+Download [dataset](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip) and do data partition
+
+```
+wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
+unzip coco128.zip -d data
+rm coco128.zip
+python partition.py ./data 2
+```
+
+move ```./data/1``` to `/data` of ```EDGE1_NODE```.
+
+```
+mkdir -p /data
+cd /data
+mv ./data/1 ./
+```
+
+move ```./data/2``` to `/data` of ```EDGE2_NODE```.
+
+```
+mkdir -p /data
+cd /data
+mv ./data/2 ./
+```
+
+### Prepare Images
+
+This example uses these images:
+
+1. aggregation worker: ```kubeedge/sedna-example-federated-learning-mistnet:v0.3.0```
+2. train worker: ```kubeedge/sedna-example-federated-learning-mistnet-client:v0.3.0```
+
+These images are generated by the script [build_images.sh](/examples/build_image.sh).
+
+### Create Federated Learning Job
+
+#### Create Dataset
+
+create dataset for `$EDGE1_NODE`
+
+```n
+kubectl create -f - <<EOF
+apiVersion: sedna.io/v1alpha1
+kind: Dataset
+metadata:
+  name: "coco-dataset"
+spec:
+  url: "/data/test.txt"
+  format: "txt"
+  nodeName: edge-node
+EOF
+```
+
+create dataset for `$EDGE2_NODE`
+
+```
+kubectl create -f - <<EOF
+apiVersion: sedna.io/v1alpha1
+kind: Dataset
+metadata:
+  name: "coco-dataset"
+spec:
+  url: "/data/test.txt"
+  format: "txt"
+  nodeName: edge-node
+EOF
+```
+
+#### Create Model
+
+create the directory `/model` in the host of `$EDGE1_NODE`
+
+```
+mkdir /model
+```
+
+create the directory `/model` in the host of `$EDGE2_NODE`
+
+```
+mkdir /model
+```
+
+```
+TODO: put pretrained model on nodes.
+```
+
+create model
+
+```
+kubectl create -f - <<EOF
+apiVersion: sedna.io/v1alpha1
+kind: Model
+metadata:
+  name: "yolo-v5-model"
+spec:
+  url: "/model/yolo.pb"
+  format: "pb"
+EOF
+```
+
+#### Start Federated Learning Job
+
+```
+kubectl create -f - <<EOF
+apiVersion: sedna.io/v1alpha1
+kind: FederatedLearningJob
+metadata:
+  name: mistnet-on-mnist-dataset
+spec:
+  stopCondition:
+    operator: "or" # and
+      conditions:
+        - operator: ">"
+          threshold: 100
+          metric: rounds
+        - operator: ">"
+          threshold: 0.95
+          metric: targetAccuracy
+        - operator: "<"
+          threshold: 0.03
+          metric: deltaLoss
+  aggregationTrigger:
+    condition:
+      operator: ">"
+      threshold: 5
+      metric: num_of_ready_clients
+  aggregationWorker:
+    model:
+      name: "mistnet-on-mnist-model"
+    template:
+      spec:
+        nodeName: $CLOUD_NODE
+        containers:
+          - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-aggregation:v0.4.0
+            name:  agg-worker
+            imagePullPolicy: IfNotPresent
+            env: # user defined environments
+              - name: "cut_layer"
+                value: "4"
+              - name: "epsilon"
+                value: "100"
+              - name: "aggregation_algorithm"
+                value: "mistnet"
+              - name: "batch_size"
+                value: "10"
+            resources:  # user defined resources
+              limits:
+                memory: 2Gi
+  trainingWorkers:
+    - dataset:
+        name: "edge1-surface-defect-detection-dataset"
+      template:
+        spec:
+          nodeName: $EDGE1_NODE
+          containers:
+            - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-train:v0.4.0
+              name:  train-worker
+              imagePullPolicy: IfNotPresent
+              env:  # user defined environments
+                - name: "batch_size"
+                  value: "32"
+                - name: "learning_rate"
+                  value: "0.001"
+                - name: "epochs"
+                  value: "2"
+              resources:  # user defined resources
+                limits:
+                  memory: 2Gi
+    - dataset:
+        name: "edge2-surface-defect-detection-dataset"
+      template:
+        spec:
+          nodeName: $EDGE2_NODE
+          containers:
+            - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-train:v0.4.0
+              name:  train-worker
+              imagePullPolicy: IfNotPresent
+              env:  # user defined environments
+                - name: "batch_size"
+                  value: "32"
+                - name: "learning_rate"
+                  value: "0.001"
+                - name: "epochs"
+                  value: "2"
+              resources:  # user defined resources
+                limits:
+                  memory: 2Gi
+EOF
+```
+
+```
+TODO: show the benifit of mistnet. for example, the compared results of fedavg & mistnet.
+
+```
+
+### Check Federated Learning Status
+
+```
+kubectl get federatedlearningjob surface-defect-detection
+```
+
+### Check Federated Learning Train Result
+
+After the job completed, you will find the model generated on the directory `/model` in `$EDGE1_NODE` and `$EDGE2_NODE`.
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/aggregate.py b/examples/federated_learning/yolov5_coco128_mistnet/aggregate.py
new file mode 100644
index 000000000..0ba8d558f
--- /dev/null
+++ b/examples/federated_learning/yolov5_coco128_mistnet/aggregate.py
@@ -0,0 +1,35 @@
+# Copyright 2021 The KubeEdge Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from interface import mistnet, s3_transmitter, simple_chooser
+from interface import Dataset, Estimator
+from sedna.service.server import AggregationServer
+
+
+def run_server():
+    data = Dataset()
+    estimator = Estimator()
+
+    server = AggregationServer(
+        data=data,
+        estimator=estimator,
+        aggregation=mistnet,
+        transmitter=s3_transmitter,
+        chooser=simple_chooser)
+
+    server.start()
+
+
+if __name__ == '__main__':
+    run_server()
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/interface.py b/examples/federated_learning/yolov5_coco128_mistnet/interface.py
new file mode 100644
index 000000000..6c654f58c
--- /dev/null
+++ b/examples/federated_learning/yolov5_coco128_mistnet/interface.py
@@ -0,0 +1,149 @@
+# Copyright 2021 The KubeEdge Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from sedna.algorithms.aggregation import MistNet
+from sedna.algorithms.client_choose import SimpleClientChoose
+from sedna.common.config import Context
+from sedna.core.federated_learning import FederatedLearning
+
+simple_chooser = SimpleClientChoose(per_round=1)
+
+# It has been determined that mistnet is required here.
+mistnet = MistNet(cut_layer=Context.get_parameters("cut_layer"),
+                  epsilon=Context.get_parameters("epsilon"))
+
+# The function `get_transmitter_from_config()` returns an object instance.
+s3_transmitter = FederatedLearning.get_transmitter_from_config()
+
+
+class Dataset:
+    def __init__(self) -> None:
+        self.parameters = {
+            "datasource": "YOLO",
+            "data_params": "./coco128.yaml",
+            # Where the dataset is located
+            "data_path": "./data/COCO",
+            "train_path": "./data/COCO/coco128/images/train2017/",
+            "test_path": "./data/COCO/coco128/images/train2017/",
+            # number of training examples
+            "num_train_examples": 128,
+            # number of testing examples
+            "num_test_examples": 128,
+            # number of classes
+            "num_classes": 80,
+            # image size
+            "image_size": 640,
+            "classes":
+                [
+                    "person",
+                    "bicycle",
+                    "car",
+                    "motorcycle",
+                    "airplane",
+                    "bus",
+                    "train",
+                    "truck",
+                    "boat",
+                    "traffic light",
+                    "fire hydrant",
+                    "stop sign",
+                    "parking meter",
+                    "bench",
+                    "bird",
+                    "cat",
+                    "dog",
+                    "horse",
+                    "sheep",
+                    "cow",
+                    "elephant",
+                    "bear",
+                    "zebra",
+                    "giraffe",
+                    "backpack",
+                    "umbrella",
+                    "handbag",
+                    "tie",
+                    "suitcase",
+                    "frisbee",
+                    "skis",
+                    "snowboard",
+                    "sports ball",
+                    "kite",
+                    "baseball bat",
+                    "baseball glove",
+                    "skateboard",
+                    "surfboard",
+                    "tennis racket",
+                    "bottle",
+                    "wine glass",
+                    "cup",
+                    "fork",
+                    "knife",
+                    "spoon",
+                    "bowl",
+                    "banana",
+                    "apple",
+                    "sandwich",
+                    "orange",
+                    "broccoli",
+                    "carrot",
+                    "hot dog",
+                    "pizza",
+                    "donut",
+                    "cake",
+                    "chair",
+                    "couch",
+                    "potted plant",
+                    "bed",
+                    "dining table",
+                    "toilet",
+                    "tv",
+                    "laptop",
+                    "mouse",
+                    "remote",
+                    "keyboard",
+                    "cell phone",
+                    "microwave",
+                    "oven",
+                    "toaster",
+                    "sink",
+                    "refrigerator",
+                    "book",
+                    "clock",
+                    "vase",
+                    "scissors",
+                    "teddy bear",
+                    "hair drier",
+                    "toothbrush",
+                ],
+            "partition_size": 128,
+        }
+
+
+class Estimator:
+    def __init__(self) -> None:
+        self.model = None
+        self.hyperparameters = {
+            "type": "yolov5",
+            "rounds": 1,
+            "target_accuracy": 0.99,
+            "epochs": 500,
+            "batch_size": 16,
+            "optimizer": "SGD",
+            "linear_lr": False,
+            # The machine learning model
+            "model_name": "yolov5",
+            "model_config": "./yolov5s.yaml",
+            "train_params": "./hyp.scratch.yaml"
+        }
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/train.py b/examples/federated_learning/yolov5_coco128_mistnet/train.py
new file mode 100644
index 000000000..62886cb34
--- /dev/null
+++ b/examples/federated_learning/yolov5_coco128_mistnet/train.py
@@ -0,0 +1,33 @@
+# Copyright 2021 The KubeEdge Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from interface import mistnet, s3_transmitter
+from interface import Dataset, Estimator
+from sedna.core.federated_learning import FederatedLearning
+
+
+def main():
+    data = Dataset()
+    estimator = Estimator()
+
+    fl_model = FederatedLearning(
+        estimator=estimator,
+        aggregation=mistnet,
+        transmitter=s3_transmitter)
+
+    fl_model.train(data)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/lib/sedna/algorithms/aggregation/__init__.py b/lib/sedna/algorithms/aggregation/__init__.py
index eba0a1881..96d7da5e5 100644
--- a/lib/sedna/algorithms/aggregation/__init__.py
+++ b/lib/sedna/algorithms/aggregation/__init__.py
@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from . import aggregation
+from .aggregation import FedAvg, MistNet
diff --git a/lib/sedna/algorithms/aggregation/aggregation.py b/lib/sedna/algorithms/aggregation/aggregation.py
index 3b998814f..6f13962f3 100644
--- a/lib/sedna/algorithms/aggregation/aggregation.py
+++ b/lib/sedna/algorithms/aggregation/aggregation.py
@@ -104,3 +104,14 @@ def aggregate(self, clients: List[AggClient]):
             updates.append(row.tolist())
         self.weights = deepcopy(updates)
         return updates
+
+
+@ClassFactory.register(ClassType.FL_AGG)
+class MistNet(BaseAggregation, abc.ABC):
+    def __init__(self, cut_layer, epsilon=100):
+        super().__init__()
+        self.cut_layer = cut_layer
+        self.epsilon = epsilon
+
+    def aggregate(self, clients: List[AggClient]):
+        pass
diff --git a/lib/sedna/algorithms/client_choose/__init__.py b/lib/sedna/algorithms/client_choose/__init__.py
new file mode 100644
index 000000000..d5f58d4a0
--- /dev/null
+++ b/lib/sedna/algorithms/client_choose/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2021 The KubeEdge Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .client_choose import SimpleClientChoose
diff --git a/lib/sedna/algorithms/client_choose/client_choose.py b/lib/sedna/algorithms/client_choose/client_choose.py
new file mode 100644
index 000000000..4bd4756a6
--- /dev/null
+++ b/lib/sedna/algorithms/client_choose/client_choose.py
@@ -0,0 +1,36 @@
+# Copyright 2021 The KubeEdge Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+
+
+class AbstractClientChoose(metaclass=abc.ABCMeta):
+    """
+    Abstract class of ClientChoose, which provides base client choose
+    algorithm interfaces in federated learning.
+    """
+
+    def __init__(self):
+        pass
+
+
+class SimpleClientChoose(AbstractClientChoose):
+    """
+    A Simple Implementation of Client Choose.
+    """
+
+    def __init__(self, per_round=1):
+        super().__init__()
+        self.per_round = per_round
diff --git a/lib/sedna/algorithms/transmitter/__init__.py b/lib/sedna/algorithms/transmitter/__init__.py
new file mode 100644
index 000000000..b71ccf1f7
--- /dev/null
+++ b/lib/sedna/algorithms/transmitter/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2021 The KubeEdge Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .transmitter import S3Transmitter, WSTransmitter
diff --git a/lib/sedna/algorithms/transmitter/transmitter.py b/lib/sedna/algorithms/transmitter/transmitter.py
new file mode 100644
index 000000000..865398995
--- /dev/null
+++ b/lib/sedna/algorithms/transmitter/transmitter.py
@@ -0,0 +1,64 @@
+# Copyright 2021 The KubeEdge Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+
+
+class AbstractTransmitter(ABC):
+    """
+    Abstract class of Transmitter, which provides base transmission
+    interfaces between edge and cloud.
+    """
+
+    @abstractmethod
+    def recv(self):
+        pass
+
+    @abstractmethod
+    def send(self, data):
+        pass
+
+
+class WSTransmitter(AbstractTransmitter, ABC):
+    """
+    An implementation of Transmitter based on WebSocket.
+    """
+
+    def recv(self):
+        pass
+
+    def send(self, data):
+        pass
+
+
+class S3Transmitter(AbstractTransmitter, ABC):
+    """
+    An implementation of Transmitter based on S3 protocol.
+    """
+
+    def __init__(self,
+                 s3_endpoint_url,
+                 access_key,
+                 secret_key,
+                 transmitter_url):
+        self.s3_endpoint_url = s3_endpoint_url
+        self.access_key = access_key
+        self.secret_key = secret_key
+        self.transmitter_url = transmitter_url
+
+    def recv(self):
+        pass
+
+    def send(self, data):
+        pass
diff --git a/lib/sedna/common/config.py b/lib/sedna/common/config.py
index 769ef1a45..ad6c62d4a 100644
--- a/lib/sedna/common/config.py
+++ b/lib/sedna/common/config.py
@@ -269,9 +269,16 @@ class BaseConfig(ConfigSerializable):
     # the name of FederatedLearningJob and others Job
     job_name = os.getenv("JOB_NAME", "sedna")
 
+    pretrained_model_url = os.getenv("PRETRAINED_MODEL_URL", "./")
     model_url = os.getenv("MODEL_URL")
     model_name = os.getenv("MODEL_NAME")
 
+    transmitter = os.getenv("TRANSMITTER", "ws")
+    agg_data_path = os.getenv("AGG_DATA_PATH", "./")
+    s3_endpoint_url = os.getenv("S3_ENDPOINT_URL", "")
+    access_key_id = os.getenv("ACCESS_KEY_ID", "")
+    secret_access_key = os.getenv("SECRET_ACCESS_KEY", "")
+
     # user parameter
     parameters = os.getenv("PARAMETERS")
 
diff --git a/lib/sedna/core/federated_learning/federated_learning.py b/lib/sedna/core/federated_learning/federated_learning.py
index eec652dd1..f49825dda 100644
--- a/lib/sedna/core/federated_learning/federated_learning.py
+++ b/lib/sedna/core/federated_learning/federated_learning.py
@@ -13,17 +13,23 @@
 # limitations under the License.
 
 
+import asyncio
+import sys
 import time
 
-from sedna.core.base import JobBase
-from sedna.common.config import Context
-from sedna.common.file_ops import FileOps
+from plato.clients import registry as client_registry
+from plato.config import Config
+
+from sedna.algorithms.transmitter import S3Transmitter, WSTransmitter
 from sedna.common.class_factory import ClassFactory, ClassType
-from sedna.service.client import AggregationClient
+from sedna.common.config import BaseConfig, Context
 from sedna.common.constant import K8sResourceKindStatus
+from sedna.common.file_ops import FileOps
+from sedna.core.base import JobBase
+from sedna.service.client import AggregationClient
 
 
-class FederatedLearning(JobBase):
+class FederatedLearningV0(JobBase):
     """
     Federated learning enables multiple actors to build a common, robust
     machine learning model without sharing data, thus allowing to address
@@ -50,6 +56,7 @@ class FederatedLearning(JobBase):
             aggregation="FedAvg"
         )
     """
+
     def __init__(self, estimator, aggregation="FedAvg"):
 
         protocol = Context.get_parameters("AGG_PROTOCOL", "ws")
@@ -178,3 +185,65 @@ def train(self, train_data,
                     task_info,
                     K8sResourceKindStatus.RUNNING.value,
                     task_info_res)
+
+
+class FederatedLearning:
+    def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None) -> None:
+        # set parameters
+        server = Config.server._asdict()
+        clients = Config.clients._asdict()
+        datastore = Config.data._asdict()
+        train = Config.trainer._asdict()
+
+        if data is not None:
+            for xkey in data.parameters:
+                datastore[xkey] = data.parameters[xkey]
+            Config.data = Config.namedtuple_from_dict(datastore)
+
+        self.model = None
+        if estimator is not None:
+            self.model = estimator.model
+            for xkey in estimator.hyperparameters:
+                train[xkey] = estimator.hyperparameters[xkey]
+            Config.trainer = Config.namedtuple_from_dict(train)
+
+        if aggregation is not None:
+            Config.algorithm = Config.namedtuple_from_dict(aggregation.parameters)
+            if aggregation.parameters["type"] == "mistnet":
+                clients["type"] = "mistnet"
+                server["type"] = "mistnet"
+
+        if isinstance(transmitter, S3Transmitter):
+            server["address"] = Context.get_parameters("AGG_IP")
+            server["port"] = Context.get_parameters("AGG_PORT")
+            server["s3_endpoint_url"] = transmitter.s3_endpoint_url
+            server["s3_bucket"] = transmitter.s3_bucket
+            server["access_key"] = transmitter.access_key
+            server["secret_key"] = transmitter.secret_key
+        elif isinstance(transmitter, WSTransmitter):
+            pass
+
+        Config.server = Config.namedtuple_from_dict(server)
+        Config.clients = Config.namedtuple_from_dict(clients)
+
+        # Config.store()
+        # create a client
+        self.client = client_registry.get(model=self.model)
+        self.client.configure()
+
+    @classmethod
+    def get_transmitter_from_config(cls):
+        if BaseConfig.transmitter == "ws":
+            return WSTransmitter()
+        elif BaseConfig.transmitter == "s3":
+            return S3Transmitter(s3_endpoint_url=BaseConfig.s3_endpoint_url,
+                                 access_key=BaseConfig.access_key_id,
+                                 secret_key=BaseConfig.secret_access_key,
+                                 transmitter_url=BaseConfig.agg_data_path)
+
+    def train(self):
+        if int(sys.version[2]) <= 6:
+            loop = asyncio.get_event_loop()
+            loop.run_until_complete(self.client.start_client())
+        else:
+            asyncio.run(self.client.start_client())

From 2488b3f94bfd70a9291d3765bab7a9571da6ccf4 Mon Sep 17 00:00:00 2001
From: XinYao1994 <xyao@cs.hku.hk>
Date: Wed, 8 Sep 2021 16:47:24 +0800
Subject: [PATCH 2/2] add federated learning inplementation by plato

Signed-off-by: XinYao1994 <xyao@cs.hku.hk>
---
 .../federatedlearningjob_yolo_v1alpha1.yaml   |  41 +++-
 examples/build_image.sh                       |   4 +-
 ...earning-mistnet-yolo-aggregator.Dockerfile |  23 ++
 ...ed-learning-mistnet-yolo-client.Dockerfile |  23 ++
 .../training_worker/train.py                  |   2 +-
 .../yolov5_coco128_mistnet/README.md          | 205 +++++++++---------
 .../yolov5_coco128_mistnet/aggregate.py       |   8 +-
 .../yolov5_coco128_mistnet/coco128.yaml       |  28 +++
 .../yolov5_coco128_mistnet/hyp.scratch.yaml   |  33 +++
 .../yolov5_coco128_mistnet/interface.py       |   6 +-
 .../yolov5_coco128_mistnet/train.py           |  16 +-
 .../yolov5_coco128_mistnet/yolov5s.yaml       |  48 ++++
 lib/sedna/algorithms/aggregation/__init__.py  |   2 +-
 .../algorithms/aggregation/aggregation.py     |  14 +-
 .../algorithms/client_choose/client_choose.py |   4 +-
 .../algorithms/transmitter/transmitter.py     |  13 +-
 lib/sedna/core/federated_learning/__init__.py |   1 +
 .../federated_learning/federated_learning.py  |  38 ++--
 lib/sedna/service/server/aggregation.py       |  69 +++++-
 19 files changed, 427 insertions(+), 151 deletions(-)
 create mode 100644 examples/federated-learning-mistnet-yolo-aggregator.Dockerfile
 create mode 100644 examples/federated-learning-mistnet-yolo-client.Dockerfile
 create mode 100644 examples/federated_learning/yolov5_coco128_mistnet/coco128.yaml
 create mode 100644 examples/federated_learning/yolov5_coco128_mistnet/hyp.scratch.yaml
 create mode 100644 examples/federated_learning/yolov5_coco128_mistnet/yolov5s.yaml

diff --git a/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml b/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml
index 9864951a9..8230be0d6 100644
--- a/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml
+++ b/build/crd-samples/sedna/federatedlearningjob_yolo_v1alpha1.yaml
@@ -5,7 +5,7 @@ metadata:
 spec:
   pretrainedModel: # option
     name: "yolo-v5-pretrained-model"
-  transimitter: # option
+  transmitter: # option
     ws: { } # option, by default
     s3: # option, but at least one
       aggDataPath: "s3://sedna/fl/aggregation_data"
@@ -17,7 +17,7 @@ spec:
       spec:
         nodeName: "sedna-control-plane"
         containers:
-          - image: kubeedge/sedna-fl-aggregation:mistnetyolo
+          - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-aggregator:v0.4.0
             name: agg-worker
             imagePullPolicy: IfNotPresent
             env: # user defined environments
@@ -28,21 +28,54 @@ spec:
               - name: "aggregation_algorithm"
                 value: "mistnet"
               - name: "batch_size"
+                value: "32"
             resources: # user defined resources
               limits:
                 memory: 8Gi
   trainingWorkers:
     - dataset:
-        name: "coco-dataset"
+        name: "coco-dataset-1"
       template:
         spec:
           nodeName: "edge-node"
           containers:
-            - image: kubeedge/sedna-fl-train:mistnetyolo
+            - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-client:v0.4.0
               name: train-worker
               imagePullPolicy: IfNotPresent
               args: [ "-i", "1" ]
               env: # user defined environments
+                - name: "cut_layer"
+                  value: "4"
+                - name: "epsilon"
+                  value: "100"
+                - name: "aggregation_algorithm"
+                  value: "mistnet"
+                - name: "batch_size"
+                  value: "32"
+                - name: "learning_rate"
+                  value: "0.001"
+                - name: "epochs"
+                  value: "1"
+              resources: # user defined resources
+                limits:
+                  memory: 2Gi
+    - dataset:
+        name: "coco-dataset-2"
+      template:
+        spec:
+          nodeName: "edge-node"
+          containers:
+            - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-client:v0.4.0
+              name: train-worker
+              imagePullPolicy: IfNotPresent
+              args: [ "-i", "2" ]
+              env: # user defined environments
+                - name: "cut_layer"
+                  value: "4"
+                - name: "epsilon"
+                  value: "100"
+                - name: "aggregation_algorithm"
+                  value: "mistnet"
                 - name: "batch_size"
                   value: "32"
                 - name: "learning_rate"
diff --git a/examples/build_image.sh b/examples/build_image.sh
index 6a154d845..fb05c2c3f 100644
--- a/examples/build_image.sh
+++ b/examples/build_image.sh
@@ -17,11 +17,13 @@
 cd "$(dirname "${BASH_SOURCE[0]}")"
 
 IMAGE_REPO=${IMAGE_REPO:-kubeedge}
-IMAGE_TAG=${IMAGE_TAG:-v0.3.0}
+IMAGE_TAG=${IMAGE_TAG:-v0.4.0}
 
 EXAMPLE_REPO_PREFIX=${IMAGE_REPO}/sedna-example-
 
 dockerfiles=(
+federated-learning-mistnet-yolo-aggregator.Dockerfile
+federated-learning-mistnet-yolo-client.Dockerfile
 federated-learning-surface-defect-detection-aggregation.Dockerfile
 federated-learning-surface-defect-detection-train.Dockerfile
 incremental-learning-helmet-detection.Dockerfile
diff --git a/examples/federated-learning-mistnet-yolo-aggregator.Dockerfile b/examples/federated-learning-mistnet-yolo-aggregator.Dockerfile
new file mode 100644
index 000000000..e316f6eb3
--- /dev/null
+++ b/examples/federated-learning-mistnet-yolo-aggregator.Dockerfile
@@ -0,0 +1,23 @@
+FROM tensorflow/tensorflow:1.15.4
+
+RUN apt update \
+  && apt install -y libgl1-mesa-glx git
+
+COPY ./lib/requirements.txt /home
+
+RUN python -m pip install --upgrade pip
+
+RUN pip install -r /home/requirements.txt
+
+ENV PYTHONPATH "/home/lib:/home/plato:/home/plato/packages/yolov5"
+
+COPY ./lib /home/lib
+RUN git clone https://github.com/TL-System/plato.git /home/plato
+
+RUN pip install -r /home/plato/requirements.txt
+RUN pip install -r /home/plato/packages/yolov5/requirements.txt
+
+WORKDIR /home/work
+COPY examples/federated_learning/yolov5_coco128_mistnet  /home/work/
+
+CMD ["/bin/sh", "-c", "ulimit -n 50000; python aggregate.py"]
diff --git a/examples/federated-learning-mistnet-yolo-client.Dockerfile b/examples/federated-learning-mistnet-yolo-client.Dockerfile
new file mode 100644
index 000000000..b1e7aa356
--- /dev/null
+++ b/examples/federated-learning-mistnet-yolo-client.Dockerfile
@@ -0,0 +1,23 @@
+FROM tensorflow/tensorflow:1.15.4
+
+RUN apt update \
+  && apt install -y libgl1-mesa-glx git
+
+COPY ./lib/requirements.txt /home
+
+RUN python -m pip install --upgrade pip
+
+RUN pip install -r /home/requirements.txt
+
+ENV PYTHONPATH "/home/lib:/home/plato:/home/plato/packages/yolov5"
+
+COPY ./lib /home/lib
+RUN git clone https://github.com/TL-System/plato.git /home/plato
+
+RUN pip install -r /home/plato/requirements.txt
+RUN pip install -r /home/plato/packages/yolov5/requirements.txt
+
+WORKDIR /home/work
+COPY examples/federated_learning/yolov5_coco128_mistnet   /home/work/
+
+ENTRYPOINT ["python", "train.py"]
diff --git a/examples/federated_learning/surface_defect_detection/training_worker/train.py b/examples/federated_learning/surface_defect_detection/training_worker/train.py
index 4fd9a1122..37f21d0cc 100644
--- a/examples/federated_learning/surface_defect_detection/training_worker/train.py
+++ b/examples/federated_learning/surface_defect_detection/training_worker/train.py
@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import os
 
 import numpy as np
@@ -74,6 +73,7 @@ def main():
         learning_rate=learning_rate,
         validation_split=validation_split
     )
+    
     return train_jobs
 
 
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/README.md b/examples/federated_learning/yolov5_coco128_mistnet/README.md
index 509734ae8..0bc26fa0c 100644
--- a/examples/federated_learning/yolov5_coco128_mistnet/README.md
+++ b/examples/federated_learning/yolov5_coco128_mistnet/README.md
@@ -32,37 +32,32 @@ Follow the [Sedna installation document](/docs/setup/install.md) to install Sedn
 
 ### Prepare Dataset
 
-Download [dataset](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip) and do data partition
+Download [dataset](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip) 
 
-```
-wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
-unzip coco128.zip -d data
-rm coco128.zip
-python partition.py ./data 2
-```
+Create data interface for ```EDGE1_NODE```.
 
-move ```./data/1``` to `/data` of ```EDGE1_NODE```.
-
-```
-mkdir -p /data
-cd /data
-mv ./data/1 ./
+```shell
+mkdir -p /data/1
+cd /data/1
+wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
+unzip coco128.zip -d COCO
 ```
 
-move ```./data/2``` to `/data` of ```EDGE2_NODE```.
+Create data interface for ```EDGE2_NODE```.
 
-```
-mkdir -p /data
-cd /data
-mv ./data/2 ./
+```shell
+mkdir -p /data/2
+cd /data/2
+wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
+unzip coco128.zip -d COCO
 ```
 
 ### Prepare Images
 
 This example uses these images:
 
-1. aggregation worker: ```kubeedge/sedna-example-federated-learning-mistnet:v0.3.0```
-2. train worker: ```kubeedge/sedna-example-federated-learning-mistnet-client:v0.3.0```
+1. aggregation worker: ```kubeedge/sedna-example-federated-learning-mistnet-yolo-aggregato:v0.4.0```
+2. train worker: ```kubeedge/sedna-example-federated-learning-mistnet-yolo-client:v0.4.0```
 
 These images are generated by the script [build_images.sh](/examples/build_image.sh).
 
@@ -70,103 +65,118 @@ These images are generated by the script [build_images.sh](/examples/build_image
 
 #### Create Dataset
 
-create dataset for `$EDGE1_NODE`
+create dataset for `$EDGE1_NODE` and `$EDGE2_NODE`
 
-```n
+```bash
 kubectl create -f - <<EOF
 apiVersion: sedna.io/v1alpha1
 kind: Dataset
 metadata:
-  name: "coco-dataset"
+  name: "coco-dataset-1"
 spec:
-  url: "/data/test.txt"
-  format: "txt"
-  nodeName: edge-node
+  url: "/data/1/COCO"
+  format: "dir"
+  nodeName: $EDGE1_NODE
 EOF
 ```
 
-create dataset for `$EDGE2_NODE`
-
-```
+```bash
 kubectl create -f - <<EOF
 apiVersion: sedna.io/v1alpha1
 kind: Dataset
 metadata:
-  name: "coco-dataset"
+  name: "coco-dataset-2"
 spec:
-  url: "/data/test.txt"
-  format: "txt"
-  nodeName: edge-node
+  url: "/data/2/COCO"
+  format: "dir"
+  nodeName: $EDGE2_NODE
 EOF
 ```
 
 #### Create Model
-
-create the directory `/model` in the host of `$EDGE1_NODE`
-
-```
-mkdir /model
+create the directory `/model` and `/pretrained` in `$EDGE1_NODE` and `$EDGE2_NODE`.
+```bash
+mkdir -p /model
+mkdir -p /pretrained
 ```
 
-create the directory `/model` in the host of `$EDGE2_NODE`
+create the directory `/model` and `/pretrained` in the host of `$CLOUD_NODE` (download links [here](https://kubeedge.obs.cn-north-1.myhuaweicloud.com/examples/yolov5_coco128_mistnet/yolov5.pth))
 
-```
-mkdir /model
-```
 
-```
-TODO: put pretrained model on nodes.
+```bash
+# on the cloud side
+mkdir -p /model
+mkdir -p /pretrained
+cd /pretrained
+wget https://kubeedge.obs.cn-north-1.myhuaweicloud.com/examples/yolov5_coco128_mistnet/yolov5.pth
 ```
 
 create model
 
-```
+```bash
 kubectl create -f - <<EOF
 apiVersion: sedna.io/v1alpha1
 kind: Model
 metadata:
   name: "yolo-v5-model"
 spec:
-  url: "/model/yolo.pb"
-  format: "pb"
+  url: "/model/yolov5.pth"
+  format: "pth"
+EOF
+
+kubectl create -f - <<EOF
+apiVersion: sedna.io/v1alpha1
+kind: Model
+metadata:
+  name: "yolo-v5-pretrained-model"
+spec:
+  url: "/pretrained/yolov5.pth"
+  format: "pth"
 EOF
 ```
 
-#### Start Federated Learning Job
+### Create a secret with your S3 user credential. (Optional)
 
+```shell
+kubectl create -f - <<EOF
+apiVersion: v1
+kind: Secret
+metadata:
+  name: mysecret
+  annotations:
+    s3-endpoint: s3.amazonaws.com 
+    s3-usehttps: "1" 
+stringData: 
+  ACCESS_KEY_ID: XXXX
+  SECRET_ACCESS_KEY: XXXXXXXX
+EOF
 ```
+
+#### Start Federated Learning Job
+
+```bash
 kubectl create -f - <<EOF
 apiVersion: sedna.io/v1alpha1
 kind: FederatedLearningJob
 metadata:
-  name: mistnet-on-mnist-dataset
+  name: yolo-v5
 spec:
-  stopCondition:
-    operator: "or" # and
-      conditions:
-        - operator: ">"
-          threshold: 100
-          metric: rounds
-        - operator: ">"
-          threshold: 0.95
-          metric: targetAccuracy
-        - operator: "<"
-          threshold: 0.03
-          metric: deltaLoss
-  aggregationTrigger:
-    condition:
-      operator: ">"
-      threshold: 5
-      metric: num_of_ready_clients
+  pretrainedModel: # option
+    name: "yolo-v5-pretrained-model"
+  transmitter: # option
+    ws: { } # option, by default
+    s3: # optional, but at least one
+      aggDataPath: "s3://sedna/fl/aggregation_data"
+      credentialName: mysecret
   aggregationWorker:
     model:
-      name: "mistnet-on-mnist-model"
+      name: "yolo-v5-model"
     template:
       spec:
         nodeName: $CLOUD_NODE
         containers:
-          - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-aggregation:v0.4.0
-            name:  agg-worker
+          - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-aggregator:v0.4.0
+            name: agg-worker
             imagePullPolicy: IfNotPresent
             env: # user defined environments
               - name: "cut_layer"
@@ -176,63 +186,64 @@ spec:
               - name: "aggregation_algorithm"
                 value: "mistnet"
               - name: "batch_size"
-                value: "10"
-            resources:  # user defined resources
+                value: "32"
+            resources: # user defined resources
               limits:
-                memory: 2Gi
+                memory: 8Gi
   trainingWorkers:
     - dataset:
-        name: "edge1-surface-defect-detection-dataset"
+        name: "coco-dataset-1"
       template:
         spec:
           nodeName: $EDGE1_NODE
           containers:
-            - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-train:v0.4.0
-              name:  train-worker
+            - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-client:v0.4.0
+              name: train-worker
               imagePullPolicy: IfNotPresent
-              env:  # user defined environments
+              args: [ "-i", "1" ]
+              env: # user defined environments
+                - name: "cut_layer"
+                  value: "4"
+                - name: "epsilon"
+                  value: "100"
+                - name: "aggregation_algorithm"
+                  value: "mistnet"
                 - name: "batch_size"
                   value: "32"
                 - name: "learning_rate"
                   value: "0.001"
                 - name: "epochs"
-                  value: "2"
-              resources:  # user defined resources
+                  value: "1"
+              resources: # user defined resources
                 limits:
                   memory: 2Gi
     - dataset:
-        name: "edge2-surface-defect-detection-dataset"
+        name: "coco-dataset-2"
       template:
         spec:
           nodeName: $EDGE2_NODE
           containers:
-            - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-train:v0.4.0
-              name:  train-worker
+            - image: kubeedge/sedna-example-federated-learning-mistnet-yolo-client:v0.4.0
+              name: train-worker
               imagePullPolicy: IfNotPresent
-              env:  # user defined environments
+              args: [ "-i", "2" ]
+              env: # user defined environments
+                - name: "cut_layer"
+                  value: "4"
+                - name: "epsilon"
+                  value: "100"
+                - name: "aggregation_algorithm"
+                  value: "mistnet"
                 - name: "batch_size"
                   value: "32"
                 - name: "learning_rate"
                   value: "0.001"
                 - name: "epochs"
-                  value: "2"
-              resources:  # user defined resources
+                  value: "1"
+              resources: # user defined resources
                 limits:
                   memory: 2Gi
 EOF
 ```
 
-```
-TODO: show the benifit of mistnet. for example, the compared results of fedavg & mistnet.
-
-```
-
-### Check Federated Learning Status
-
-```
-kubectl get federatedlearningjob surface-defect-detection
-```
-
-### Check Federated Learning Train Result
 
-After the job completed, you will find the model generated on the directory `/model` in `$EDGE1_NODE` and `$EDGE2_NODE`.
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/aggregate.py b/examples/federated_learning/yolov5_coco128_mistnet/aggregate.py
index 0ba8d558f..1e6b314e6 100644
--- a/examples/federated_learning/yolov5_coco128_mistnet/aggregate.py
+++ b/examples/federated_learning/yolov5_coco128_mistnet/aggregate.py
@@ -14,14 +14,16 @@
 
 from interface import mistnet, s3_transmitter, simple_chooser
 from interface import Dataset, Estimator
-from sedna.service.server import AggregationServer
-
+from sedna.service.server import AggregationServerV2
+from sedna.common.config import BaseConfig
 
 def run_server():
     data = Dataset()
     estimator = Estimator()
 
-    server = AggregationServer(
+    estimator.pretrained = BaseConfig.pretrained_model_url.replace("yolov5.pth", "")
+
+    server = AggregationServerV2(
         data=data,
         estimator=estimator,
         aggregation=mistnet,
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/coco128.yaml b/examples/federated_learning/yolov5_coco128_mistnet/coco128.yaml
new file mode 100644
index 000000000..cabd6e133
--- /dev/null
+++ b/examples/federated_learning/yolov5_coco128_mistnet/coco128.yaml
@@ -0,0 +1,28 @@
+# COCO 2017 dataset http://cocodataset.org - first 128 training images
+# Train command: python train.py --data coco128.yaml
+# Default dataset location is next to YOLOv5:
+#   /parent_folder
+#     /coco128
+#     /yolov5
+
+
+# download command/URL (optional)
+download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
+
+# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
+train: ./data/COCO/coco128/images/train2017/  # 128 images
+val: ./data/COCO/coco128/images/train2017/ # 128 images
+
+# number of classes
+nc: 80
+
+# class names
+names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
+         'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
+         'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
+         'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
+         'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
+         'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
+         'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
+         'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
+         'hair drier', 'toothbrush' ]
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/hyp.scratch.yaml b/examples/federated_learning/yolov5_coco128_mistnet/hyp.scratch.yaml
new file mode 100644
index 000000000..44f26b665
--- /dev/null
+++ b/examples/federated_learning/yolov5_coco128_mistnet/hyp.scratch.yaml
@@ -0,0 +1,33 @@
+# Hyperparameters for COCO training from scratch
+# python train.py --batch 40 --cfg yolov5m.yaml --weights '' --data coco.yaml --img 640 --epochs 300
+# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials
+
+
+lr0: 0.01  # initial learning rate (SGD=1E-2, Adam=1E-3)
+lrf: 0.2  # final OneCycleLR learning rate (lr0 * lrf)
+momentum: 0.937  # SGD momentum/Adam beta1
+weight_decay: 0.0005  # optimizer weight decay 5e-4
+warmup_epochs: 3.0  # warmup epochs (fractions ok)
+warmup_momentum: 0.8  # warmup initial momentum
+warmup_bias_lr: 0.1  # warmup initial bias lr
+box: 0.05  # box loss gain
+cls: 0.5  # cls loss gain
+cls_pw: 1.0  # cls BCELoss positive_weight
+obj: 1.0  # obj loss gain (scale with pixels)
+obj_pw: 1.0  # obj BCELoss positive_weight
+iou_t: 0.20  # IoU training threshold
+anchor_t: 4.0  # anchor-multiple threshold
+# anchors: 3  # anchors per output layer (0 to ignore)
+fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5)
+hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
+hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
+hsv_v: 0.4  # image HSV-Value augmentation (fraction)
+degrees: 0.0  # image rotation (+/- deg)
+translate: 0.1  # image translation (+/- fraction)
+scale: 0.5  # image scale (+/- gain)
+shear: 0.0  # image shear (+/- deg)
+perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
+flipud: 0.0  # image flip up-down (probability)
+fliplr: 0.5  # image flip left-right (probability)
+mosaic: 1.0  # image mosaic (probability)
+mixup: 0.0  # image mixup (probability)
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/interface.py b/examples/federated_learning/yolov5_coco128_mistnet/interface.py
index 6c654f58c..0f56a3a42 100644
--- a/examples/federated_learning/yolov5_coco128_mistnet/interface.py
+++ b/examples/federated_learning/yolov5_coco128_mistnet/interface.py
@@ -15,7 +15,7 @@
 from sedna.algorithms.aggregation import MistNet
 from sedna.algorithms.client_choose import SimpleClientChoose
 from sedna.common.config import Context
-from sedna.core.federated_learning import FederatedLearning
+from sedna.core.federated_learning import FederatedLearningV2
 
 simple_chooser = SimpleClientChoose(per_round=1)
 
@@ -24,7 +24,7 @@
                   epsilon=Context.get_parameters("epsilon"))
 
 # The function `get_transmitter_from_config()` returns an object instance.
-s3_transmitter = FederatedLearning.get_transmitter_from_config()
+s3_transmitter = FederatedLearningV2.get_transmitter_from_config()
 
 
 class Dataset:
@@ -44,6 +44,7 @@ def __init__(self) -> None:
             "num_classes": 80,
             # image size
             "image_size": 640,
+            "download_urls": ["https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip",],
             "classes":
                 [
                     "person",
@@ -134,6 +135,7 @@ def __init__(self) -> None:
 class Estimator:
     def __init__(self) -> None:
         self.model = None
+        self.pretrained = None
         self.hyperparameters = {
             "type": "yolov5",
             "rounds": 1,
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/train.py b/examples/federated_learning/yolov5_coco128_mistnet/train.py
index 62886cb34..99406dd21 100644
--- a/examples/federated_learning/yolov5_coco128_mistnet/train.py
+++ b/examples/federated_learning/yolov5_coco128_mistnet/train.py
@@ -11,23 +11,25 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import os
 from interface import mistnet, s3_transmitter
 from interface import Dataset, Estimator
-from sedna.core.federated_learning import FederatedLearning
-
+from sedna.common.config import BaseConfig
+from sedna.core.federated_learning import FederatedLearningV2
 
 def main():
     data = Dataset()
     estimator = Estimator()
-
-    fl_model = FederatedLearning(
+    data.parameters["data_path"] = BaseConfig.train_dataset_url.replace("robot.txt", "")
+    data.parameters["train_path"] = os.path.join(data.parameters["data_path"], "./coco128/images/train2017/")
+    data.parameters["test_path"] = data.parameters["train_path"]
+    fl_model = FederatedLearningV2(
+        data=data,
         estimator=estimator,
         aggregation=mistnet,
         transmitter=s3_transmitter)
 
-    fl_model.train(data)
-
+    fl_model.train()
 
 if __name__ == '__main__':
     main()
diff --git a/examples/federated_learning/yolov5_coco128_mistnet/yolov5s.yaml b/examples/federated_learning/yolov5_coco128_mistnet/yolov5s.yaml
new file mode 100644
index 000000000..e4e9e4dde
--- /dev/null
+++ b/examples/federated_learning/yolov5_coco128_mistnet/yolov5s.yaml
@@ -0,0 +1,48 @@
+# parameters
+nc: 80  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.50  # layer channel multiple
+
+# anchors
+anchors:
+  - [ 10,13, 16,30, 33,23 ]  # P3/8
+  - [ 30,61, 62,45, 59,119 ]  # P4/16
+  - [ 116,90, 156,198, 373,326 ]  # P5/32
+
+# YOLOv5 backbone
+backbone:
+  # [from, number, module, args]
+  [ [ -1, 1, Focus, [ 64, 3 ] ],  # 0-P1/2
+    [ -1, 1, Conv, [ 128, 3, 2 ] ],  # 1-P2/4
+    [ -1, 3, C3, [ 128 ] ],
+    [ -1, 1, Conv, [ 256, 3, 2 ] ],  # 3-P3/8
+    [ -1, 9, C3, [ 256 ] ],
+    [ -1, 1, Conv, [ 512, 3, 2 ] ],  # 5-P4/16
+    [ -1, 9, C3, [ 512 ] ],
+    [ -1, 1, Conv, [ 1024, 3, 2 ] ],  # 7-P5/32
+    [ -1, 1, SPP, [ 1024, [ 5, 9, 13 ] ] ],
+    [ -1, 3, C3, [ 1024, False ] ],  # 9
+  ]
+
+# YOLOv5 head
+head:
+  [ [ -1, 1, Conv, [ 512, 1, 1 ] ],
+    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
+    [ [ -1, 6 ], 1, Concat, [ 1 ] ],  # cat backbone P4
+    [ -1, 3, C3, [ 512, False ] ],  # 13
+
+    [ -1, 1, Conv, [ 256, 1, 1 ] ],
+    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
+    [ [ -1, 4 ], 1, Concat, [ 1 ] ],  # cat backbone P3
+    [ -1, 3, C3, [ 256, False ] ],  # 17 (P3/8-small)
+
+    [ -1, 1, Conv, [ 256, 3, 2 ] ],
+    [ [ -1, 14 ], 1, Concat, [ 1 ] ],  # cat head P4
+    [ -1, 3, C3, [ 512, False ] ],  # 20 (P4/16-medium)
+
+    [ -1, 1, Conv, [ 512, 3, 2 ] ],
+    [ [ -1, 10 ], 1, Concat, [ 1 ] ],  # cat head P5
+    [ -1, 3, C3, [ 1024, False ] ],  # 23 (P5/32-large)
+
+    [ [ 17, 20, 23 ], 1, Detect, [ nc, anchors ] ],  # Detect(P3, P4, P5)
+  ]
diff --git a/lib/sedna/algorithms/aggregation/__init__.py b/lib/sedna/algorithms/aggregation/__init__.py
index 96d7da5e5..4725746ab 100644
--- a/lib/sedna/algorithms/aggregation/__init__.py
+++ b/lib/sedna/algorithms/aggregation/__init__.py
@@ -13,4 +13,4 @@
 # limitations under the License.
 
 from . import aggregation
-from .aggregation import FedAvg, MistNet
+from .aggregation import FedAvg, MistNet, AggClient
diff --git a/lib/sedna/algorithms/aggregation/aggregation.py b/lib/sedna/algorithms/aggregation/aggregation.py
index 6f13962f3..d2116eadd 100644
--- a/lib/sedna/algorithms/aggregation/aggregation.py
+++ b/lib/sedna/algorithms/aggregation/aggregation.py
@@ -110,8 +110,18 @@ def aggregate(self, clients: List[AggClient]):
 class MistNet(BaseAggregation, abc.ABC):
     def __init__(self, cut_layer, epsilon=100):
         super().__init__()
-        self.cut_layer = cut_layer
-        self.epsilon = epsilon
+        self.parameters = {
+            "type": "mistnet",
+            "cut_layer": cut_layer,
+            "epsilon": epsilon
+        }
+        if isinstance(self.parameters["cut_layer"], str):
+            if self.parameters["cut_layer"].isdigit():
+                self.parameters["cut_layer"] = int(cut_layer)
+
+        if isinstance(self.parameters["epsilon"], str):
+            if self.parameters["epsilon"].isdigit():
+                self.parameters["epsilon"] = int(cut_layer)
 
     def aggregate(self, clients: List[AggClient]):
         pass
diff --git a/lib/sedna/algorithms/client_choose/client_choose.py b/lib/sedna/algorithms/client_choose/client_choose.py
index 4bd4756a6..f7e10ca1b 100644
--- a/lib/sedna/algorithms/client_choose/client_choose.py
+++ b/lib/sedna/algorithms/client_choose/client_choose.py
@@ -33,4 +33,6 @@ class SimpleClientChoose(AbstractClientChoose):
 
     def __init__(self, per_round=1):
         super().__init__()
-        self.per_round = per_round
+        self.parameters = {
+            "per_round": per_round
+        }
diff --git a/lib/sedna/algorithms/transmitter/transmitter.py b/lib/sedna/algorithms/transmitter/transmitter.py
index 865398995..0aaac0c74 100644
--- a/lib/sedna/algorithms/transmitter/transmitter.py
+++ b/lib/sedna/algorithms/transmitter/transmitter.py
@@ -35,6 +35,9 @@ class WSTransmitter(AbstractTransmitter, ABC):
     An implementation of Transmitter based on WebSocket.
     """
 
+    def __init__(self):
+        self.parameters = {}
+
     def recv(self):
         pass
 
@@ -52,10 +55,12 @@ def __init__(self,
                  access_key,
                  secret_key,
                  transmitter_url):
-        self.s3_endpoint_url = s3_endpoint_url
-        self.access_key = access_key
-        self.secret_key = secret_key
-        self.transmitter_url = transmitter_url
+        self.parameters = {
+            "s3_endpoint_url": s3_endpoint_url,
+            "s3_bucket": transmitter_url,
+            "access_key": access_key,
+            "secret_key": secret_key
+        }
 
     def recv(self):
         pass
diff --git a/lib/sedna/core/federated_learning/__init__.py b/lib/sedna/core/federated_learning/__init__.py
index c36fe3b80..e5eaad2d7 100644
--- a/lib/sedna/core/federated_learning/__init__.py
+++ b/lib/sedna/core/federated_learning/__init__.py
@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from .federated_learning import FederatedLearning
+from .federated_learning import FederatedLearningV2
diff --git a/lib/sedna/core/federated_learning/federated_learning.py b/lib/sedna/core/federated_learning/federated_learning.py
index f49825dda..999a153ed 100644
--- a/lib/sedna/core/federated_learning/federated_learning.py
+++ b/lib/sedna/core/federated_learning/federated_learning.py
@@ -17,9 +17,6 @@
 import sys
 import time
 
-from plato.clients import registry as client_registry
-from plato.config import Config
-
 from sedna.algorithms.transmitter import S3Transmitter, WSTransmitter
 from sedna.common.class_factory import ClassFactory, ClassType
 from sedna.common.config import BaseConfig, Context
@@ -28,8 +25,10 @@
 from sedna.core.base import JobBase
 from sedna.service.client import AggregationClient
 
+__all__ = ('FederatedLearning', 'FederatedLearningV2')
+
 
-class FederatedLearningV0(JobBase):
+class FederatedLearning(JobBase):
     """
     Federated learning enables multiple actors to build a common, robust
     machine learning model without sharing data, thus allowing to address
@@ -187,8 +186,12 @@ def train(self, train_data,
                     task_info_res)
 
 
-class FederatedLearning:
-    def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None) -> None:
+class FederatedLearningV2:
+    def __init__(self, data=None, estimator=None,
+                 aggregation=None, transmitter=None) -> None:
+
+        from plato.config import Config
+        from plato.clients import registry as client_registry
         # set parameters
         server = Config.server._asdict()
         clients = Config.clients._asdict()
@@ -196,32 +199,27 @@ def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None
         train = Config.trainer._asdict()
 
         if data is not None:
-            for xkey in data.parameters:
-                datastore[xkey] = data.parameters[xkey]
+            datastore.update(data.parameters)
             Config.data = Config.namedtuple_from_dict(datastore)
 
         self.model = None
         if estimator is not None:
             self.model = estimator.model
-            for xkey in estimator.hyperparameters:
-                train[xkey] = estimator.hyperparameters[xkey]
+            train.update(estimator.hyperparameters)
             Config.trainer = Config.namedtuple_from_dict(train)
 
         if aggregation is not None:
-            Config.algorithm = Config.namedtuple_from_dict(aggregation.parameters)
+            Config.algorithm = Config.namedtuple_from_dict(
+                aggregation.parameters)
             if aggregation.parameters["type"] == "mistnet":
                 clients["type"] = "mistnet"
                 server["type"] = "mistnet"
 
-        if isinstance(transmitter, S3Transmitter):
-            server["address"] = Context.get_parameters("AGG_IP")
-            server["port"] = Context.get_parameters("AGG_PORT")
-            server["s3_endpoint_url"] = transmitter.s3_endpoint_url
-            server["s3_bucket"] = transmitter.s3_bucket
-            server["access_key"] = transmitter.access_key
-            server["secret_key"] = transmitter.secret_key
-        elif isinstance(transmitter, WSTransmitter):
-            pass
+        server["address"] = Context.get_parameters("AGG_IP")
+        server["port"] = Context.get_parameters("AGG_PORT")
+
+        if transmitter is not None:
+            server.update(transmitter.parameters)
 
         Config.server = Config.namedtuple_from_dict(server)
         Config.clients = Config.namedtuple_from_dict(clients)
diff --git a/lib/sedna/service/server/aggregation.py b/lib/sedna/service/server/aggregation.py
index 3529d6b48..717540fea 100644
--- a/lib/sedna/service/server/aggregation.py
+++ b/lib/sedna/service/server/aggregation.py
@@ -13,26 +13,26 @@
 # limitations under the License.
 
 import time
-from typing import List, Optional, Dict, Any
-
 import uuid
-from pydantic import BaseModel
+from typing import Any, Dict, List, Optional
+
 from fastapi import FastAPI, WebSocket
 from fastapi.routing import APIRoute
+from pydantic import BaseModel
+from starlette.endpoints import WebSocketEndpoint
 from starlette.requests import Request
 from starlette.responses import JSONResponse
 from starlette.routing import WebSocketRoute
-from starlette.endpoints import WebSocketEndpoint
 from starlette.types import ASGIApp, Receive, Scope, Send
 
+from sedna.algorithms.aggregation import AggClient
+from sedna.common.config import BaseConfig, Context
+from sedna.common.class_factory import ClassFactory, ClassType
 from sedna.common.log import LOGGER
 from sedna.common.utils import get_host_ip
-from sedna.common.class_factory import ClassFactory, ClassType
-from sedna.algorithms.aggregation import AggClient
-
 from .base import BaseServer
 
-__all__ = ('AggregationServer',)
+__all__ = ('AggregationServer', 'AggregationServerV2')
 
 
 class WSClientInfo(BaseModel):  # pylint: disable=too-few-public-methods
@@ -266,3 +266,56 @@ async def client_info(self, request: Request):
         if client_id:
             return server.get_client(client_id)
         return WSClientInfoList(clients=server.client_list)
+
+
+class AggregationServerV2():
+    def __init__(self, data=None, estimator=None,
+                 aggregation=None, transmitter=None,
+                 chooser=None) -> None:
+        from plato.config import Config
+        from plato.servers import registry as server_registry
+        # set parameters
+        server = Config.server._asdict()
+        clients = Config.clients._asdict()
+        datastore = Config.data._asdict()
+        train = Config.trainer._asdict()
+
+        if data is not None:
+            datastore.update(data.parameters)
+            Config.data = Config.namedtuple_from_dict(datastore)
+
+        self.model = None
+        if estimator is not None:
+            self.model = estimator.model
+            if estimator.pretrained is not None:
+                LOGGER.info(estimator.pretrained)
+                Config.params['model_dir'] = estimator.pretrained
+            train.update(estimator.hyperparameters)
+            Config.trainer = Config.namedtuple_from_dict(train)
+
+        server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0")
+        server["port"] = Context.get_parameters("AGG_BIND_PORT", 7363)
+        if transmitter is not None:
+            server.update(transmitter.parameters)
+
+        if aggregation is not None:
+            Config.algorithm = Config.namedtuple_from_dict(
+                aggregation.parameters)
+            if aggregation.parameters["type"] == "mistnet":
+                clients["type"] = "mistnet"
+                server["type"] = "mistnet"
+
+        if chooser is not None:
+            clients["per_round"] = chooser.parameters["per_round"]
+
+        LOGGER.info("address %s, port %s", server["address"], server["port"])
+
+        Config.server = Config.namedtuple_from_dict(server)
+        Config.clients = Config.namedtuple_from_dict(clients)
+
+        # Config.store()
+        # create a server
+        self.server = server_registry.get(model=self.model)
+
+    def start(self):
+        self.server.run()