From e0e96543d81a0a032aec9c4fee07b3b1f01b6721 Mon Sep 17 00:00:00 2001 From: Gaius Date: Tue, 27 Aug 2024 14:03:04 +0800 Subject: [PATCH] feat: remove trainer and model Signed-off-by: Gaius --- .github/workflows/docker.yml | 4 +- .goreleaser.yml | 20 - Makefile | 42 +- build/images/trainer/Dockerfile | 40 - cmd/trainer/cmd/root.go | 125 --- cmd/trainer/main.go | 23 - go.mod | 2 +- go.sum | 4 +- hack/build.sh | 17 - hack/docker-build.sh | 3 - hack/docker-push.sh | 3 - hack/install.sh | 19 - hack/kind-load.sh | 3 - internal/dflog/loginit.go | 20 - manager/config/config.go | 20 - manager/config/config_test.go | 26 - manager/config/constants.go | 5 - manager/config/testdata/manager.yaml | 4 - manager/database/database.go | 1 - manager/handlers/model.go | 125 --- manager/handlers/model_test.go | 268 ----- manager/models/model.go | 46 - manager/models/scheduler.go | 1 - manager/router/router.go | 7 - manager/rpcserver/manager_server_v1.go | 162 --- manager/service/model.go | 190 ---- manager/service/service.go | 5 - pkg/rpc/inference/client/client_v1.go | 127 --- .../inference/client/mocks/client_v1_mock.go | 116 --- pkg/rpc/manager/client/client_v1.go | 12 - pkg/rpc/trainer/client/client_v1.go | 99 -- .../trainer/client/mocks/client_v1_mock.go | 76 -- pkg/rpc/trainer/server/server.go | 97 -- pkg/types/constants.go | 6 - scheduler/announcer/announcer.go | 130 --- scheduler/announcer/announcer_test.go | 962 +----------------- scheduler/config/config.go | 37 - scheduler/config/config_test.go | 51 - scheduler/config/constants.go | 11 - scheduler/config/testdata/scheduler.yaml | 6 - scheduler/scheduler.go | 43 +- trainer/config/config.go | 232 ----- trainer/config/config_test.go | 267 ----- trainer/config/constants.go | 64 -- trainer/config/testdata/ca.crt | 1 - trainer/config/testdata/trainer.yaml | 33 - trainer/metrics/metrics.go | 68 -- trainer/metrics/metrics_test.go | 42 - trainer/rpcserver/rpcserver.go | 38 - trainer/rpcserver/rpcserver_test.go | 56 - trainer/rpcserver/trainer_server_v1.go | 51 - trainer/service/service_v1.go | 162 --- trainer/service/service_v1_test.go | 505 --------- trainer/storage/mocks/storage_mock.go | 143 --- trainer/storage/storage.go | 148 --- trainer/storage/storage_test.go | 554 ---------- trainer/storage/testdata/download.csv | 1 - trainer/storage/testdata/networktopology.csv | 1 - trainer/trainer.go | 187 ---- trainer/training/mocks/training_mock.go | 54 - trainer/training/training.go | 98 -- 61 files changed, 26 insertions(+), 5637 deletions(-) delete mode 100644 build/images/trainer/Dockerfile delete mode 100644 cmd/trainer/cmd/root.go delete mode 100644 cmd/trainer/main.go delete mode 100644 manager/handlers/model.go delete mode 100644 manager/handlers/model_test.go delete mode 100644 manager/models/model.go delete mode 100644 manager/service/model.go delete mode 100644 pkg/rpc/inference/client/client_v1.go delete mode 100644 pkg/rpc/inference/client/mocks/client_v1_mock.go delete mode 100644 pkg/rpc/trainer/client/client_v1.go delete mode 100644 pkg/rpc/trainer/client/mocks/client_v1_mock.go delete mode 100644 pkg/rpc/trainer/server/server.go delete mode 100644 trainer/config/config.go delete mode 100644 trainer/config/config_test.go delete mode 100644 trainer/config/constants.go delete mode 100644 trainer/config/testdata/ca.crt delete mode 100644 trainer/config/testdata/trainer.yaml delete mode 100644 trainer/metrics/metrics.go delete mode 100644 trainer/metrics/metrics_test.go delete mode 100644 trainer/rpcserver/rpcserver.go delete mode 100644 trainer/rpcserver/rpcserver_test.go delete mode 100644 trainer/rpcserver/trainer_server_v1.go delete mode 100644 trainer/service/service_v1.go delete mode 100644 trainer/service/service_v1_test.go delete mode 100644 trainer/storage/mocks/storage_mock.go delete mode 100644 trainer/storage/storage.go delete mode 100644 trainer/storage/storage_test.go delete mode 100644 trainer/storage/testdata/download.csv delete mode 100644 trainer/storage/testdata/networktopology.csv delete mode 100644 trainer/trainer.go delete mode 100644 trainer/training/mocks/training_mock.go delete mode 100644 trainer/training/training.go diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 5fe1763151e..dc4632b4c89 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - module: ["manager", "scheduler", "dfdaemon", "trainer"] + module: ["manager", "scheduler", "dfdaemon"] include: - module: manager platforms: linux/amd64,linux/arm64 @@ -21,8 +21,6 @@ jobs: platforms: linux/amd64,linux/arm64 - module: dfdaemon platforms: linux/amd64,linux/arm64 - - module: trainer - platforms: linux/amd64,linux/arm64 timeout-minutes: 120 steps: - name: Check out code diff --git a/.goreleaser.yml b/.goreleaser.yml index 6e12952d784..e38fdaa4986 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -109,26 +109,6 @@ builds: env: - CGO_ENABLED=0 - - main: ./cmd/trainer - id: trainer - binary: trainer - goos: - - linux - - darwin - goarch: - - amd64 - - arm64 - ldflags: - - -X d7y.io/dragonfly/v2/version.Major={{ .Major }} - - -X d7y.io/dragonfly/v2/version.Minor={{ .Minor }} - - -X d7y.io/dragonfly/v2/version.GitVersion={{ .Tag }} - - -X d7y.io/dragonfly/v2/version.GitCommit={{ .ShortCommit }} - - -X d7y.io/dragonfly/v2/version.BuildTime={{ .Date }} - - -X "d7y.io/dragonfly/v2/version.Gotags=none" - - -X "d7y.io/dragonfly/v2/version.Gogcflags=none" - env: - - CGO_ENABLED=0 - archives: - name_template: "dragonfly-{{ .Version }}-{{ .Os }}-{{ .Arch }}" format: tar.gz diff --git a/Makefile b/Makefile index 0c46f41f658..b13935f2137 100644 --- a/Makefile +++ b/Makefile @@ -34,12 +34,12 @@ build-dirs: .PHONY: build-dirs # Build dragonfly. -docker-build: docker-build-dfdaemon docker-build-scheduler docker-build-manager docker-build-trainer +docker-build: docker-build-dfdaemon docker-build-scheduler docker-build-manager @echo "Build image done." .PHONY: docker-build # Push dragonfly images. -docker-push: docker-push-dfdaemon docker-push-scheduler docker-push-manager docker-build-trainer +docker-push: docker-push-dfdaemon docker-push-scheduler docker-push-manager @echo "Push image done." .PHONY: docker-push @@ -61,12 +61,6 @@ docker-build-manager: ./hack/docker-build.sh manager .PHONY: docker-build-manager -# Build trainer image. -docker-build-trainer: - @echo "Begin to use docker build trainer image." - ./hack/docker-build.sh trainer -.PHONY: docker-build-trainer - # Build testing tools image. docker-build-testing-tools: build-dirs @echo "Begin to testing tools image." @@ -91,14 +85,8 @@ docker-push-manager: docker-build-manager ./hack/docker-push.sh manager .PHONY: docker-push-manager -# Push trainer image. -docker-push-trainer: docker-build-trainer - @echo "Begin to push trainer docker image." - ./hack/docker-push.sh trainer -.PHONY: docker-push-trainer - # Build dragonfly. -build: build-manager build-scheduler build-trainer build-dfget build-dfcache build-dfstore +build: build-manager build-scheduler build-dfget build-dfcache build-dfstore .PHONY: build # Build dfget. @@ -161,12 +149,6 @@ build-manager-console: build-dirs ./hack/build.sh manager-console .PHONY: build-manager-console -# Build trainer. -build-trainer: build-dirs - @echo "Begin to build trainer." - ./hack/build.sh trainer -.PHONY: build-trainer - # Install dfget. install-dfget: @echo "Begin to install dfget." @@ -185,12 +167,6 @@ install-manager: ./hack/install.sh install manager .PHONY: install-manager -# Install trainer. -install-trainer: - @echo "Begin to install trainer." - ./hack/install.sh install trainer -.PHONY: install-trainer - # Build rpm dfget. build-rpm-dfget: build-linux-dfget @echo "Begin to build rpm dfget." @@ -373,7 +349,7 @@ clean-e2e-test: .PHONY: clean-e2e-test # Kind load dragonfly. -kind-load: kind-load-scheduler kind-load-dfdaemon kind-load-manager kind-load-trainer kind-load-testing-tools +kind-load: kind-load-scheduler kind-load-dfdaemon kind-load-manager kind-load-testing-tools @echo "Kind load image done." .PHONY: kind-load @@ -392,11 +368,6 @@ kind-load-manager: @./hack/kind-load.sh manager .PHONY: kind-load-manager -# Run kind load docker trainer. -kind-load-trainer: - @./hack/kind-load.sh trainer -.PHONY: kind-load-trainer - # Run kind load docker testing tools. kind-load-testing-tools: @./hack/kind-load.sh no-content-length @@ -441,11 +412,9 @@ help: @echo "make docker-build-dfdaemon build dfdaemon image" @echo "make docker-build-scheduler build scheduler image" @echo "make docker-build-manager build manager image" - @echo "make docker-build-trainer build trainer image" @echo "make docker-push-dfdaemon push dfdaemon image" @echo "make docker-push-scheduler push scheduler image" @echo "make docker-push-manager push manager image" - @echo "make docker-push-trainer push trainer image" @echo "make build build dragonfly" @echo "make build-dfget build dfget" @echo "make build-linux-dfget build linux dfget" @@ -457,13 +426,11 @@ help: @echo "make build-manager build manager" @echo "make build-manager-server build manager server" @echo "make build-manager-console build manager console" - @echo "make build-trainer build trainer" @echo "make build-e2e-sha256sum build sha256sum test tool" @echo "make build-e2e-download-grpc-test build download grpc test tool" @echo "make install-dfget install dfget" @echo "make install-scheduler install scheduler" @echo "make install-manager install manager" - @echo "make install-trainer install trainer" @echo "make build-rpm-dfget build rpm dfget" @echo "make build-rpm-dfcache build rpm dfcache" @echo "make build-rpm-dfstore build rpm dfstore" @@ -485,7 +452,6 @@ help: @echo "make kind-load-scheduler kind load scheduler docker image" @echo "make kind-load-dfdaemon kind load dfdaemon docker image" @echo "make kind-load-manager kind load manager docker image" - @echo "make kind-load-trainer kind load trainer docker image" @echo "make kind-load-testing-tools kind load testing tools docker image" @echo "make lint run code lint" @echo "make markdownlint run markdown lint" diff --git a/build/images/trainer/Dockerfile b/build/images/trainer/Dockerfile deleted file mode 100644 index df2a9bc1933..00000000000 --- a/build/images/trainer/Dockerfile +++ /dev/null @@ -1,40 +0,0 @@ -ARG BASE_IMAGE=alpine:3.17 - -FROM golang:1.21.1-alpine3.17 AS builder - -ARG GOPROXY -ARG GOTAGS -ARG GOGCFLAGS - -WORKDIR /go/src/d7y.io/dragonfly/v2 - -RUN apk --no-cache add bash make gcc libc-dev git - -COPY . /go/src/d7y.io/dragonfly/v2 - -RUN make build-trainer && make install-trainer - -FROM ${BASE_IMAGE} AS health - -ENV GRPC_HEALTH_PROBE_VERSION v0.4.24 - -RUN if [ "$(uname -m)" = "ppc64le" ]; then \ - wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \ - elif [ "$(uname -m)" = "aarch64" ]; then \ - wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \ - else \ - wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \ - fi && \ - chmod +x /bin/grpc_health_probe - -FROM ${BASE_IMAGE} - -ENV PATH=/opt/dragonfly/bin:$PATH -RUN echo "hosts: files dns" > /etc/nsswitch.conf - -COPY --from=builder /opt/dragonfly/bin/trainer /opt/dragonfly/bin/trainer -COPY --from=health /bin/grpc_health_probe /bin/grpc_health_probe - -EXPOSE 9090 - -ENTRYPOINT ["/opt/dragonfly/bin/trainer"] diff --git a/cmd/trainer/cmd/root.go b/cmd/trainer/cmd/root.go deleted file mode 100644 index 666f6a6bbc4..00000000000 --- a/cmd/trainer/cmd/root.go +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cmd - -import ( - "context" - "fmt" - "os" - "path" - - "github.com/spf13/cobra" - - "d7y.io/dragonfly/v2/cmd/dependency" - logger "d7y.io/dragonfly/v2/internal/dflog" - "d7y.io/dragonfly/v2/pkg/dfpath" - "d7y.io/dragonfly/v2/pkg/types" - "d7y.io/dragonfly/v2/trainer" - "d7y.io/dragonfly/v2/trainer/config" - "d7y.io/dragonfly/v2/version" -) - -var ( - cfg *config.Config -) - -// rootCmd represents the commonv1 command when called without any subcommands. -var rootCmd = &cobra.Command{ - Use: "trainer", - Short: "the trainer of dragonfly", - Long: `Trainer is a long-running process and is mainly responsible for receiving historical download and network topology records, -preprocessing original record data, establing datasets and training machine learning and AI models that support scheduler peer-scheduling decisions.`, - Args: cobra.NoArgs, - DisableAutoGenTag: true, - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - // Convert config. - if err := cfg.Convert(); err != nil { - return err - } - - // Validate config. - if err := cfg.Validate(); err != nil { - return err - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Initialize dfpath. - d, err := initDfpath(&cfg.Server) - if err != nil { - return err - } - rotateConfig := logger.LogRotateConfig{ - MaxSize: cfg.Server.LogMaxSize, - MaxAge: cfg.Server.LogMaxAge, - MaxBackups: cfg.Server.LogMaxBackups} - - // Initialize logger. - if err := logger.InitTrainer(cfg.Verbose, cfg.Console, d.LogDir(), rotateConfig); err != nil { - return fmt.Errorf("init trainer logger: %w", err) - } - logger.RedirectStdoutAndStderr(cfg.Console, path.Join(d.LogDir(), types.TrainerName)) - - return runTrainer(ctx, d) - }, -} - -// Execute adds all child commands to the root command and sets flags appropriately. -// This is called by main.main(). It only needs to happen once to the rootCmd. -func Execute() { - if err := rootCmd.Execute(); err != nil { - logger.Error(err) - os.Exit(1) - } -} - -func init() { - // Initialize default scheduler config. - cfg = config.New() - // Initialize command and config. - dependency.InitCommandAndConfig(rootCmd, true, cfg) -} - -func initDfpath(cfg *config.ServerConfig) (dfpath.Dfpath, error) { - var options []dfpath.Option - if cfg.LogDir != "" { - options = append(options, dfpath.WithLogDir(cfg.LogDir)) - } - - if cfg.DataDir != "" { - options = append(options, dfpath.WithDataDir(cfg.DataDir)) - } - - return dfpath.New(options...) -} - -func runTrainer(ctx context.Context, d dfpath.Dfpath) error { - logger.Infof("version:\n%s", version.Version()) - - ff := dependency.InitMonitor(cfg.PProfPort, cfg.Telemetry) - defer ff() - - svr, err := trainer.New(ctx, cfg, d) - if err != nil { - return err - } - - dependency.SetupQuitSignalHandler(func() { svr.Stop() }) - return svr.Serve() -} diff --git a/cmd/trainer/main.go b/cmd/trainer/main.go deleted file mode 100644 index 5478f597a28..00000000000 --- a/cmd/trainer/main.go +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import "d7y.io/dragonfly/v2/cmd/trainer/cmd" - -func main() { - cmd.Execute() -} diff --git a/go.mod b/go.mod index 44f213647c5..c3750d6649d 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module d7y.io/dragonfly/v2 go 1.21 require ( - d7y.io/api/v2 v2.0.148 + d7y.io/api/v2 v2.0.154 github.com/MysteriousPotato/go-lockable v1.0.0 github.com/RichardKnop/machinery v1.10.8 github.com/Showmax/go-fqdn v1.0.0 diff --git a/go.sum b/go.sum index bbd3d266732..e2a40f1c63e 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,8 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= -d7y.io/api/v2 v2.0.148 h1:11waj+EuaHdx95Fkr3hXJJckNGw9Hu5U0ohtCbpIirw= -d7y.io/api/v2 v2.0.148/go.mod h1:hyEaaIglThVWRHODk2yHN/tpa1L+/nPgdQFwPsL6fNc= +d7y.io/api/v2 v2.0.154 h1:IBCV+c1PYFIWyE/Otj5AsFGi5+s7TWcujOpzDTt1P5c= +d7y.io/api/v2 v2.0.154/go.mod h1:hyEaaIglThVWRHODk2yHN/tpa1L+/nPgdQFwPsL6fNc= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/Azure/azure-sdk-for-go v16.2.1+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U= diff --git a/hack/build.sh b/hack/build.sh index a3ff7a1ebd4..ea946a08c30 100755 --- a/hack/build.sh +++ b/hack/build.sh @@ -9,7 +9,6 @@ DFCACHE_BINARY_NAME=dfcache DFSTORE_BINARY_NAME=dfstore SCHEDULER_BINARY_NAME=scheduler MANAGER_BINARY_NAME=manager -TRAINER_BINARY_NAME=trainer PKG=d7y.io/dragonfly/v2 BUILD_IMAGE=golang:1.21.1-alpine3.17 @@ -70,10 +69,6 @@ build-manager-local() { build-local ${MANAGER_BINARY_NAME} manager } -build-trainer-local() { - build-local ${TRAINER_BINARY_NAME} trainer -} - build-docker() { cd "${BUILD_SOURCE_HOME}" || return docker run \ @@ -129,10 +124,6 @@ build-manager-console() { cp -r $CONSOLE_ASSETS $MANAGER_ASSETS_DIR } -build-trainer-docker() { - build-docker ${TRAINER_BINARY_NAME} trainer -} - main() { create-dirs if [[ "1" == "${USE_DOCKER}" ]]; then @@ -150,9 +141,6 @@ main() { scheduler) build-scheduler-docker ;; - trainer) - build-trainer-docker - ;; manager) build-manager-docker ;; @@ -165,7 +153,6 @@ main() { build-dfstore-docker build-scheduler-docker build-manager-docker - build-trainer-docker ;; esac else @@ -183,9 +170,6 @@ main() { scheduler) build-scheduler-local ;; - trainer) - build-trainer-local - ;; manager) build-manager-local ;; @@ -198,7 +182,6 @@ main() { build-dfstore-local build-scheduler-local build-manager-local - build-trainer-local ;; esac fi diff --git a/hack/docker-build.sh b/hack/docker-build.sh index 68408bf6c35..bd6f413c962 100755 --- a/hack/docker-build.sh +++ b/hack/docker-build.sh @@ -55,9 +55,6 @@ main() { manager) git-submodule docker-build manager - ;; - trainer) - docker-build trainer esac } diff --git a/hack/docker-push.sh b/hack/docker-push.sh index 7f0fed5ae01..e2b9e7d504f 100755 --- a/hack/docker-push.sh +++ b/hack/docker-push.sh @@ -23,9 +23,6 @@ main() { ;; manager) docker-push manager - ;; - trainer) - docker-push trainer esac } diff --git a/hack/install.sh b/hack/install.sh index 09856ca3c7c..21a3f913cda 100755 --- a/hack/install.sh +++ b/hack/install.sh @@ -8,7 +8,6 @@ BIN_DIR="../bin" DFGET_BINARY_NAME=dfget SCHEDULER_BINARY_NAME=scheduler MANAGER_BINARY_NAME=manager -TRAINER_BINARY_NAME=trainer curDir=$(cd "$(dirname "$0")" && pwd) cd "${curDir}" || return @@ -25,9 +24,6 @@ install() { ;; manager) install-manager - ;; - trainer) - install-trainer esac } @@ -76,21 +72,6 @@ uninstall-manager() { test -e /usr/local/bin/manager && unlink /usr/local/bin/manager } -install-trainer() { - local bin="${INSTALL_HOME}/${INSTALL_BIN_PATH}" - echo "install: ${bin}" - mkdir -p "${bin}" - - cp "${BIN_DIR}/${GOOS}_${GOARCH}/${TRAINER_BINARY_NAME}" "${bin}" - - createLink "${bin}/${TRAINER_BINARY_NAME}" /usr/local/bin/trainer -} - -uninstall-trainer() { - echo "unlink /usr/local/bin/trainer" - test -e /usr/local/bin/trainer && unlink /usr/local/bin/trainer -} - createLink() { srcPath="$1" linkPath="$2" diff --git a/hack/kind-load.sh b/hack/kind-load.sh index 777b1289825..044ee1c08b9 100755 --- a/hack/kind-load.sh +++ b/hack/kind-load.sh @@ -22,9 +22,6 @@ main() { manager) kind-load manager ;; - trainer) - kind-load trainer - ;; no-content-length) kind-load no-content-length esac diff --git a/internal/dflog/loginit.go b/internal/dflog/loginit.go index 740cb606736..c3dc16489c7 100644 --- a/internal/dflog/loginit.go +++ b/internal/dflog/loginit.go @@ -157,26 +157,6 @@ func InitDfcache(console bool, dir string, rotateConfig LogRotateConfig) error { return createFileLogger(console, meta, logDir, rotateConfig) } -func InitTrainer(verbose, console bool, dir string, rotateConfig LogRotateConfig) error { - if console { - return createConsoleLogger(verbose) - } - - logDir := filepath.Join(dir, types.TrainerName) - var meta = []logInitMeta{ - { - fileName: CoreLogFileName, - setSugaredLoggerFunc: SetCoreLogger, - }, - { - fileName: GrpcLogFileName, - setSugaredLoggerFunc: SetGrpcLogger, - }, - } - - return createFileLogger(console, meta, logDir, rotateConfig) -} - func createConsoleLogger(verbose bool) error { levels = nil config := zap.NewDevelopmentConfig() diff --git a/manager/config/config.go b/manager/config/config.go index 5c5e300d48b..7085076976e 100644 --- a/manager/config/config.go +++ b/manager/config/config.go @@ -60,9 +60,6 @@ type Config struct { // Network configuration. Network NetworkConfig `yaml:"network" mapstructure:"network"` - - // Trainer configuration. - Trainer TrainerConfig `yaml:"trainer" mapstructure:"trainer"` } type ServerConfig struct { @@ -384,14 +381,6 @@ type NetworkConfig struct { EnableIPv6 bool `mapstructure:"enableIPv6" yaml:"enableIPv6"` } -type TrainerConfig struct { - // Enable trainer service. - Enable bool `yaml:"enable" mapstructure:"enable"` - - // BucketName is the object storage bucket name of model. - BucketName string `yaml:"bucketName" mapstructure:"bucketName"` -} - // New config instance. func New() *Config { return &Config{ @@ -476,10 +465,6 @@ func New() *Config { Network: NetworkConfig{ EnableIPv6: DefaultNetworkEnableIPv6, }, - Trainer: TrainerConfig{ - Enable: false, - BucketName: DefaultTrainerBucketName, - }, } } @@ -689,11 +674,6 @@ func (cfg *Config) Validate() error { } } - if cfg.Trainer.Enable { - if cfg.Trainer.BucketName == "" { - return errors.New("trainer requires parameter bucketName") - } - } return nil } diff --git a/manager/config/config_test.go b/manager/config/config_test.go index 9dc260e7983..dd80e914091 100644 --- a/manager/config/config_test.go +++ b/manager/config/config_test.go @@ -103,11 +103,6 @@ var ( ValidityPeriod: DefaultCertValidityPeriod, }, } - - mockTrainerConfig = TrainerConfig{ - Enable: true, - BucketName: DefaultTrainerBucketName, - } ) func TestConfig_Load(t *testing.T) { @@ -230,10 +225,6 @@ func TestConfig_Load(t *testing.T) { Network: NetworkConfig{ EnableIPv6: true, }, - Trainer: TrainerConfig{ - Enable: true, - BucketName: "models", - }, } managerConfigYAML := &Config{} @@ -944,23 +935,6 @@ func TestConfig_Validate(t *testing.T) { assert.EqualError(err, "certSpec requires parameter validityPeriod") }, }, - { - name: "trainer requires parameter bucketName", - config: New(), - mock: func(cfg *Config) { - cfg.Auth.JWT = mockJWTConfig - cfg.Database.Type = DatabaseTypeMysql - cfg.Database.Mysql = mockMysqlConfig - cfg.Database.Redis = mockRedisConfig - cfg.Security = mockSecurityConfig - cfg.Trainer = mockTrainerConfig - cfg.Trainer.BucketName = "" - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "trainer requires parameter bucketName") - }, - }, } for _, tc := range tests { diff --git a/manager/config/constants.go b/manager/config/constants.go index 75e6ad8e0a7..55976d6a4b4 100644 --- a/manager/config/constants.go +++ b/manager/config/constants.go @@ -148,8 +148,3 @@ var ( // DefaultNetworkEnableIPv6 is default value of enableIPv6. DefaultNetworkEnableIPv6 = false ) - -var ( - // DefaultTrainerBucketName is default object storage bucket name of model. - DefaultTrainerBucketName = "models" -) diff --git a/manager/config/testdata/manager.yaml b/manager/config/testdata/manager.yaml index 884fda293df..40a56e454cd 100644 --- a/manager/config/testdata/manager.yaml +++ b/manager/config/testdata/manager.yaml @@ -101,7 +101,3 @@ metrics: network: enableIPv6: true - -trainer: - enable: true - bucketName: models diff --git a/manager/database/database.go b/manager/database/database.go index 90ce45be5c9..38a5be637aa 100644 --- a/manager/database/database.go +++ b/manager/database/database.go @@ -92,7 +92,6 @@ func migrate(db *gorm.DB) error { &models.Oauth{}, &models.Config{}, &models.Application{}, - &models.Model{}, &models.PersonalAccessToken{}, &models.Peer{}, ) diff --git a/manager/handlers/model.go b/manager/handlers/model.go deleted file mode 100644 index c80d28f2ed0..00000000000 --- a/manager/handlers/model.go +++ /dev/null @@ -1,125 +0,0 @@ -package handlers - -import ( - "net/http" - - "github.com/gin-gonic/gin" - - _ "d7y.io/dragonfly/v2/manager/models" // nolint - "d7y.io/dragonfly/v2/manager/types" -) - -// @Summary Destroy Model -// @Description Destroy by id -// @Tags Model -// @Accept json -// @Produce json -// @Param id path string true "id" -// @Success 200 -// @Failure 400 -// @Failure 404 -// @Failure 500 -// @Router /models/{id} [delete] -func (h *Handlers) DestroyModel(ctx *gin.Context) { - var params types.ModelParams - if err := ctx.ShouldBindUri(¶ms); err != nil { - ctx.JSON(http.StatusUnprocessableEntity, gin.H{"errors": err.Error()}) - return - } - - if err := h.service.DestroyModel(ctx.Request.Context(), params.ID); err != nil { - ctx.Error(err) // nolint: errcheck - return - } - - ctx.Status(http.StatusOK) -} - -// @Summary Update Model -// @Description Update by json config -// @Tags Model -// @Accept json -// @Produce json -// @Param id path string true "id" -// @Param Model body types.UpdateModelRequest true "Model" -// @Success 200 {object} models.Model -// @Failure 400 -// @Failure 404 -// @Failure 500 -// @Router /models/{id} [patch] -func (h *Handlers) UpdateModel(ctx *gin.Context) { - var params types.ModelParams - if err := ctx.ShouldBindUri(¶ms); err != nil { - ctx.JSON(http.StatusUnprocessableEntity, gin.H{"errors": err.Error()}) - return - } - - var json types.UpdateModelRequest - if err := ctx.ShouldBindJSON(&json); err != nil { - ctx.JSON(http.StatusUnprocessableEntity, gin.H{"errors": err.Error()}) - return - } - - model, err := h.service.UpdateModel(ctx.Request.Context(), params.ID, json) - if err != nil { - ctx.Error(err) // nolint: errcheck - return - } - - ctx.JSON(http.StatusOK, model) -} - -// @Summary Get Model -// @Description Get Model by id -// @Tags Model -// @Accept json -// @Produce json -// @Param id path string true "id" -// @Success 200 {object} models.Model -// @Failure 400 -// @Failure 404 -// @Failure 500 -// @Router /models/{id} [get] -func (h *Handlers) GetModel(ctx *gin.Context) { - var params types.ModelParams - if err := ctx.ShouldBindUri(¶ms); err != nil { - ctx.JSON(http.StatusUnprocessableEntity, gin.H{"errors": err.Error()}) - return - } - - model, err := h.service.GetModel(ctx.Request.Context(), params.ID) - if err != nil { - ctx.Error(err) // nolint: errcheck - return - } - - ctx.JSON(http.StatusOK, model) -} - -// @Summary Get Models -// @Description Get Models -// @Tags Model -// @Accept json -// @Produce json -// @Success 200 {object} []models.Model -// @Failure 400 -// @Failure 404 -// @Failure 500 -// @Router /models [get] -func (h *Handlers) GetModels(ctx *gin.Context) { - var query types.GetModelsQuery - if err := ctx.ShouldBindQuery(&query); err != nil { - ctx.JSON(http.StatusUnprocessableEntity, gin.H{"errors": err.Error()}) - return - } - - h.setPaginationDefault(&query.Page, &query.PerPage) - models, count, err := h.service.GetModels(ctx.Request.Context(), query) - if err != nil { - ctx.Error(err) // nolint: errcheck - return - } - - h.setPaginationLinkHeader(ctx, query.Page, query.PerPage, int(count)) - ctx.JSON(http.StatusOK, models) -} diff --git a/manager/handlers/model_test.go b/manager/handlers/model_test.go deleted file mode 100644 index baafc56dff3..00000000000 --- a/manager/handlers/model_test.go +++ /dev/null @@ -1,268 +0,0 @@ -/* - * Copyright 2024 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package handlers - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" - - "d7y.io/dragonfly/v2/manager/models" - "d7y.io/dragonfly/v2/manager/service/mocks" - "d7y.io/dragonfly/v2/manager/types" -) - -var ( - mockModelReqBody = ` - { - "bio": "bio", - "state": "active" - }` - mockUpdateModelRequest = types.UpdateModelRequest{ - BIO: "bio", - State: "active", - } - mockModel = &models.Model{ - BaseModel: mockBaseModel, - Name: "name", - Type: "type", - BIO: "bio", - Version: "version", - State: "state", - Evaluation: nil, - SchedulerID: 8, - Scheduler: models.Scheduler{}, - } -) - -func mockModelRouter(h *Handlers) *gin.Engine { - r := gin.Default() - apiv1 := r.Group("/api/v1") - model := apiv1.Group("/models") - model.DELETE(":id", h.DestroyModel) - model.PATCH(":id", h.UpdateModel) - model.GET(":id", h.GetModel) - model.GET("", h.GetModels) - return r -} - -func TestHandlers_DestroyModel(t *testing.T) { - tests := []struct { - name string - req *http.Request - mock func(ms *mocks.MockServiceMockRecorder) - expect func(t *testing.T, w *httptest.ResponseRecorder) - }{ - { - name: "unprocessable entity", - req: httptest.NewRequest(http.MethodDelete, "/api/v1/models/test", nil), - mock: func(ms *mocks.MockServiceMockRecorder) {}, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusUnprocessableEntity, w.Code) - }, - }, - { - name: "success", - req: httptest.NewRequest(http.MethodDelete, "/api/v1/models/2", nil), - mock: func(ms *mocks.MockServiceMockRecorder) { - ms.DestroyModel(gomock.Any(), gomock.Eq(uint(2))).Return(nil).Times(1) - }, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusOK, w.Code) - }, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - svc := mocks.NewMockService(ctl) - w := httptest.NewRecorder() - h := New(svc) - mockRouter := mockModelRouter(h) - - tc.mock(svc.EXPECT()) - mockRouter.ServeHTTP(w, tc.req) - tc.expect(t, w) - }) - } -} - -func TestHandlers_UpdateModel(t *testing.T) { - tests := []struct { - name string - req *http.Request - mock func(ms *mocks.MockServiceMockRecorder) - expect func(t *testing.T, w *httptest.ResponseRecorder) - }{ - { - name: "unprocessable entity caused by uri", - req: httptest.NewRequest(http.MethodPatch, "/api/v1/models/test", nil), - mock: func(ms *mocks.MockServiceMockRecorder) {}, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusUnprocessableEntity, w.Code) - }, - }, - { - name: "unprocessable entity caused by body", - req: httptest.NewRequest(http.MethodPatch, "/api/v1/models/2", nil), - mock: func(ms *mocks.MockServiceMockRecorder) {}, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusUnprocessableEntity, w.Code) - }, - }, - { - name: "success", - req: httptest.NewRequest(http.MethodPatch, "/api/v1/models/2", strings.NewReader(mockModelReqBody)), - mock: func(ms *mocks.MockServiceMockRecorder) { - ms.UpdateModel(gomock.Any(), gomock.Eq(uint(2)), gomock.Eq(mockUpdateModelRequest)).Return(mockModel, nil).Times(1) - }, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusOK, w.Code) - model := models.Model{} - err := json.Unmarshal(w.Body.Bytes(), &model) - assert.NoError(err) - assert.Equal(mockModel, &model) - }, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - svc := mocks.NewMockService(ctl) - w := httptest.NewRecorder() - h := New(svc) - mockRouter := mockModelRouter(h) - - tc.mock(svc.EXPECT()) - mockRouter.ServeHTTP(w, tc.req) - tc.expect(t, w) - }) - } -} - -func TestHandlers_GetModel(t *testing.T) { - tests := []struct { - name string - req *http.Request - mock func(ms *mocks.MockServiceMockRecorder) - expect func(t *testing.T, w *httptest.ResponseRecorder) - }{ - { - name: "unprocessable entity", - req: httptest.NewRequest(http.MethodGet, "/api/v1/models/test", nil), - mock: func(ms *mocks.MockServiceMockRecorder) {}, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusUnprocessableEntity, w.Code) - }, - }, - { - name: "success", - req: httptest.NewRequest(http.MethodGet, "/api/v1/models/2", nil), - mock: func(ms *mocks.MockServiceMockRecorder) { - ms.GetModel(gomock.Any(), gomock.Eq(uint(2))).Return(mockModel, nil).Times(1) - }, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusOK, w.Code) - model := models.Model{} - err := json.Unmarshal(w.Body.Bytes(), &model) - assert.NoError(err) - assert.Equal(mockModel, &model) - }, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - svc := mocks.NewMockService(ctl) - w := httptest.NewRecorder() - h := New(svc) - mockRouter := mockModelRouter(h) - - tc.mock(svc.EXPECT()) - mockRouter.ServeHTTP(w, tc.req) - tc.expect(t, w) - }) - } -} - -func TestHandlers_GetModels(t *testing.T) { - tests := []struct { - name string - req *http.Request - mock func(ms *mocks.MockServiceMockRecorder) - expect func(t *testing.T, w *httptest.ResponseRecorder) - }{ - { - name: "unprocessable entity", - req: httptest.NewRequest(http.MethodGet, "/api/v1/models?page=-1", nil), - mock: func(ms *mocks.MockServiceMockRecorder) {}, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusUnprocessableEntity, w.Code) - }, - }, - { - name: "success", - req: httptest.NewRequest(http.MethodGet, "/api/v1/models?name=foo", nil), - mock: func(ms *mocks.MockServiceMockRecorder) { - ms.GetModels(gomock.Any(), gomock.Eq(types.GetModelsQuery{ - Name: "", - Page: 1, - PerPage: 10, - })).Return([]models.Model{*mockModel}, int64(1), nil).Times(1) - }, - expect: func(t *testing.T, w *httptest.ResponseRecorder) { - assert := assert.New(t) - assert.Equal(http.StatusOK, w.Code) - model := models.Model{} - err := json.Unmarshal(w.Body.Bytes()[1:w.Body.Len()-1], &model) - assert.NoError(err) - assert.Equal(mockModel, &model) - }, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - svc := mocks.NewMockService(ctl) - w := httptest.NewRecorder() - h := New(svc) - mockRouter := mockModelRouter(h) - - tc.mock(svc.EXPECT()) - mockRouter.ServeHTTP(w, tc.req) - tc.expect(t, w) - }) - } -} diff --git a/manager/models/model.go b/manager/models/model.go deleted file mode 100644 index 898a7df84ee..00000000000 --- a/manager/models/model.go +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package models - -const ( - // ModelVersionStateActive represents the model version - // whose state is active and the model version currently being used. - ModelVersionStateActive = "active" - - // ModelVersionStateInactive represents the model version - // whose state is inactive and the model version currently being not used. - ModelVersionStateInactive = "inactive" - - // ModelTypeGNN represents the model type is GNN. - ModelTypeGNN = "gnn" - - // ModelTypeMLP represents the model type is MLP. - ModelTypeMLP = "mlp" -) - -// TODO(Gaius) Add regression analysis parameters. -type Model struct { - BaseModel - Name string `gorm:"column:name;type:varchar(256);not null;comment:name" json:"name"` - Type string `gorm:"column:type;type:varchar(256);index:uk_model,unique;not null;comment:type" json:"type"` - BIO string `gorm:"column:bio;type:varchar(1024);comment:biography" json:"bio"` - Version string `gorm:"column:version;type:varchar(256);index:uk_model,unique;not null;comment:model version" json:"version"` - State string `gorm:"column:state;type:varchar(256);default:'inactive';comment:model state" json:"state"` - Evaluation JSONMap `gorm:"column:evaluation;comment:evaluation metrics" json:"evaluation"` - SchedulerID uint `gorm:"index:uk_model,unique;not null;comment:scheduler id" json:"scheduler_id"` - Scheduler Scheduler `json:"scheduler"` -} diff --git a/manager/models/scheduler.go b/manager/models/scheduler.go index c51b5bed4cf..91784e607ce 100644 --- a/manager/models/scheduler.go +++ b/manager/models/scheduler.go @@ -35,5 +35,4 @@ type Scheduler struct { Features Array `gorm:"column:features;comment:feature flags" json:"features"` SchedulerClusterID uint `gorm:"index:uk_scheduler,unique;not null;comment:scheduler cluster id" json:"scheduler_cluster_id"` SchedulerCluster SchedulerCluster `json:"scheduler_cluster"` - Models []Model `json:"models"` } diff --git a/manager/router/router.go b/manager/router/router.go index 6792f18a93c..1bf14a5f2e3 100644 --- a/manager/router/router.go +++ b/manager/router/router.go @@ -216,13 +216,6 @@ func Init(cfg *config.Config, logDir string, service service.Service, database * cs.GET(":id", h.GetApplication) cs.GET("", h.GetApplications) - // Model. - model := apiv1.Group("/models", jwt.MiddlewareFunc(), rbac) - model.DELETE(":id", h.DestroyModel) - model.PATCH(":id", h.UpdateModel) - model.GET(":id", h.GetModel) - model.GET("", h.GetModels) - // Personal Access Token. pat := apiv1.Group("/personal-access-tokens", jwt.MiddlewareFunc(), rbac) pat.POST("", h.CreatePersonalAccessToken) diff --git a/manager/rpcserver/manager_server_v1.go b/manager/rpcserver/manager_server_v1.go index abab6775176..c562dd10936 100644 --- a/manager/rpcserver/manager_server_v1.go +++ b/manager/rpcserver/manager_server_v1.go @@ -17,24 +17,18 @@ package rpcserver import ( - "bytes" "context" "encoding/json" "errors" - "fmt" "io" - "strings" - "time" cachev9 "github.com/go-redis/cache/v9" "github.com/redis/go-redis/v9" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/emptypb" "gorm.io/gorm" commonv1 "d7y.io/api/v2/pkg/apis/common/v1" - inference "d7y.io/api/v2/pkg/apis/inference" managerv1 "d7y.io/api/v2/pkg/apis/manager/v1" logger "d7y.io/dragonfly/v2/internal/dflog" @@ -45,12 +39,9 @@ import ( "d7y.io/dragonfly/v2/manager/models" "d7y.io/dragonfly/v2/manager/searcher" "d7y.io/dragonfly/v2/manager/types" - "d7y.io/dragonfly/v2/pkg/digest" - "d7y.io/dragonfly/v2/pkg/idgen" "d7y.io/dragonfly/v2/pkg/objectstorage" pkgredis "d7y.io/dragonfly/v2/pkg/redis" "d7y.io/dragonfly/v2/pkg/slices" - "d7y.io/dragonfly/v2/pkg/structure" ) // managerServerV1 is v1 version of the manager grpc server. @@ -798,159 +789,6 @@ func (s *managerServerV1) ListApplications(ctx context.Context, req *managerv1.L return &pbListApplicationsResponse, nil } -// CreateModel creates model and update data of model to object storage. -func (s *managerServerV1) CreateModel(ctx context.Context, req *managerv1.CreateModelRequest) (*emptypb.Empty, error) { - log := logger.WithHostnameAndIP(req.GetHostname(), req.GetIp()) - - if !s.config.ObjectStorage.Enable { - log.Warn("object storage is disabled") - return nil, status.Error(codes.Internal, "object storage is disabled") - } - - // Create model bucket, if not exist. - if err := s.createModelBucket(ctx); err != nil { - log.Error(err) - return nil, status.Error(codes.Internal, err.Error()) - } - - var ( - name string - typ string - evaluation types.ModelEvaluation - version = time.Now().Nanosecond() - ) - switch createModelRequest := req.GetRequest().(type) { - case *managerv1.CreateModelRequest_CreateGnnRequest: - name = idgen.GNNModelIDV1(req.GetIp(), req.GetHostname()) - typ = models.ModelTypeGNN - evaluation = types.ModelEvaluation{ - Precision: createModelRequest.CreateGnnRequest.GetPrecision(), - Recall: createModelRequest.CreateGnnRequest.GetRecall(), - F1Score: createModelRequest.CreateGnnRequest.GetF1Score(), - } - - // Update GNN model config to object storage. - if err := s.createModelConfig(ctx, name); err != nil { - log.Error(err) - return nil, status.Error(codes.Internal, err.Error()) - } - - // Upload GNN model file to object storage. - data := createModelRequest.CreateGnnRequest.GetData() - dgst := digest.New(digest.AlgorithmSHA256, digest.SHA256FromBytes(data)) - if err := s.objectStorage.PutObject(ctx, s.config.Trainer.BucketName, - types.MakeObjectKeyOfModelFile(name, version), dgst.String(), bytes.NewReader(data)); err != nil { - log.Error(err) - return nil, status.Error(codes.Internal, err.Error()) - } - case *managerv1.CreateModelRequest_CreateMlpRequest: - name = idgen.MLPModelIDV1(req.GetHostname(), req.GetIp()) - typ = models.ModelTypeMLP - evaluation = types.ModelEvaluation{ - MSE: createModelRequest.CreateMlpRequest.GetMse(), - MAE: createModelRequest.CreateMlpRequest.GetMae(), - } - - // Update MLP model config to object storage. - if err := s.createModelConfig(ctx, name); err != nil { - log.Error(err) - return nil, status.Error(codes.Internal, err.Error()) - } - - // Upload MLP model file to object storage. - data := createModelRequest.CreateMlpRequest.GetData() - dgst := digest.New(digest.AlgorithmSHA256, digest.SHA256FromBytes(data)) - if err := s.objectStorage.PutObject(ctx, s.config.Trainer.BucketName, - types.MakeObjectKeyOfModelFile(name, version), dgst.String(), bytes.NewReader(data)); err != nil { - log.Error(err) - return nil, status.Error(codes.Internal, err.Error()) - } - default: - msg := fmt.Sprintf("receive unknow request: %#v", createModelRequest) - log.Error(msg) - return nil, status.Error(codes.FailedPrecondition, msg) - } - - scheduler := models.Scheduler{} - if err := s.db.WithContext(ctx).First(&scheduler, &models.Scheduler{ - Hostname: req.Hostname, - IP: req.Ip, - }).Error; err != nil { - log.Error(err) - return nil, status.Error(codes.Internal, err.Error()) - } - - rawEvaluation, err := structure.StructToMap(evaluation) - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - - // Create model in database. - if err := s.db.WithContext(ctx).Model(&scheduler).Association("Models").Append(&models.Model{ - Name: name, - Type: typ, - Version: fmt.Sprint(version), - State: models.ModelVersionStateInactive, - Evaluation: rawEvaluation, - }); err != nil { - log.Error(err) - return nil, status.Error(codes.Internal, err.Error()) - } - - return new(emptypb.Empty), nil -} - -// createModelBucket creates model bucket if not exist. -func (s *managerServerV1) createModelBucket(ctx context.Context) error { - // Check bucket exist. - isExist, err := s.objectStorage.IsBucketExist(ctx, s.config.Trainer.BucketName) - if err != nil { - return err - } - - // Create bucket if not exist. - if !isExist { - if err := s.objectStorage.CreateBucket(ctx, s.config.Trainer.BucketName); err != nil { - return err - } - } - - return nil -} - -// createModelConfig creates model config to object storage. -func (s *managerServerV1) createModelConfig(ctx context.Context, name string) error { - objectKey := types.MakeObjectKeyOfModelConfigFile(name) - isExist, err := s.objectStorage.IsObjectExist(ctx, s.config.Trainer.BucketName, objectKey) - if err != nil { - return err - } - - // If the model config already exists, skip it. - if isExist { - return nil - } - - // If the model config does not exist, create a new model config. - pbModelConfig := inference.ModelConfig{ - Name: name, - Platform: types.DefaultTritonPlatform, - VersionPolicy: &inference.ModelVersionPolicy{ - PolicyChoice: &inference.ModelVersionPolicy_Specific_{ - Specific: &inference.ModelVersionPolicy_Specific{}, - }, - }, - } - - dgst := digest.New(digest.AlgorithmSHA256, digest.SHA256FromStrings(pbModelConfig.String())) - if err := s.objectStorage.PutObject(ctx, s.config.Trainer.BucketName, - types.MakeObjectKeyOfModelConfigFile(name), dgst.String(), strings.NewReader(pbModelConfig.String())); err != nil { - return err - } - - return nil -} - // KeepAlive with manager. func (s *managerServerV1) KeepAlive(stream managerv1.Manager_KeepAliveServer) error { req, err := stream.Recv() diff --git a/manager/service/model.go b/manager/service/model.go deleted file mode 100644 index b3d63171bdb..00000000000 --- a/manager/service/model.go +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package service - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "strconv" - "strings" - - inference "d7y.io/api/v2/pkg/apis/inference" - - "d7y.io/dragonfly/v2/manager/models" - "d7y.io/dragonfly/v2/manager/types" - "d7y.io/dragonfly/v2/pkg/digest" -) - -func (s *service) DestroyModel(ctx context.Context, id uint) error { - model := models.Model{} - if err := s.db.WithContext(ctx).First(&model, id).Error; err != nil { - return err - } - - // If the model is active, return an error. - if model.State == models.ModelVersionStateActive { - return errors.New("cannot delete an active model") - } - - version, err := strconv.Atoi(model.Version) - if err != nil { - return err - } - - if err := s.objectStorage.DeleteObject(ctx, s.config.Trainer.BucketName, types.MakeObjectKeyOfModelFile(model.Name, version)); err != nil { - return err - } - - if err := s.db.WithContext(ctx).Unscoped().Delete(&models.Model{}, id).Error; err != nil { - return err - } - - return nil -} - -func (s *service) UpdateModel(ctx context.Context, id uint, json types.UpdateModelRequest) (*models.Model, error) { - model := models.Model{} - if err := s.db.WithContext(ctx).First(&model, id).Error; err != nil { - return nil, err - } - - // If the model is active, update the model config and - // update the model state. - if json.State == models.ModelVersionStateActive { - if err := s.updateModelStateToActive(ctx, &model); err != nil { - return nil, err - } - } - - // Update the model. - if err := s.db.WithContext(ctx).Model(&model).Updates(models.Model{ - BIO: json.BIO, - }).Error; err != nil { - return nil, err - } - - return &model, nil -} - -func (s *service) GetModel(ctx context.Context, id uint) (*models.Model, error) { - model := models.Model{} - if err := s.db.WithContext(ctx).First(&model, id).Error; err != nil { - return nil, err - } - - return &model, nil -} - -func (s *service) GetModels(ctx context.Context, q types.GetModelsQuery) ([]models.Model, int64, error) { - var count int64 - var model []models.Model - if err := s.db.WithContext(ctx).Scopes(models.Paginate(q.Page, q.PerPage)).Where(&models.Model{ - Type: q.Type, - Version: q.Version, - SchedulerID: q.SchedulerID, - }).Find(&model).Limit(-1).Offset(-1).Count(&count).Error; err != nil { - return nil, 0, err - } - - return model, count, nil -} - -func (s *service) updateModelStateToActive(ctx context.Context, model *models.Model) error { - version, err := strconv.ParseInt(model.Version, 10, 64) - if err != nil { - return err - } - - // Update the model config to object storage. - if err := s.updateModelConfig(ctx, model.Name, version); err != nil { - return err - } - - // Create a transaction to ensure that only one - // version is active at a time. - tx := s.db.WithContext(ctx).Begin() - defer func() { - if r := recover(); r != nil { - tx.Rollback() - } - }() - - if err := tx.Error; err != nil { - return err - } - - if err := tx.Model(&models.Model{}).Where(&models.Model{ - SchedulerID: model.SchedulerID, - State: models.ModelVersionStateActive, - }).Updates(&models.Model{State: models.ModelVersionStateInactive}).Error; err != nil { - tx.Rollback() - return err - } - - if err := tx.Model(model).Updates(&models.Model{State: models.ModelVersionStateActive}).Error; err != nil { - tx.Rollback() - return err - } - - if tx.Commit().Error != nil { - return err - } - - return nil -} - -func (s *service) updateModelConfig(ctx context.Context, name string, version int64) error { - if !s.config.ObjectStorage.Enable { - return errors.New("object storage is disabled") - } - - objectKey := types.MakeObjectKeyOfModelConfigFile(name) - var pbModelConfig inference.ModelConfig - reader, err := s.objectStorage.GetObject(ctx, s.config.Trainer.BucketName, objectKey) - if err != nil { - return err - } - defer reader.Close() - - data, err := io.ReadAll(reader) - if err != nil { - return err - } - - if err := json.Unmarshal(data, &pbModelConfig); err != nil { - return err - } - - switch policyChoice := pbModelConfig.VersionPolicy.PolicyChoice.(type) { - case *inference.ModelVersionPolicy_Specific_: - // If the version already exists, add the version to the existing version list. - policyChoice.Specific.Versions = []int64{version} - default: - return fmt.Errorf("unknown policy choice: %#v", policyChoice) - } - - dgst := digest.New(digest.AlgorithmSHA256, digest.SHA256FromStrings(pbModelConfig.String())) - if err := s.objectStorage.PutObject(ctx, s.config.Trainer.BucketName, - types.MakeObjectKeyOfModelConfigFile(name), dgst.String(), strings.NewReader(pbModelConfig.String())); err != nil { - return err - } - - return nil -} diff --git a/manager/service/service.go b/manager/service/service.go index fcbf6d63687..9035b5f5912 100644 --- a/manager/service/service.go +++ b/manager/service/service.go @@ -130,11 +130,6 @@ type Service interface { GetApplication(context.Context, uint) (*models.Application, error) GetApplications(context.Context, types.GetApplicationsQuery) ([]models.Application, int64, error) - DestroyModel(context.Context, uint) error - UpdateModel(context.Context, uint, types.UpdateModelRequest) (*models.Model, error) - GetModel(context.Context, uint) (*models.Model, error) - GetModels(context.Context, types.GetModelsQuery) ([]models.Model, int64, error) - CreatePersonalAccessToken(context.Context, types.CreatePersonalAccessTokenRequest) (*models.PersonalAccessToken, error) DestroyPersonalAccessToken(context.Context, uint) error UpdatePersonalAccessToken(context.Context, uint, types.UpdatePersonalAccessTokenRequest) (*models.PersonalAccessToken, error) diff --git a/pkg/rpc/inference/client/client_v1.go b/pkg/rpc/inference/client/client_v1.go deleted file mode 100644 index 073e4765205..00000000000 --- a/pkg/rpc/inference/client/client_v1.go +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -//go:generate mockgen -destination mocks/client_v1_mock.go -source client_v1.go -package mocks - -package client - -import ( - "context" - "math" - "time" - - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "google.golang.org/grpc" - - inference "d7y.io/api/v2/pkg/apis/inference" - - logger "d7y.io/dragonfly/v2/internal/dflog" -) - -const ( - // contextTimeout is timeout of grpc invoke. - contextTimeout = 2 * time.Minute - - // maxRetries is maximum number of retries. - maxRetries = 3 - - // backoffWaitBetween is waiting for a fixed period of - // time between calls in backoff linear. - backoffWaitBetween = 500 * time.Millisecond -) - -// GetV1 returns v1 version of the prediction client. -func GetV1(ctx context.Context, target string, opts ...grpc.DialOption) (V1, error) { - conn, err := grpc.DialContext( - ctx, - target, - append([]grpc.DialOption{ - grpc.WithIdleTimeout(0), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(math.MaxInt32), - grpc.MaxCallSendMsgSize(math.MaxInt32), - ), - grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient( - grpc_prometheus.UnaryClientInterceptor, - grpc_zap.UnaryClientInterceptor(logger.GrpcLogger.Desugar()), - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(maxRetries), - grpc_retry.WithBackoff(grpc_retry.BackoffLinear(backoffWaitBetween)), - ), - )), - grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient( - grpc_prometheus.StreamClientInterceptor, - grpc_zap.StreamClientInterceptor(logger.GrpcLogger.Desugar()), - )), - }, opts...)..., - ) - if err != nil { - return nil, err - } - - return &v1{ - GRPCInferenceServiceClient: inference.NewGRPCInferenceServiceClient(conn), - ClientConn: conn, - }, nil -} - -// ClientV1 is the interface for v1 version of the grpc client. -type V1 interface { - // ModelInfer performs inference using a specific model. - ModelInfer(context.Context, *inference.ModelInferRequest, ...grpc.CallOption) (*inference.ModelInferResponse, error) - - // ModelReady checks readiness of a model in the inference server.. - ModelReady(context.Context, *inference.ModelReadyRequest, ...grpc.CallOption) (*inference.ModelReadyResponse, error) - - // ServerReady checks readiness of the inference server. - ServerReady(context.Context, *inference.ServerReadyRequest, ...grpc.CallOption) (*inference.ServerReadyResponse, error) - - // Close tears down the ClientConn and all underlying connections. - Close() error -} - -// clientV1 provides v1 version of the prediction grpc function. -type v1 struct { - inference.GRPCInferenceServiceClient - *grpc.ClientConn -} - -// ModelInfer performs inference using a specific model. -func (v *v1) ModelInfer(ctx context.Context, req *inference.ModelInferRequest, opts ...grpc.CallOption) (*inference.ModelInferResponse, error) { - ctx, cancel := context.WithTimeout(ctx, contextTimeout) - defer cancel() - - return v.GRPCInferenceServiceClient.ModelInfer(ctx, req, opts...) -} - -// ModelReady checks readiness of a model in the inference server. -func (v *v1) ModelReady(ctx context.Context, req *inference.ModelReadyRequest, opts ...grpc.CallOption) (*inference.ModelReadyResponse, error) { - ctx, cancel := context.WithTimeout(ctx, contextTimeout) - defer cancel() - - return v.GRPCInferenceServiceClient.ModelReady(ctx, req, opts...) -} - -// ServerReady checks readiness of the inference server. -func (v *v1) ServerReady(ctx context.Context, req *inference.ServerReadyRequest, opts ...grpc.CallOption) (*inference.ServerReadyResponse, error) { - ctx, cancel := context.WithTimeout(ctx, contextTimeout) - defer cancel() - - return v.GRPCInferenceServiceClient.ServerReady(ctx, req, opts...) -} diff --git a/pkg/rpc/inference/client/mocks/client_v1_mock.go b/pkg/rpc/inference/client/mocks/client_v1_mock.go deleted file mode 100644 index fbc375a5244..00000000000 --- a/pkg/rpc/inference/client/mocks/client_v1_mock.go +++ /dev/null @@ -1,116 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: client_v1.go -// -// Generated by this command: -// -// mockgen -destination mocks/client_v1_mock.go -source client_v1.go -package mocks -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - inference "d7y.io/api/v2/pkg/apis/inference" - gomock "go.uber.org/mock/gomock" - grpc "google.golang.org/grpc" -) - -// MockV1 is a mock of V1 interface. -type MockV1 struct { - ctrl *gomock.Controller - recorder *MockV1MockRecorder -} - -// MockV1MockRecorder is the mock recorder for MockV1. -type MockV1MockRecorder struct { - mock *MockV1 -} - -// NewMockV1 creates a new mock instance. -func NewMockV1(ctrl *gomock.Controller) *MockV1 { - mock := &MockV1{ctrl: ctrl} - mock.recorder = &MockV1MockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockV1) EXPECT() *MockV1MockRecorder { - return m.recorder -} - -// Close mocks base method. -func (m *MockV1) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockV1MockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockV1)(nil).Close)) -} - -// ModelInfer mocks base method. -func (m *MockV1) ModelInfer(arg0 context.Context, arg1 *inference.ModelInferRequest, arg2 ...grpc.CallOption) (*inference.ModelInferResponse, error) { - m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "ModelInfer", varargs...) - ret0, _ := ret[0].(*inference.ModelInferResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ModelInfer indicates an expected call of ModelInfer. -func (mr *MockV1MockRecorder) ModelInfer(arg0, arg1 any, arg2 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModelInfer", reflect.TypeOf((*MockV1)(nil).ModelInfer), varargs...) -} - -// ModelReady mocks base method. -func (m *MockV1) ModelReady(arg0 context.Context, arg1 *inference.ModelReadyRequest, arg2 ...grpc.CallOption) (*inference.ModelReadyResponse, error) { - m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "ModelReady", varargs...) - ret0, _ := ret[0].(*inference.ModelReadyResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ModelReady indicates an expected call of ModelReady. -func (mr *MockV1MockRecorder) ModelReady(arg0, arg1 any, arg2 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModelReady", reflect.TypeOf((*MockV1)(nil).ModelReady), varargs...) -} - -// ServerReady mocks base method. -func (m *MockV1) ServerReady(arg0 context.Context, arg1 *inference.ServerReadyRequest, arg2 ...grpc.CallOption) (*inference.ServerReadyResponse, error) { - m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "ServerReady", varargs...) - ret0, _ := ret[0].(*inference.ServerReadyResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ServerReady indicates an expected call of ServerReady. -func (mr *MockV1MockRecorder) ServerReady(arg0, arg1 any, arg2 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServerReady", reflect.TypeOf((*MockV1)(nil).ServerReady), varargs...) -} diff --git a/pkg/rpc/manager/client/client_v1.go b/pkg/rpc/manager/client/client_v1.go index 30d064b124a..e9ca43f5d63 100644 --- a/pkg/rpc/manager/client/client_v1.go +++ b/pkg/rpc/manager/client/client_v1.go @@ -115,9 +115,6 @@ type V1 interface { // List applications configuration. ListApplications(context.Context, *managerv1.ListApplicationsRequest, ...grpc.CallOption) (*managerv1.ListApplicationsResponse, error) - // Create model and update data of model to object storage. - CreateModel(context.Context, *managerv1.CreateModelRequest, ...grpc.CallOption) error - // KeepAlive with manager. KeepAlive(time.Duration, *managerv1.KeepAliveRequest, <-chan struct{}, ...grpc.CallOption) @@ -196,15 +193,6 @@ func (v *v1) ListApplications(ctx context.Context, req *managerv1.ListApplicatio return v.ManagerClient.ListApplications(ctx, req, opts...) } -// Create model and update data of model to object storage. -func (v *v1) CreateModel(ctx context.Context, req *managerv1.CreateModelRequest, opts ...grpc.CallOption) error { - ctx, cancel := context.WithTimeout(ctx, createModelContextTimeout) - defer cancel() - - _, err := v.ManagerClient.CreateModel(ctx, req, opts...) - return err -} - // List active schedulers configuration. func (v *v1) KeepAlive(interval time.Duration, keepalive *managerv1.KeepAliveRequest, done <-chan struct{}, opts ...grpc.CallOption) { log := logger.WithKeepAlive(keepalive.Hostname, keepalive.Ip, keepalive.SourceType.Enum().String(), keepalive.ClusterId) diff --git a/pkg/rpc/trainer/client/client_v1.go b/pkg/rpc/trainer/client/client_v1.go deleted file mode 100644 index bfb83d8ab9a..00000000000 --- a/pkg/rpc/trainer/client/client_v1.go +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -//go:generate mockgen -destination mocks/client_v1_mock.go -source client_v1.go -package mocks - -package client - -import ( - "context" - "math" - "time" - - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "google.golang.org/grpc" - - trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1" - - logger "d7y.io/dragonfly/v2/internal/dflog" -) - -const ( - // maxRetries is maximum number of retries. - maxRetries = 3 - - // backoffWaitBetween is waiting for a fixed period of - // time between calls in backoff linear. - backoffWaitBetween = 500 * time.Millisecond -) - -// GetV1ByAddr returns v1 version of the trainer client by address. -func GetV1ByAddr(ctx context.Context, target string, opts ...grpc.DialOption) (V1, error) { - conn, err := grpc.DialContext( - ctx, - target, - append([]grpc.DialOption{ - grpc.WithIdleTimeout(0), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(math.MaxInt32), - grpc.MaxCallSendMsgSize(math.MaxInt32), - ), - grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient( - grpc_prometheus.UnaryClientInterceptor, - grpc_zap.UnaryClientInterceptor(logger.GrpcLogger.Desugar()), - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(maxRetries), - grpc_retry.WithBackoff(grpc_retry.BackoffLinear(backoffWaitBetween)), - ), - )), - grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient( - grpc_prometheus.StreamClientInterceptor, - grpc_zap.StreamClientInterceptor(logger.GrpcLogger.Desugar()), - )), - }, opts...)..., - ) - if err != nil { - return nil, err - } - - return &v1{ - TrainerClient: trainerv1.NewTrainerClient(conn), - ClientConn: conn, - }, nil -} - -// V1 is the interface for v1 version of the grpc client. -type V1 interface { - // Train models of scheduler using dataset. - Train(context.Context, ...grpc.CallOption) (trainerv1.Trainer_TrainClient, error) - - // Close tears down the ClientConn and all underlying connections. - Close() error -} - -// v1 provides v1 version of the trainer grpc function. -type v1 struct { - trainerv1.TrainerClient - *grpc.ClientConn -} - -// Train models of scheduler using dataset. -func (v *v1) Train(ctx context.Context, opts ...grpc.CallOption) (trainerv1.Trainer_TrainClient, error) { - return v.TrainerClient.Train(ctx, opts...) -} diff --git a/pkg/rpc/trainer/client/mocks/client_v1_mock.go b/pkg/rpc/trainer/client/mocks/client_v1_mock.go deleted file mode 100644 index ba67f4b21e5..00000000000 --- a/pkg/rpc/trainer/client/mocks/client_v1_mock.go +++ /dev/null @@ -1,76 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: client_v1.go -// -// Generated by this command: -// -// mockgen -destination mocks/client_v1_mock.go -source client_v1.go -package mocks -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - trainer "d7y.io/api/v2/pkg/apis/trainer/v1" - gomock "go.uber.org/mock/gomock" - grpc "google.golang.org/grpc" -) - -// MockV1 is a mock of V1 interface. -type MockV1 struct { - ctrl *gomock.Controller - recorder *MockV1MockRecorder -} - -// MockV1MockRecorder is the mock recorder for MockV1. -type MockV1MockRecorder struct { - mock *MockV1 -} - -// NewMockV1 creates a new mock instance. -func NewMockV1(ctrl *gomock.Controller) *MockV1 { - mock := &MockV1{ctrl: ctrl} - mock.recorder = &MockV1MockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockV1) EXPECT() *MockV1MockRecorder { - return m.recorder -} - -// Close mocks base method. -func (m *MockV1) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockV1MockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockV1)(nil).Close)) -} - -// Train mocks base method. -func (m *MockV1) Train(arg0 context.Context, arg1 ...grpc.CallOption) (trainer.Trainer_TrainClient, error) { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Train", varargs...) - ret0, _ := ret[0].(trainer.Trainer_TrainClient) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Train indicates an expected call of Train. -func (mr *MockV1MockRecorder) Train(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Train", reflect.TypeOf((*MockV1)(nil).Train), varargs...) -} diff --git a/pkg/rpc/trainer/server/server.go b/pkg/rpc/trainer/server/server.go deleted file mode 100644 index b8fce463628..00000000000 --- a/pkg/rpc/trainer/server/server.go +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package server - -import ( - "math" - "time" - - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" - grpc_ratelimit "github.com/grpc-ecosystem/go-grpc-middleware/ratelimit" - grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" - grpc_validator "github.com/grpc-ecosystem/go-grpc-middleware/validator" - grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "google.golang.org/grpc" - "google.golang.org/grpc/health" - healthpb "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/reflection" - - trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1" - - logger "d7y.io/dragonfly/v2/internal/dflog" - "d7y.io/dragonfly/v2/pkg/rpc" -) - -const ( - // DefaultQPS is default qps of grpc server. - DefaultQPS = 20 * 1000 - - // DefaultBurst is default burst of grpc server. - DefaultBurst = 30 * 1000 - - // DefaultMaxConnectionIdle is default max connection idle of grpc keepalive. - DefaultMaxConnectionIdle = 10 * time.Minute - - // DefaultMaxConnectionAge is default max connection age of grpc keepalive. - DefaultMaxConnectionAge = 12 * time.Hour - - // DefaultMaxConnectionAgeGrace is default max connection age grace of grpc keepalive. - DefaultMaxConnectionAgeGrace = 5 * time.Minute -) - -// New returns grpc server instance and register service on grpc server. -func New(trainerServerV1 trainerv1.TrainerServer, opts ...grpc.ServerOption) *grpc.Server { - limiter := rpc.NewRateLimiterInterceptor(DefaultQPS, DefaultBurst) - - grpcServer := grpc.NewServer(append([]grpc.ServerOption{ - grpc.MaxRecvMsgSize(math.MaxInt32), - grpc.MaxSendMsgSize(math.MaxInt32), - grpc.StatsHandler(otelgrpc.NewServerHandler()), - grpc.KeepaliveParams(keepalive.ServerParameters{ - MaxConnectionIdle: DefaultMaxConnectionIdle, - MaxConnectionAge: DefaultMaxConnectionAge, - MaxConnectionAgeGrace: DefaultMaxConnectionAgeGrace, - }), - grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( - grpc_ratelimit.UnaryServerInterceptor(limiter), - grpc_prometheus.UnaryServerInterceptor, - grpc_zap.UnaryServerInterceptor(logger.GrpcLogger.Desugar()), - grpc_validator.UnaryServerInterceptor(), - grpc_recovery.UnaryServerInterceptor(), - )), - grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( - grpc_ratelimit.StreamServerInterceptor(limiter), - grpc_prometheus.StreamServerInterceptor, - grpc_zap.StreamServerInterceptor(logger.GrpcLogger.Desugar()), - grpc_validator.StreamServerInterceptor(), - grpc_recovery.StreamServerInterceptor(), - )), - }, opts...)...) - - // Register servers on v1 version of the grpc server. - trainerv1.RegisterTrainerServer(grpcServer, trainerServerV1) - - // Register health on grpc server. - healthpb.RegisterHealthServer(grpcServer, health.NewServer()) - - // Register reflection on grpc server. - reflection.Register(grpcServer) - return grpcServer -} diff --git a/pkg/types/constants.go b/pkg/types/constants.go index 1f79bc74c0c..64567ad31de 100644 --- a/pkg/types/constants.go +++ b/pkg/types/constants.go @@ -37,9 +37,6 @@ const ( // DfstoreName is dfstore name of dfdaemon. DfstoreName = "dfstore" - - // TrainerName is name of trainer. - TrainerName = "trainer" ) const ( @@ -54,9 +51,6 @@ const ( // DfdaemonMetricsName is name of dfdaemon metrics. DfdaemonMetricsName = "dfdaemon" - - // TrainerMetricsName is name of trainer metrics. - TrainerMetricsName = "trainer" ) const ( diff --git a/scheduler/announcer/announcer.go b/scheduler/announcer/announcer.go index 73078b68f64..bbec011e5fb 100644 --- a/scheduler/announcer/announcer.go +++ b/scheduler/announcer/announcer.go @@ -20,17 +20,11 @@ package announcer import ( "context" - "io" - "time" - - "golang.org/x/sync/errgroup" managerv2 "d7y.io/api/v2/pkg/apis/manager/v2" - trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1" logger "d7y.io/dragonfly/v2/internal/dflog" managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client" - trainerclient "d7y.io/dragonfly/v2/pkg/rpc/trainer/client" "d7y.io/dragonfly/v2/scheduler/config" "d7y.io/dragonfly/v2/scheduler/storage" ) @@ -53,18 +47,10 @@ type Announcer interface { type announcer struct { config *config.Config managerClient managerclient.V2 - trainerClient trainerclient.V1 storage storage.Storage done chan struct{} } -// WithTrainerClient sets the grpc client of trainer. -func WithTrainerClient(client trainerclient.V1) Option { - return func(a *announcer) { - a.trainerClient = client - } -} - // Option is a functional option for configuring the announcer. type Option func(s *announcer) @@ -101,11 +87,6 @@ func New(cfg *config.Config, managerClient managerclient.V2, storage storage.Sto func (a *announcer) Serve() { logger.Info("announce scheduler to manager") go a.announceToManager() - - if a.trainerClient != nil { - logger.Info("announce scheduler to trainer") - a.announceToTrainer() - } } // Stop announcer server. @@ -122,114 +103,3 @@ func (a *announcer) announceToManager() { ClusterId: uint64(a.config.Manager.SchedulerClusterID), }, a.done) } - -// announceSeedPeer announces dataset to trainer. -func (a *announcer) announceToTrainer() { - tick := time.NewTicker(a.config.Trainer.Interval) - for { - select { - case <-tick.C: - if err := a.train(); err != nil { - logger.Error(err) - } - case <-a.done: - return - } - } -} - -// train uploads dataset to trainer and trigger training. -func (a *announcer) train() error { - ctx, cancel := context.WithTimeout(context.Background(), a.config.Trainer.UploadTimeout) - defer cancel() - - stream, err := a.trainerClient.Train(ctx) - if err != nil { - return err - } - - eg := errgroup.Group{} - eg.Go(func() error { - return a.uploadDownloadToTrainer(stream) - }) - - eg.Go(func() error { - return a.uploadNetworkTopologyToTrainer(stream) - }) - - if err := eg.Wait(); err != nil { - return err - } - - if _, err := stream.CloseAndRecv(); err != nil { - return err - } - - return nil -} - -// uploadDownloadToTrainer uploads download information to trainer. -func (a *announcer) uploadDownloadToTrainer(stream trainerv1.Trainer_TrainClient) error { - readCloser, err := a.storage.OpenDownload() - if err != nil { - return err - } - defer readCloser.Close() - - buf := make([]byte, defaultUploadBufferSize) - for { - n, err := readCloser.Read(buf) - if err != nil { - if err == io.EOF { - return nil - } - - return err - } - - if err := stream.Send(&trainerv1.TrainRequest{ - Hostname: a.config.Server.Host, - Ip: a.config.Server.AdvertiseIP.String(), - Request: &trainerv1.TrainRequest_TrainMlpRequest{ - TrainMlpRequest: &trainerv1.TrainMLPRequest{ - Dataset: buf[:n], - }, - }, - }); err != nil { - return err - } - } -} - -// uploadNetworkTopologyToTrainer uploads network topology to trainer. -func (a *announcer) uploadNetworkTopologyToTrainer(stream trainerv1.Trainer_TrainClient) error { - readCloser, err := a.storage.OpenNetworkTopology() - if err != nil { - return err - } - defer readCloser.Close() - - buf := make([]byte, defaultUploadBufferSize) - for { - n, err := readCloser.Read(buf) - if err != nil { - if err == io.EOF { - return nil - } - - return err - } - - if err := stream.Send(&trainerv1.TrainRequest{ - Hostname: a.config.Server.Host, - Ip: a.config.Server.AdvertiseIP.String(), - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: buf[:n], - }, - }, - }); err != nil { - return err - } - } -} diff --git a/scheduler/announcer/announcer_test.go b/scheduler/announcer/announcer_test.go index bd7b019328d..347bd107380 100644 --- a/scheduler/announcer/announcer_test.go +++ b/scheduler/announcer/announcer_test.go @@ -17,11 +17,8 @@ package announcer import ( - "bytes" "errors" - "io" "net" - "sync" "testing" "time" @@ -29,11 +26,8 @@ import ( "go.uber.org/mock/gomock" managerv2 "d7y.io/api/v2/pkg/apis/manager/v2" - trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1" - trainerv1mocks "d7y.io/api/v2/pkg/apis/trainer/v1/mocks" managerclientmocks "d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks" - trainerclientmocks "d7y.io/dragonfly/v2/pkg/rpc/trainer/client/mocks" "d7y.io/dragonfly/v2/scheduler/config" storagemocks "d7y.io/dragonfly/v2/scheduler/storage/mocks" ) @@ -56,14 +50,12 @@ func (m *mockReadCloserWithReadError) Close() error { func TestAnnouncer_New(t *testing.T) { ctl := gomock.NewController(t) defer ctl.Finish() - mockTrainerClient := trainerclientmocks.NewMockV1(ctl) tests := []struct { - name string - config *config.Config - options []Option - mock func(m *managerclientmocks.MockV2MockRecorder) - expect func(t *testing.T, announcer Announcer, err error) + name string + config *config.Config + mock func(m *managerclientmocks.MockV2MockRecorder) + expect func(t *testing.T, announcer Announcer, err error) }{ { name: "new announcer", @@ -82,36 +74,6 @@ func TestAnnouncer_New(t *testing.T) { SchedulerClusterID: 1, }, }, - options: []Option{}, - mock: func(m *managerclientmocks.MockV2MockRecorder) { - m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) - }, - expect: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - instance := a.(*announcer) - assert.NoError(err) - assert.NotNil(instance.config) - assert.NotNil(instance.managerClient) - }, - }, - { - name: "new announcer with trainer client", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - options: []Option{WithTrainerClient(mockTrainerClient)}, mock: func(m *managerclientmocks.MockV2MockRecorder) { m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) }, @@ -121,7 +83,6 @@ func TestAnnouncer_New(t *testing.T) { assert.NoError(err) assert.NotNil(instance.config) assert.NotNil(instance.managerClient) - assert.NotNil(instance.trainerClient) }, }, { @@ -141,7 +102,6 @@ func TestAnnouncer_New(t *testing.T) { SchedulerClusterID: 1, }, }, - options: []Option{}, mock: func(m *managerclientmocks.MockV2MockRecorder) { m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")).Times(1) }, @@ -158,7 +118,7 @@ func TestAnnouncer_New(t *testing.T) { mockStorage := storagemocks.NewMockStorage(ctl) tc.mock(mockManagerClient.EXPECT()) - a, err := New(tc.config, mockManagerClient, mockStorage, tc.options...) + a, err := New(tc.config, mockManagerClient, mockStorage) tc.expect(t, a, err) }) } @@ -167,16 +127,14 @@ func TestAnnouncer_New(t *testing.T) { func TestAnnouncer_Serve(t *testing.T) { ctl := gomock.NewController(t) defer ctl.Finish() - mockTrainerClient := trainerclientmocks.NewMockV1(ctl) tests := []struct { - name string - config *config.Config - data []byte - options []Option - sleep func() - mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) - except func(t *testing.T, a Announcer) + name string + config *config.Config + data []byte + sleep func() + mock func(data []byte, m *managerclientmocks.MockV2MockRecorder, ms *storagemocks.MockStorageMockRecorder) + except func(t *testing.T, a Announcer) }{ { name: "started announcer server success", @@ -197,88 +155,12 @@ func TestAnnouncer_Serve(t *testing.T) { }, SchedulerClusterID: 1, }, - Trainer: config.TrainerConfig{ - Interval: 2 * time.Second, - UploadTimeout: 10 * time.Second, - }, }, - data: []byte("bar"), - options: []Option{WithTrainerClient(mockTrainerClient)}, + data: []byte("bar"), sleep: func() { time.Sleep(3 * time.Second) }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - var wg sync.WaitGroup - wg.Add(2) - - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - m.KeepAlive(gomock.Eq(50*time.Millisecond), gomock.Eq(&managerv2.KeepAliveRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - ClusterId: uint64(1), - }), gomock.Any()).Times(1), - mtc.Train(gomock.Any()).Return(stream, nil).Times(1), - mt.CloseAndRecv().Do(func() { wg.Wait() }).Return(nil, nil).Times(1), - ) - - gomock.InOrder( - ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - - gomock.InOrder( - ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - }, - except: func(t *testing.T, a Announcer) { - go a.Serve() - }, - }, - { - name: "started announcer server success without trainer client", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - KeepAlive: config.KeepAliveConfig{ - Interval: 50 * time.Millisecond, - }, - SchedulerClusterID: 1, - }, - }, - data: []byte("bar"), - options: []Option{}, - sleep: func() { - time.Sleep(100 * time.Millisecond) - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { + mock: func(data []byte, m *managerclientmocks.MockV2MockRecorder, ms *storagemocks.MockStorageMockRecorder) { gomock.InOrder( m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ SourceType: managerv2.SourceType_SCHEDULER_SOURCE, @@ -305,15 +187,15 @@ func TestAnnouncer_Serve(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl) mockManagerClient := managerclientmocks.NewMockV2(ctl) mockStorage := storagemocks.NewMockStorage(ctl) - tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT()) - a, err := New(tc.config, mockManagerClient, mockStorage, tc.options...) + tc.mock(tc.data, mockManagerClient.EXPECT(), mockStorage.EXPECT()) + a, err := New(tc.config, mockManagerClient, mockStorage) if err != nil { t.Fatal(err) } + tc.except(t, a) tc.sleep() a.Stop() @@ -382,826 +264,16 @@ func TestAnnouncer_announceToManager(t *testing.T) { ctl := gomock.NewController(t) defer ctl.Finish() mockManagerClient := managerclientmocks.NewMockV2(ctl) - mockTrainerClient := trainerclientmocks.NewMockV1(ctl) mockStorage := storagemocks.NewMockStorage(ctl) tc.mock(mockManagerClient.EXPECT()) - a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient)) + a, err := New(tc.config, mockManagerClient, mockStorage) if err != nil { t.Fatal(err) } - tc.except(a) - tc.sleep() - }) - } -} - -func TestAnnouncer_announceToTrainer(t *testing.T) { - tests := []struct { - name string - config *config.Config - data []byte - sleep func() - mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) - except func(a Announcer) - }{ - { - name: "announce to trainer failed", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - Trainer: config.TrainerConfig{ - Interval: 2 * time.Second, - UploadTimeout: 10 * time.Second, - }, - }, - data: []byte("bar"), - sleep: func() { - time.Sleep(3 * time.Second) - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - mtc.Train(gomock.Any()).Return(nil, errors.New("foo")).Times(1), - ) - }, - except: func(a Announcer) { - go a.(*announcer).announceToTrainer() - }, - }, - { - name: "announce to trainer success", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - Trainer: config.TrainerConfig{ - Interval: 2 * time.Second, - UploadTimeout: 10 * time.Second, - }, - }, - data: []byte("bar"), - sleep: func() { - time.Sleep(3 * time.Second) - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - var wg sync.WaitGroup - wg.Add(2) - - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - mtc.Train(gomock.Any()).Return(stream, nil).Times(1), - mt.CloseAndRecv().Do(func() { wg.Wait() }).Return(nil, nil).Times(1), - ) - - gomock.InOrder( - ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - - gomock.InOrder( - ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - }, - except: func(a Announcer) { - go a.(*announcer).announceToTrainer() - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl) - mockManagerClient := managerclientmocks.NewMockV2(ctl) - mockTrainerClient := trainerclientmocks.NewMockV1(ctl) - mockStorage := storagemocks.NewMockStorage(ctl) - tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT()) - a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient)) - if err != nil { - t.Fatal(err) - } tc.except(a) tc.sleep() - a.Stop() - }) - } -} - -func TestAnnouncer_train(t *testing.T) { - tests := []struct { - name string - config *config.Config - data []byte - mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) - except func(t *testing.T, announcer Announcer, err error) - }{ - { - name: "get stream failed", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - data: []byte("bar"), - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - mtc.Train(gomock.Any()).Return(nil, errors.New("foo")).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "upload download failed", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - data: []byte("bar"), - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - var wg sync.WaitGroup - wg.Add(2) - - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - mtc.Train(gomock.Any()).Return(stream, nil).Times(1), - ) - - gomock.InOrder( - ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - - gomock.InOrder( - ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return errors.New("foo") - }).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "upload network topology failed", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - data: []byte("bar"), - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - var wg sync.WaitGroup - wg.Add(2) - - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - mtc.Train(gomock.Any()).Return(stream, nil).Times(1), - ) - - gomock.InOrder( - ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return errors.New("foo") - }).Times(1), - ) - - gomock.InOrder( - ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "close stream failed", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - data: []byte("bar"), - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - var wg sync.WaitGroup - wg.Add(2) - - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - mtc.Train(gomock.Any()).Return(stream, nil).Times(1), - mt.CloseAndRecv().Return(nil, errors.New("foo")).Do(func() { wg.Wait() }).Times(1), - ) - - gomock.InOrder( - ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - - gomock.InOrder( - ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "train success", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - data: []byte("bar"), - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - var wg sync.WaitGroup - wg.Add(2) - - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - mtc.Train(gomock.Any()).Return(stream, nil).Times(1), - mt.CloseAndRecv().Return(nil, nil).Do(func() { wg.Wait() }).Times(1), - ) - - gomock.InOrder( - ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - - gomock.InOrder( - ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - wg.Done() - return nil - }).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.NoError(err) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl) - mockManagerClient := managerclientmocks.NewMockV2(ctl) - mockTrainerClient := trainerclientmocks.NewMockV1(ctl) - mockStorage := storagemocks.NewMockStorage(ctl) - tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT()) - - a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient)) - if err != nil { - t.Fatal(err) - } - tc.except(t, a, a.(*announcer).train()) - }) - } -} - -func TestAnnouncer_uploadDownloadToTrainer(t *testing.T) { - tests := []struct { - name string - config *config.Config - data []byte - mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) - except func(t *testing.T, announcer Announcer, err error) - }{ - { - name: "open download failed", - data: []byte{}, - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - ms.OpenDownload().Return(nil, errors.New("foo")).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "read buffer failed", - data: []byte{}, - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - ms.OpenDownload().Return(&mockReadCloserWithReadError{}, nil).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "send download request failed", - data: []byte("bar"), - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - return nil - }).Return(errors.New("foo")), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "send download request success", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - data: []byte("bar"), - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - return nil - }).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.NoError(err) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl) - mockManagerClient := managerclientmocks.NewMockV2(ctl) - mockTrainerClient := trainerclientmocks.NewMockV1(ctl) - mockStorage := storagemocks.NewMockStorage(ctl) - tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT()) - - a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient)) - if err != nil { - t.Fatal(err) - } - tc.except(t, a, a.(*announcer).uploadDownloadToTrainer(stream)) - }) - } -} - -func TestAnnouncer_uploadNetworkTopologyToTrainer(t *testing.T) { - tests := []struct { - name string - config *config.Config - data []byte - mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) - except func(t *testing.T, announcer Announcer, err error) - }{ - { - name: "open networkTopology failed", - data: []byte{}, - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - ms.OpenNetworkTopology().Return(nil, errors.New("foo")).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "read buffer failed", - data: []byte{}, - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - ms.OpenNetworkTopology().Return(&mockReadCloserWithReadError{}, nil).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "send network topology failed", - data: []byte("bar"), - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - return nil - }).Return(errors.New("foo")), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.EqualError(err, "foo") - }, - }, - { - name: "send network topology success", - config: &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 8004, - Port: 8080, - }, - Host: config.HostConfig{ - IDC: mockIDC, - Location: mockLocation, - }, - Manager: config.ManagerConfig{ - SchedulerClusterID: 1, - }, - }, - data: []byte("bar"), - mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) { - gomock.InOrder( - m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: "localhost", - Ip: "127.0.0.1", - Port: int32(8004), - Idc: &mockIDC, - Location: &mockLocation, - SchedulerClusterId: uint64(1), - })).Times(1), - ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1), - mt.Send(gomock.Any()).DoAndReturn( - func(t *trainerv1.TrainRequest) error { - return nil - }).Times(1), - ) - }, - except: func(t *testing.T, a Announcer, err error) { - assert := assert.New(t) - assert.NoError(err) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl) - mockManagerClient := managerclientmocks.NewMockV2(ctl) - mockTrainerClient := trainerclientmocks.NewMockV1(ctl) - mockStorage := storagemocks.NewMockStorage(ctl) - tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT()) - - a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient)) - if err != nil { - t.Fatal(err) - } - tc.except(t, a, a.(*announcer).uploadNetworkTopologyToTrainer(stream)) }) } } diff --git a/scheduler/config/config.go b/scheduler/config/config.go index 9136859a81b..5acfe2a304a 100644 --- a/scheduler/config/config.go +++ b/scheduler/config/config.go @@ -72,9 +72,6 @@ type Config struct { // Network configuration. Network NetworkConfig `yaml:"network" mapstructure:"network"` - - // Trainer configuration. - Trainer TrainerConfig `yaml:"trainer" mapstructure:"trainer"` } type ServerConfig struct { @@ -366,20 +363,6 @@ type CacheConfig struct { TTL time.Duration `yaml:"ttl" mapstructure:"ttl"` } -type TrainerConfig struct { - // Enable trainer service. - Enable bool `yaml:"enable" mapstructure:"enable"` - - // Addr is trainer service address. - Addr string `yaml:"addr" mapstructure:"addr"` - - // Interval is the interval of training. - Interval time.Duration `yaml:"interval" mapstructure:"interval"` - - // UploadTimeout is the timeout of uploading dataset to trainer. - UploadTimeout time.Duration `yaml:"uploadTimeout" mapstructure:"uploadTimeout"` -} - // New default configuration. func New() *Config { return &Config{ @@ -478,12 +461,6 @@ func New() *Config { Network: NetworkConfig{ EnableIPv6: DefaultNetworkEnableIPv6, }, - Trainer: TrainerConfig{ - Enable: false, - Addr: DefaultTrainerAddr, - Interval: DefaultTrainerInterval, - UploadTimeout: DefaultTrainerUploadTimeout, - }, } } @@ -669,20 +646,6 @@ func (cfg *Config) Validate() error { } } - if cfg.Trainer.Enable { - if cfg.Trainer.Addr == "" { - return errors.New("trainer requires parameter addr") - } - - if cfg.Trainer.Interval <= 0 { - return errors.New("trainer requires parameter interval") - } - - if cfg.Trainer.UploadTimeout <= 0 { - return errors.New("trainer requires parameter uploadTimeout") - } - } - return nil } diff --git a/scheduler/config/config_test.go b/scheduler/config/config_test.go index 09103ba4aec..9baee730fc7 100644 --- a/scheduler/config/config_test.go +++ b/scheduler/config/config_test.go @@ -186,12 +186,6 @@ func TestConfig_Load(t *testing.T) { Network: NetworkConfig{ EnableIPv6: true, }, - Trainer: TrainerConfig{ - Enable: false, - Addr: "127.0.0.1:9090", - Interval: 10 * time.Minute, - UploadTimeout: 2 * time.Hour, - }, } schedulerConfigYAML := &Config{} @@ -829,51 +823,6 @@ func TestConfig_Validate(t *testing.T) { assert.EqualError(err, "certSpec requires parameter validityPeriod") }, }, - { - name: "trainer requires parameter addr", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Database.Redis = mockRedisConfig - cfg.Job = mockJobConfig - cfg.Trainer.Enable = true - cfg.Trainer.Addr = "" - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "trainer requires parameter addr") - }, - }, - { - name: "trainer requires parameter interval", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Database.Redis = mockRedisConfig - cfg.Job = mockJobConfig - cfg.Trainer.Enable = true - cfg.Trainer.Interval = 0 - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "trainer requires parameter interval") - }, - }, - { - name: "trainer requires parameter interval", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Database.Redis = mockRedisConfig - cfg.Job = mockJobConfig - cfg.Trainer.Enable = true - cfg.Trainer.UploadTimeout = 0 - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "trainer requires parameter uploadTimeout") - }, - }, } for _, tc := range tests { diff --git a/scheduler/config/constants.go b/scheduler/config/constants.go index e3a77685e79..3345db66829 100644 --- a/scheduler/config/constants.go +++ b/scheduler/config/constants.go @@ -190,17 +190,6 @@ const ( DefaultStorageBufferSize = 100 ) -const ( - // DefaultTrainerAddr is the default address of trainer. - DefaultTrainerAddr = "127.0.0.1:9090" - - // DefaultTrainerInterval is the default interval of training. - DefaultTrainerInterval = 7 * 24 * time.Hour - - // DefaultTrainerUploadTimeout is the default timeout of uploading dataset to trainer. - DefaultTrainerUploadTimeout = 1 * time.Hour -) - const ( // DefaultLogRotateMaxSize is the default maximum size in megabytes of log files before rotation. DefaultLogRotateMaxSize = 1024 diff --git a/scheduler/config/testdata/scheduler.yaml b/scheduler/config/testdata/scheduler.yaml index c6a9bc109d8..55b01466214 100644 --- a/scheduler/config/testdata/scheduler.yaml +++ b/scheduler/config/testdata/scheduler.yaml @@ -101,9 +101,3 @@ security: network: enableIPv6: true - -trainer: - enable: false - addr: "127.0.0.1:9090" - interval: 10m - uploadTimeout: 2h diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index b1cb3e06648..452da61f967 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -45,7 +45,6 @@ import ( "d7y.io/dragonfly/v2/pkg/rpc" managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client" securityclient "d7y.io/dragonfly/v2/pkg/rpc/security/client" - trainerclient "d7y.io/dragonfly/v2/pkg/rpc/trainer/client" "d7y.io/dragonfly/v2/pkg/types" "d7y.io/dragonfly/v2/scheduler/announcer" "d7y.io/dragonfly/v2/scheduler/config" @@ -82,9 +81,6 @@ type Server struct { // Security client. securityClient securityclient.V1 - // Trainer client. - trainerClient trainerclient.V1 - // Resource interface. resource resource.Resource @@ -143,36 +139,8 @@ func New(ctx context.Context, cfg *config.Config, d dfpath.Dfpath) (*Server, err } s.managerClient = managerClient - // Initialize dial options of trainer grpc client. - if cfg.Trainer.Enable { - trainerDialOptions := []grpc.DialOption{grpc.WithStatsHandler(otelgrpc.NewClientHandler())} - if cfg.Security.AutoIssueCert { - clientTransportCredentials, err := rpc.NewClientCredentials(cfg.Security.TLSPolicy, nil, []byte(cfg.Security.CACert)) - if err != nil { - return nil, err - } - - trainerDialOptions = append(trainerDialOptions, grpc.WithTransportCredentials(clientTransportCredentials)) - } else { - trainerDialOptions = append(trainerDialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) - } - - // Initialize trainer client. - trainerClient, err := trainerclient.GetV1ByAddr(ctx, cfg.Trainer.Addr, trainerDialOptions...) - if err != nil { - return nil, err - } - s.trainerClient = trainerClient - } - - // Initialize dial options of announcer. - announcerOptions := []announcer.Option{} - if s.trainerClient != nil { - announcerOptions = append(announcerOptions, announcer.WithTrainerClient(s.trainerClient)) - } - // Initialize announcer. - announcer, err := announcer.New(cfg, s.managerClient, storage, announcerOptions...) + announcer, err := announcer.New(cfg, s.managerClient, storage) if err != nil { return nil, err } @@ -426,15 +394,6 @@ func (s *Server) Stop() { } } - // Stop trainer client. - if s.trainerClient != nil { - if err := s.trainerClient.Close(); err != nil { - logger.Errorf("trainer client failed to stop: %s", err.Error()) - } else { - logger.Info("trainer client closed") - } - } - // Stop security client. if s.securityClient != nil { if err := s.securityClient.Close(); err != nil { diff --git a/trainer/config/config.go b/trainer/config/config.go deleted file mode 100644 index 12411544dda..00000000000 --- a/trainer/config/config.go +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package config - -import ( - "errors" - "net" - "time" - - "d7y.io/dragonfly/v2/cmd/dependency/base" - "d7y.io/dragonfly/v2/pkg/net/ip" - "d7y.io/dragonfly/v2/pkg/rpc" - "d7y.io/dragonfly/v2/pkg/slices" - "d7y.io/dragonfly/v2/pkg/types" -) - -type Config struct { - // Base options. - base.Options `yaml:",inline" mapstructure:",squash"` - - // Network configuration. - Network NetworkConfig `yaml:"network" mapstructure:"network"` - - // Server configuration. - Server ServerConfig `yaml:"server" mapstructure:"server"` - - // Metrics configuration. - Metrics MetricsConfig `yaml:"metrics" mapstructure:"metrics"` - - // Security configuration. - Security SecurityConfig `yaml:"security" mapstructure:"security"` - - // Manager configuration. - Manager ManagerConfig `yaml:"manager" mapstructure:"manager"` -} - -type NetworkConfig struct { - // EnableIPv6 enables ipv6 for server. - EnableIPv6 bool `yaml:"enableIPv6" mapstructure:"enableIPv6"` -} - -type ServerConfig struct { - // AdvertiseIP is advertise ip. - AdvertiseIP net.IP `yaml:"advertiseIP" mapstructure:"advertiseIP"` - - // AdvertisePort is advertise port. - AdvertisePort int `yaml:"advertisePort" mapstructure:"advertisePort"` - - // ListenIP is listen ip, like: 0.0.0.0, 192.168.0.1. - ListenIP net.IP `yaml:"listenIP" mapstructure:"listenIP"` - - // Server port. - Port int `yaml:"port" mapstructure:"port"` - - // Server log directory. - LogDir string `yaml:"logDir" mapstructure:"logDir"` - - // Maximum size in megabytes of log files before rotation (default: 1024) - LogMaxSize int `yaml:"logMaxSize" mapstructure:"logMaxSize"` - - // Maximum number of days to retain old log files (default: 7) - LogMaxAge int `yaml:"logMaxAge" mapstructure:"logMaxAge"` - - // Maximum number of old log files to keep (default: 20) - LogMaxBackups int `yaml:"logMaxBackups" mapstructure:"logMaxBackups"` - - // Server storage data directory. - DataDir string `yaml:"dataDir" mapstructure:"dataDir"` -} - -type MetricsConfig struct { - // Enable metrics service. - Enable bool `yaml:"enable" mapstructure:"enable"` - - // Metrics service address. - Addr string `yaml:"addr" mapstructure:"addr"` -} - -type SecurityConfig struct { - // AutoIssueCert indicates to issue client certificates for all grpc call - // if AutoIssueCert is false, any other option in Security will be ignored. - AutoIssueCert bool `mapstructure:"autoIssueCert" yaml:"autoIssueCert"` - - // CACert is the root CA certificate for all grpc tls handshake, it can be path or PEM format string. - CACert types.PEMContent `mapstructure:"caCert" yaml:"caCert"` - - // TLSVerify indicates to verify client certificates. - TLSVerify bool `mapstructure:"tlsVerify" yaml:"tlsVerify"` - - // TLSPolicy controls the grpc shandshake behaviors: - // force: both ClientHandshake and ServerHandshake are only support tls. - // prefer: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support tls. - // default: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support insecure (non-tls). - TLSPolicy string `mapstructure:"tlsPolicy" yaml:"tlsPolicy"` - - // CertSpec is the desired state of certificate. - CertSpec CertSpec `mapstructure:"certSpec" yaml:"certSpec"` -} - -type CertSpec struct { - // DNSNames is a list of dns names be set on the certificate. - DNSNames []string `mapstructure:"dnsNames" yaml:"dnsNames"` - - // IPAddresses is a list of ip addresses be set on the certificate. - IPAddresses []net.IP `mapstructure:"ipAddresses" yaml:"ipAddresses"` - - // ValidityPeriod is the validity period of certificate. - ValidityPeriod time.Duration `mapstructure:"validityPeriod" yaml:"validityPeriod"` -} - -type ManagerConfig struct { - // Addr is manager address. - Addr string `yaml:"addr" mapstructure:"addr"` -} - -// New default configuration. -func New() *Config { - return &Config{ - Network: NetworkConfig{ - EnableIPv6: DefaultNetworkEnableIPv6, - }, - Server: ServerConfig{ - AdvertisePort: DefaultServerAdvertisePort, - Port: DefaultServerPort, - LogMaxSize: DefaultLogRotateMaxSize, - LogMaxAge: DefaultLogRotateMaxAge, - LogMaxBackups: DefaultLogRotateMaxBackups, - }, - Metrics: MetricsConfig{ - Enable: false, - Addr: DefaultMetricsAddr, - }, - Security: SecurityConfig{ - AutoIssueCert: false, - TLSVerify: true, - TLSPolicy: rpc.PreferTLSPolicy, - CertSpec: CertSpec{ - DNSNames: DefaultCertDNSNames, - IPAddresses: DefaultCertIPAddresses, - ValidityPeriod: DefaultCertValidityPeriod, - }, - }, - Manager: ManagerConfig{}, - } -} - -// Validate config parameters. -func (cfg *Config) Validate() error { - if cfg.Server.AdvertiseIP == nil { - return errors.New("server requires parameter advertiseIP") - } - - if cfg.Server.AdvertisePort <= 0 { - return errors.New("server requires parameter advertisePort") - } - - if cfg.Server.ListenIP == nil { - return errors.New("server requires parameter listenIP") - } - - if cfg.Server.Port <= 0 { - return errors.New("server requires parameter port") - } - - if cfg.Metrics.Enable { - if cfg.Metrics.Addr == "" { - return errors.New("metrics requires parameter addr") - } - } - - if cfg.Security.AutoIssueCert { - if cfg.Security.CACert == "" { - return errors.New("security requires parameter caCert") - } - - if !slices.Contains([]string{rpc.DefaultTLSPolicy, rpc.ForceTLSPolicy, rpc.PreferTLSPolicy}, cfg.Security.TLSPolicy) { - return errors.New("security requires parameter tlsPolicy") - } - - if len(cfg.Security.CertSpec.IPAddresses) == 0 { - return errors.New("certSpec requires parameter ipAddresses") - } - - if len(cfg.Security.CertSpec.DNSNames) == 0 { - return errors.New("certSpec requires parameter dnsNames") - } - - if cfg.Security.CertSpec.ValidityPeriod <= 0 { - return errors.New("certSpec requires parameter validityPeriod") - } - } - - if cfg.Manager.Addr == "" { - return errors.New("manager requires parameter addr") - } - - return nil -} - -func (cfg *Config) Convert() error { - if cfg.Server.AdvertiseIP == nil { - if cfg.Network.EnableIPv6 { - cfg.Server.AdvertiseIP = ip.IPv6 - } else { - cfg.Server.AdvertiseIP = ip.IPv4 - } - } - - if cfg.Server.ListenIP == nil { - if cfg.Network.EnableIPv6 { - cfg.Server.ListenIP = net.IPv6zero - } else { - cfg.Server.ListenIP = net.IPv4zero - } - } - - return nil -} diff --git a/trainer/config/config_test.go b/trainer/config/config_test.go deleted file mode 100644 index c274a551dbd..00000000000 --- a/trainer/config/config_test.go +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package config - -import ( - "net" - "os" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v3" - - "d7y.io/dragonfly/v2/pkg/rpc" - "d7y.io/dragonfly/v2/pkg/types" -) - -var ( - mockManagerConfig = ManagerConfig{ - Addr: "localhost", - } - - mockMetricsConfig = MetricsConfig{ - Enable: true, - Addr: DefaultMetricsAddr, - } - - mockSecurityConfig = SecurityConfig{ - AutoIssueCert: true, - CACert: types.PEMContent("foo"), - TLSPolicy: rpc.PreferTLSPolicy, - CertSpec: CertSpec{ - DNSNames: DefaultCertDNSNames, - IPAddresses: DefaultCertIPAddresses, - ValidityPeriod: DefaultCertValidityPeriod, - }, - } -) - -func TestConfig_Load(t *testing.T) { - config := &Config{ - Network: NetworkConfig{ - EnableIPv6: true, - }, - Server: ServerConfig{ - AdvertiseIP: net.ParseIP("127.0.0.1"), - AdvertisePort: 9090, - ListenIP: net.ParseIP("0.0.0.0"), - Port: 9090, - LogDir: "foo", - LogMaxSize: 512, - LogMaxAge: 5, - LogMaxBackups: 3, - DataDir: "foo", - }, - Metrics: MetricsConfig{ - Enable: false, - Addr: ":8000", - }, - Security: SecurityConfig{ - AutoIssueCert: true, - CACert: "foo", - TLSVerify: true, - TLSPolicy: "force", - CertSpec: CertSpec{ - DNSNames: []string{"foo"}, - IPAddresses: []net.IP{net.IPv4zero}, - ValidityPeriod: 10 * time.Minute, - }, - }, - Manager: ManagerConfig{ - Addr: "127.0.0.1:65003", - }, - } - - trainerConfigYAML := &Config{} - contentYAML, _ := os.ReadFile("./testdata/trainer.yaml") - if err := yaml.Unmarshal(contentYAML, &trainerConfigYAML); err != nil { - t.Fatal(err) - } - assert := assert.New(t) - assert.EqualValues(config, trainerConfigYAML) -} - -func TestConfig_Validate(t *testing.T) { - tests := []struct { - name string - config *Config - mock func(cfg *Config) - expect func(t *testing.T, err error) - }{ - { - name: "valid config", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.NoError(err) - }, - }, - { - name: "server requires parameter advertiseIP", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Server.AdvertiseIP = nil - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "server requires parameter advertiseIP") - }, - }, - { - name: "server requires parameter advertisePort", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Server.AdvertisePort = 0 - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "server requires parameter advertisePort") - }, - }, - { - name: "server requires parameter listenIP", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Server.ListenIP = nil - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "server requires parameter listenIP") - }, - }, - { - name: "server requires parameter port", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Server.Port = 0 - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "server requires parameter port") - }, - }, - { - name: "metrics requires parameter addr", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Metrics = mockMetricsConfig - cfg.Metrics.Addr = "" - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "metrics requires parameter addr") - }, - }, - { - name: "security requires parameter caCert", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Security = mockSecurityConfig - cfg.Security.CACert = "" - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "security requires parameter caCert") - }, - }, - { - name: "security requires parameter tlsPolicy", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Security = mockSecurityConfig - cfg.Security.TLSPolicy = "" - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "security requires parameter tlsPolicy") - }, - }, - { - name: "certSpec requires parameter ipAddresses", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Security = mockSecurityConfig - cfg.Security.CertSpec.IPAddresses = []net.IP{} - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "certSpec requires parameter ipAddresses") - }, - }, - { - name: "certSpec requires parameter dnsNames", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Security = mockSecurityConfig - cfg.Security.CertSpec.DNSNames = []string{} - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "certSpec requires parameter dnsNames") - }, - }, - { - name: "certSpec requires parameter validityPeriod", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Security = mockSecurityConfig - cfg.Security.CertSpec.ValidityPeriod = 0 - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "certSpec requires parameter validityPeriod") - }, - }, - { - name: "manager requires parameter addr", - config: New(), - mock: func(cfg *Config) { - cfg.Manager = mockManagerConfig - cfg.Manager.Addr = "" - }, - expect: func(t *testing.T, err error) { - assert := assert.New(t) - assert.EqualError(err, "manager requires parameter addr") - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - if err := tc.config.Convert(); err != nil { - t.Fatal(err) - } - - tc.mock(tc.config) - tc.expect(t, tc.config.Validate()) - }) - } -} diff --git a/trainer/config/constants.go b/trainer/config/constants.go deleted file mode 100644 index b6c51243636..00000000000 --- a/trainer/config/constants.go +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package config - -import ( - "net" - "time" - - "d7y.io/dragonfly/v2/pkg/net/ip" -) - -const ( - // DefaultServerPort is default port for server. - DefaultServerPort = 9090 - - // DefaultServerAdvertisePort is default advertise port for server. - DefaultServerAdvertisePort = 9090 -) - -const ( - // DefaultMetricsAddr is default address for metrics server. - DefaultMetricsAddr = ":8000" -) - -var ( - // DefaultCertIPAddresses is default ip addresses of certificate. - DefaultCertIPAddresses = []net.IP{ip.IPv4, ip.IPv6} - - // DefaultCertDNSNames is default dns names of certificate. - DefaultCertDNSNames = []string{"dragonfly-trainer", "dragonfly-trainer.dragonfly-system.svc", "dragonfly-trainer.dragonfly-system.svc.cluster.local"} - - // DefaultCertValidityPeriod is default validity period of certificate. - DefaultCertValidityPeriod = 180 * 24 * time.Hour -) - -var ( - // DefaultNetworkEnableIPv6 is default value of enableIPv6. - DefaultNetworkEnableIPv6 = false -) - -const ( - // DefaultLogRotateMaxSize is the default maximum size in megabytes of log files before rotation. - DefaultLogRotateMaxSize = 1024 - - // DefaultLogRotateMaxAge is the default number of days to retain old log files. - DefaultLogRotateMaxAge = 7 - - // DefaultLogRotateMaxBackups is the default number of old log files to keep. - DefaultLogRotateMaxBackups = 20 -) diff --git a/trainer/config/testdata/ca.crt b/trainer/config/testdata/ca.crt deleted file mode 100644 index 257cc5642cb..00000000000 --- a/trainer/config/testdata/ca.crt +++ /dev/null @@ -1 +0,0 @@ -foo diff --git a/trainer/config/testdata/trainer.yaml b/trainer/config/testdata/trainer.yaml deleted file mode 100644 index 67b1e0f431e..00000000000 --- a/trainer/config/testdata/trainer.yaml +++ /dev/null @@ -1,33 +0,0 @@ -network: - enableIPv6: true - -server: - advertiseIP: 127.0.0.1 - advertisePort: 9090 - listenIP: 0.0.0.0 - port: 9090 - host: foo - logDir: foo - dataDir: foo - logMaxSize: 512 - logMaxAge: 5 - logMaxBackups: 3 - -metrics: - enable: false - addr: ":8000" - -security: - autoIssueCert: true - caCert: testdata/ca.crt - tlsVerify: true - tlsPolicy: force - certSpec: - dnsNames: - - foo - ipAddresses: - - 0.0.0.0 - validityPeriod: 10m - -manager: - addr: 127.0.0.1:65003 diff --git a/trainer/metrics/metrics.go b/trainer/metrics/metrics.go deleted file mode 100644 index 3bb72247f42..00000000000 --- a/trainer/metrics/metrics.go +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2020 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package metrics - -import ( - "net/http" - - grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/prometheus/client_golang/prometheus/promhttp" - "google.golang.org/grpc" - - "d7y.io/dragonfly/v2/pkg/types" - "d7y.io/dragonfly/v2/trainer/config" - "d7y.io/dragonfly/v2/version" -) - -// Variables declared for metrics. -var ( - TrainCount = promauto.NewCounter(prometheus.CounterOpts{ - Namespace: types.MetricsNamespace, - Subsystem: types.TrainerMetricsName, - Name: "training_total", - Help: "Counter of the number of the training.", - }) - - TrainFailureCount = promauto.NewCounter(prometheus.CounterOpts{ - Namespace: types.MetricsNamespace, - Subsystem: types.TrainerMetricsName, - Name: "training_failure_total", - Help: "Counter of the number of failed of the training.", - }) - - VersionGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{ - Namespace: types.MetricsNamespace, - Subsystem: types.TrainerMetricsName, - Name: "version", - Help: "Version info of the service.", - }, []string{"major", "minor", "git_version", "git_commit", "platform", "build_time", "go_version", "go_tags", "go_gcflags"}) -) - -func New(cfg *config.MetricsConfig, svr *grpc.Server) *http.Server { - grpc_prometheus.Register(svr) - - mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.Handler()) - - VersionGauge.WithLabelValues(version.Major, version.Minor, version.GitVersion, version.GitCommit, version.Platform, version.BuildTime, version.GoVersion, version.Gotags, version.Gogcflags).Set(1) - return &http.Server{ - Addr: cfg.Addr, - Handler: mux, - } -} diff --git a/trainer/metrics/metrics_test.go b/trainer/metrics/metrics_test.go deleted file mode 100644 index a25c8dcb741..00000000000 --- a/trainer/metrics/metrics_test.go +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package metrics - -import ( - "net/http" - "testing" - - "google.golang.org/grpc" - - "d7y.io/dragonfly/v2/trainer/config" -) - -func TestNew(t *testing.T) { - cfg := &config.MetricsConfig{ - Addr: "localhost:8080", - } - svr := grpc.NewServer() - server := New(cfg, svr) - - if server.Addr != cfg.Addr { - t.Errorf("expected server.Addr to be %s, but got %s", cfg.Addr, server.Addr) - } - - if _, ok := server.Handler.(*http.ServeMux); !ok { - t.Errorf("expected server.Handler to be a *http.ServeMux, but got %T", server.Handler) - } -} diff --git a/trainer/rpcserver/rpcserver.go b/trainer/rpcserver/rpcserver.go deleted file mode 100644 index bda0a5c6418..00000000000 --- a/trainer/rpcserver/rpcserver.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package rpcserver - -import ( - "google.golang.org/grpc" - - "d7y.io/dragonfly/v2/pkg/rpc/trainer/server" - "d7y.io/dragonfly/v2/trainer/config" - "d7y.io/dragonfly/v2/trainer/storage" - "d7y.io/dragonfly/v2/trainer/training" -) - -// New creates a new grpc server. -func New( - cfg *config.Config, - storage storage.Storage, - training training.Training, - opts ...grpc.ServerOption, -) *grpc.Server { - return server.New( - newTrainerServerV1(cfg, storage, training), - opts...) -} diff --git a/trainer/rpcserver/rpcserver_test.go b/trainer/rpcserver/rpcserver_test.go deleted file mode 100644 index 34dfc50682a..00000000000 --- a/trainer/rpcserver/rpcserver_test.go +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package rpcserver - -import ( - "reflect" - "testing" - - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" - - "d7y.io/dragonfly/v2/trainer/config" - storagemocks "d7y.io/dragonfly/v2/trainer/storage/mocks" - trainingmocks "d7y.io/dragonfly/v2/trainer/training/mocks" -) - -func TestRPCServer_New(t *testing.T) { - tests := []struct { - name string - expect func(t *testing.T, s any) - }{ - { - name: "new server", - expect: func(t *testing.T, s any) { - assert := assert.New(t) - assert.Equal(reflect.TypeOf(s).Elem().Name(), "Server") - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - storage := storagemocks.NewMockStorage(ctl) - training := trainingmocks.NewMockTraining(ctl) - - svr := New(&config.Config{}, storage, training) - tc.expect(t, svr) - }) - } -} diff --git a/trainer/rpcserver/trainer_server_v1.go b/trainer/rpcserver/trainer_server_v1.go deleted file mode 100644 index da30e29f079..00000000000 --- a/trainer/rpcserver/trainer_server_v1.go +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package rpcserver - -import ( - trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1" - - "d7y.io/dragonfly/v2/trainer/config" - "d7y.io/dragonfly/v2/trainer/metrics" - "d7y.io/dragonfly/v2/trainer/service" - storage "d7y.io/dragonfly/v2/trainer/storage" - "d7y.io/dragonfly/v2/trainer/training" -) - -// trainerServerV1 is v1 version of the trainer grpc server. -type trainerServerV1 struct { - // Service interface. - service *service.V1 -} - -// newTrainerServerV1 returns a new trainerServerV1 instance. -func newTrainerServerV1(cfg *config.Config, storage storage.Storage, training training.Training) trainerv1.TrainerServer { - return &trainerServerV1{service.NewV1(cfg, storage, training)} -} - -// Train handles the training request from scheduler. -func (t *trainerServerV1) Train(stream trainerv1.Trainer_TrainServer) error { - // Collect TrainCount metrics. - metrics.TrainCount.Inc() - if err := t.service.Train(stream); err != nil { - // Collect TrainFailureCount metrics. - metrics.TrainFailureCount.Inc() - return err - } - - return nil -} diff --git a/trainer/service/service_v1.go b/trainer/service/service_v1.go deleted file mode 100644 index af4e828e6ef..00000000000 --- a/trainer/service/service_v1.go +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package service - -import ( - "context" - "fmt" - "io" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/emptypb" - - trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1" - - logger "d7y.io/dragonfly/v2/internal/dflog" - "d7y.io/dragonfly/v2/pkg/idgen" - "d7y.io/dragonfly/v2/trainer/config" - "d7y.io/dragonfly/v2/trainer/storage" - "d7y.io/dragonfly/v2/trainer/training" -) - -// V1 is the interface for v1 version of the service. -type V1 struct { - // Trainer service config. - config *config.Config - - // Storage Interface. - storage storage.Storage - - // Training Interface. - training training.Training -} - -// New v1 version of service instance. -func NewV1( - cfg *config.Config, - storage storage.Storage, - training training.Training, -) *V1 { - return &V1{cfg, storage, training} -} - -// Train implements the Trainer.Train method. -func (v *V1) Train(stream trainerv1.Trainer_TrainServer) error { - var ( - ip string - hostname string - hostID string - networkTopologyFile io.WriteCloser - downloadFile io.WriteCloser - req *trainerv1.TrainRequest - initialized bool - err error - ) - - for { - req, err = stream.Recv() - if err != nil { - if err == io.EOF { - break - } - - logger.Errorf("receive failed: %s", err.Error()) - return err - } - - logger := logger.WithHostnameAndIP(req.Hostname, req.Ip) - if !initialized { - initialized = true - ip = req.Ip - hostname = req.Hostname - hostID = idgen.HostIDV2(req.Ip, req.Hostname) - - // Open network topology file and store received data. - networkTopologyFile, err = v.storage.OpenNetworkTopology(hostID) - if err != nil { - msg := fmt.Sprintf("open network topology failed: %s", err.Error()) - logger.Error(msg) - return status.Error(codes.Internal, msg) - } - defer func() { - networkTopologyFile.Close() - - // If error occurred, clear network topology. - if err != nil && err != io.EOF { - if err := v.storage.ClearNetworkTopology(hostID); err != nil { - logger.Errorf("clear network topology failed: %s", err.Error()) - } - } - }() - - // Open download file and store received data. - downloadFile, err = v.storage.OpenDownload(hostID) - if err != nil { - msg := fmt.Sprintf("open download failed: %s", err.Error()) - logger.Error(msg) - return status.Error(codes.Internal, msg) - } - defer func() { - downloadFile.Close() - - // If error occurred, clear download. - if err != nil && err != io.EOF { - if err := v.storage.ClearDownload(hostID); err != nil { - logger.Errorf("clear download failed: %s", err.Error()) - } - } - }() - } - - switch trainRequest := req.GetRequest().(type) { - case *trainerv1.TrainRequest_TrainGnnRequest: - // Store network topology. - if _, err := networkTopologyFile.Write(trainRequest.TrainGnnRequest.Dataset); err != nil { - msg := fmt.Sprintf("write network topology failed: %s", err.Error()) - logger.Error(msg) - return status.Error(codes.Internal, msg) - } - case *trainerv1.TrainRequest_TrainMlpRequest: - // Store download. - if _, err := downloadFile.Write(trainRequest.TrainMlpRequest.Dataset); err != nil { - msg := fmt.Sprintf("write download failed: %s", err.Error()) - logger.Error(msg) - return status.Error(codes.Internal, msg) - } - default: - msg := fmt.Sprintf("receive unknown request: %#v", trainRequest) - logger.Error(msg) - return status.Error(codes.FailedPrecondition, msg) - } - } - - // Send empty response and close stream. - if err := stream.SendAndClose(&emptypb.Empty{}); err != nil { - logger.Errorf("send and close failed: %s", err.Error()) - return err - } - - // If all dataset received, start training. - go func() { - if err := v.training.Train(context.Background(), ip, hostname); err != nil { - logger.Errorf("train failed: %s", err.Error()) - } - }() - - return nil -} diff --git a/trainer/service/service_v1_test.go b/trainer/service/service_v1_test.go deleted file mode 100644 index c1879e95cfa..00000000000 --- a/trainer/service/service_v1_test.go +++ /dev/null @@ -1,505 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package service - -import ( - "context" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "reflect" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" - "google.golang.org/protobuf/types/known/emptypb" - - trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1" - trainerv1mocks "d7y.io/api/v2/pkg/apis/trainer/v1/mocks" - - "d7y.io/dragonfly/v2/pkg/idgen" - "d7y.io/dragonfly/v2/trainer/config" - storagemocks "d7y.io/dragonfly/v2/trainer/storage/mocks" - trainingmocks "d7y.io/dragonfly/v2/trainer/training/mocks" -) - -var ( - mockHostName = "localhost" - mockIP = "127.0.0.1" - mockHostID = idgen.HostIDV2(mockIP, mockHostName) - mockDataset = []byte("foo") -) - -func TestService_NewV1(t *testing.T) { - tests := []struct { - name string - run func(t *testing.T, s any) - }{ - { - name: "new service", - run: func(t *testing.T, s any) { - assert := assert.New(t) - assert.Equal(reflect.TypeOf(s).Elem().Name(), "V1") - }, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - storage := storagemocks.NewMockStorage(ctl) - training := trainingmocks.NewMockTraining(ctl) - tc.run(t, NewV1(config.New(), storage, training)) - }) - } -} - -func TestV1_Train(t *testing.T) { - tests := []struct { - name string - run func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) - }{ - { - name: "receive GNN and MLP train requests success", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), - fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), - fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - var wg sync.WaitGroup - wg.Add(1) - defer wg.Wait() - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1), - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainMlpRequest{ - TrainMlpRequest: &trainerv1.TrainMLPRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - mtts.Recv().Return(nil, io.EOF).Times(1), - mtts.SendAndClose(new(emptypb.Empty)).Return(nil).Times(1), - mt.Train(context.Background(), mockIP, mockHostName).DoAndReturn(func(ctx context.Context, ip, hostName string) error { - wg.Done() - return nil - }).Times(1), - ) - - assert := assert.New(t) - assert.NoError(svc.Train(stream)) - }, - }, - { - name: "receive error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - gomock.InOrder( - mtts.Recv().Return(nil, errors.New("receive error")).Times(1), - ) - - assert := assert.New(t) - assert.EqualError(svc.Train(stream), "receive error") - }, - }, - { - name: "open network topology file error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(nil, errors.New("open network topology file error")).Times(1), - ) - - assert := assert.New(t) - assert.EqualError(svc.Train(stream), - "rpc error: code = Internal desc = open network topology failed: open network topology file error") - }, - }, - { - name: "open download file error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(nil, errors.New("open download file error")).Times(1), - ms.ClearNetworkTopology(mockHostID).Do(func(id string) { - networktopologyFile.Close() - if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", id, "csv"))); err != nil { - t.Fatal(err) - } - }).Return(nil).Times(1), - ) - - assert := assert.New(t) - assert.EqualError(svc.Train(stream), - "rpc error: code = Internal desc = open download failed: open download file error") - }, - }, - { - name: "clear network topology file error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1), - mtts.Recv().Return(nil, errors.New("receive error")).Times(1), - ms.ClearDownload(mockHostID).Do(func(id string) { - downloadFile.Close() - if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", id, "csv"))); err != nil { - t.Fatal(err) - } - }).Return(nil).Times(1), - - ms.ClearNetworkTopology(mockHostID).Do(func(id string) { - networktopologyFile.Close() - if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", id, "csv"))); err != nil { - t.Fatal(err) - } - }).Return(errors.New("clear network topology file error")).Times(1), - ) - - assert := assert.New(t) - assert.EqualError(svc.Train(stream), "receive error") - }, - }, - { - name: "clear download file error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1), - mtts.Recv().Return(nil, errors.New("receive error")).Times(1), - ms.ClearDownload(mockHostID).Do(func(id string) { - downloadFile.Close() - if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", id, "csv"))); err != nil { - t.Fatal(err) - } - }).Return(errors.New("clear download file error")).Times(1), - - ms.ClearNetworkTopology(mockHostID).Do(func(id string) { - networktopologyFile.Close() - if err := os.Remove(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", id, "csv"))); err != nil { - t.Fatal(err) - } - }).Return(nil).Times(1), - ) - - assert := assert.New(t) - assert.EqualError(svc.Train(stream), "receive error") - }, - }, - { - name: "store network topology error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1), - ) - - networktopologyFile.Close() - assert := assert.New(t) - assert.EqualError(svc.Train(stream), - "rpc error: code = Internal desc = write network topology failed: write /tmp/networktopology-52fa2eb710c71cc3e6ba7be6ca82453fcfe59e1c5da358ab3df8b72fd4d2a2cf.csv: file already closed") - }, - }, - { - name: "store download error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainMlpRequest{ - TrainMlpRequest: &trainerv1.TrainMLPRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1), - ) - - downloadFile.Close() - assert := assert.New(t) - assert.EqualError(svc.Train(stream), - "rpc error: code = Internal desc = write download failed: write /tmp/download-52fa2eb710c71cc3e6ba7be6ca82453fcfe59e1c5da358ab3df8b72fd4d2a2cf.csv: file already closed") - }, - }, - { - name: "receive unknown request", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: nil, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1), - ) - - assert := assert.New(t) - assert.EqualError(svc.Train(stream), "rpc error: code = FailedPrecondition desc = receive unknown request: ") - }, - }, - { - name: "send and close error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1), - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainMlpRequest{ - TrainMlpRequest: &trainerv1.TrainMLPRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - mtts.Recv().Return(nil, io.EOF).Times(1), - mtts.SendAndClose(new(emptypb.Empty)).Return(errors.New("send and close error")).Times(1), - ) - - assert := assert.New(t) - assert.EqualError(svc.Train(stream), "send and close error") - }, - }, - { - name: "training error", - run: func(t *testing.T, svc *V1, stream trainerv1.Trainer_TrainServer, mtts *trainerv1mocks.MockTrainer_TrainServerMockRecorder, - ms *storagemocks.MockStorageMockRecorder, mt *trainingmocks.MockTrainingMockRecorder) { - networktopologyFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "networktopology", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(os.TempDir(), fmt.Sprintf("%s-%s.%s", "download", mockHostID, "csv")), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) - if err != nil { - t.Fatal(err) - } - - var wg sync.WaitGroup - wg.Add(1) - defer wg.Wait() - gomock.InOrder( - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainGnnRequest{ - TrainGnnRequest: &trainerv1.TrainGNNRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - - ms.OpenNetworkTopology(mockHostID).Return(networktopologyFile, nil).Times(1), - ms.OpenDownload(mockHostID).Return(downloadFile, nil).Times(1), - mtts.Recv().Return(&trainerv1.TrainRequest{ - Hostname: mockHostName, - Ip: mockIP, - Request: &trainerv1.TrainRequest_TrainMlpRequest{ - TrainMlpRequest: &trainerv1.TrainMLPRequest{ - Dataset: mockDataset, - }, - }, - }, nil).Times(1), - mtts.Recv().Return(nil, io.EOF).Times(1), - mtts.SendAndClose(new(emptypb.Empty)).Return(nil).Times(1), - mt.Train(context.Background(), mockIP, mockHostName).DoAndReturn(func(ctx context.Context, ip, hostName string) error { - wg.Done() - return errors.New("training error") - }).Times(1), - ) - - assert := assert.New(t) - assert.NoError(svc.Train(stream)) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - storage := storagemocks.NewMockStorage(ctl) - training := trainingmocks.NewMockTraining(ctl) - stream := trainerv1mocks.NewMockTrainer_TrainServer(ctl) - - svc := NewV1(config.New(), storage, training) - tc.run(t, svc, stream, stream.EXPECT(), storage.EXPECT(), training.EXPECT()) - }) - } -} diff --git a/trainer/storage/mocks/storage_mock.go b/trainer/storage/mocks/storage_mock.go deleted file mode 100644 index e687b41f1de..00000000000 --- a/trainer/storage/mocks/storage_mock.go +++ /dev/null @@ -1,143 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: storage.go -// -// Generated by this command: -// -// mockgen -destination mocks/storage_mock.go -source storage.go -package mocks -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - os "os" - reflect "reflect" - - storage "d7y.io/dragonfly/v2/scheduler/storage" - gomock "go.uber.org/mock/gomock" -) - -// MockStorage is a mock of Storage interface. -type MockStorage struct { - ctrl *gomock.Controller - recorder *MockStorageMockRecorder -} - -// MockStorageMockRecorder is the mock recorder for MockStorage. -type MockStorageMockRecorder struct { - mock *MockStorage -} - -// NewMockStorage creates a new mock instance. -func NewMockStorage(ctrl *gomock.Controller) *MockStorage { - mock := &MockStorage{ctrl: ctrl} - mock.recorder = &MockStorageMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStorage) EXPECT() *MockStorageMockRecorder { - return m.recorder -} - -// Clear mocks base method. -func (m *MockStorage) Clear() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Clear") - ret0, _ := ret[0].(error) - return ret0 -} - -// Clear indicates an expected call of Clear. -func (mr *MockStorageMockRecorder) Clear() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockStorage)(nil).Clear)) -} - -// ClearDownload mocks base method. -func (m *MockStorage) ClearDownload(arg0 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClearDownload", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// ClearDownload indicates an expected call of ClearDownload. -func (mr *MockStorageMockRecorder) ClearDownload(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearDownload", reflect.TypeOf((*MockStorage)(nil).ClearDownload), arg0) -} - -// ClearNetworkTopology mocks base method. -func (m *MockStorage) ClearNetworkTopology(arg0 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClearNetworkTopology", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// ClearNetworkTopology indicates an expected call of ClearNetworkTopology. -func (mr *MockStorageMockRecorder) ClearNetworkTopology(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearNetworkTopology", reflect.TypeOf((*MockStorage)(nil).ClearNetworkTopology), arg0) -} - -// ListDownload mocks base method. -func (m *MockStorage) ListDownload(arg0 string) ([]storage.Download, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListDownload", arg0) - ret0, _ := ret[0].([]storage.Download) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListDownload indicates an expected call of ListDownload. -func (mr *MockStorageMockRecorder) ListDownload(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListDownload", reflect.TypeOf((*MockStorage)(nil).ListDownload), arg0) -} - -// ListNetworkTopology mocks base method. -func (m *MockStorage) ListNetworkTopology(arg0 string) ([]storage.NetworkTopology, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListNetworkTopology", arg0) - ret0, _ := ret[0].([]storage.NetworkTopology) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListNetworkTopology indicates an expected call of ListNetworkTopology. -func (mr *MockStorageMockRecorder) ListNetworkTopology(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListNetworkTopology", reflect.TypeOf((*MockStorage)(nil).ListNetworkTopology), arg0) -} - -// OpenDownload mocks base method. -func (m *MockStorage) OpenDownload(arg0 string) (*os.File, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenDownload", arg0) - ret0, _ := ret[0].(*os.File) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenDownload indicates an expected call of OpenDownload. -func (mr *MockStorageMockRecorder) OpenDownload(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenDownload", reflect.TypeOf((*MockStorage)(nil).OpenDownload), arg0) -} - -// OpenNetworkTopology mocks base method. -func (m *MockStorage) OpenNetworkTopology(arg0 string) (*os.File, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OpenNetworkTopology", arg0) - ret0, _ := ret[0].(*os.File) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// OpenNetworkTopology indicates an expected call of OpenNetworkTopology. -func (mr *MockStorageMockRecorder) OpenNetworkTopology(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenNetworkTopology", reflect.TypeOf((*MockStorage)(nil).OpenNetworkTopology), arg0) -} diff --git a/trainer/storage/storage.go b/trainer/storage/storage.go deleted file mode 100644 index 4bd02363070..00000000000 --- a/trainer/storage/storage.go +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -//go:generate mockgen -destination mocks/storage_mock.go -source storage.go -package mocks - -package storage - -import ( - "errors" - "fmt" - "os" - "path/filepath" - - "github.com/gocarina/gocsv" - - schedulerstorage "d7y.io/dragonfly/v2/scheduler/storage" -) - -const ( - // DownloadFilePrefix is prefix of download file name. - DownloadFilePrefix = "download" - - // NetworkTopologyFilePrefix is prefix of network topology file name. - NetworkTopologyFilePrefix = "networktopology" - - // CSVFileExt is extension of file name. - CSVFileExt = "csv" -) - -// Storage is the interface used for storage. -type Storage interface { - // ListDownload returns downloads in csv files based on the given model key. - ListDownload(string) ([]schedulerstorage.Download, error) - - // ListNetworkTopology returns network topologies in csv files based on the given model key. - ListNetworkTopology(string) ([]schedulerstorage.NetworkTopology, error) - - // OpenDownload opens download files for read based on the given model key, it returns io.ReadCloser of download files. - OpenDownload(string) (*os.File, error) - - // OpenNetworkTopology opens network topology files for read based on the given model key, it returns io.ReadCloser of network topology files. - OpenNetworkTopology(string) (*os.File, error) - - // ClearDownload removes all downloads based on the given model key. - ClearDownload(string) error - - // ClearNetworkTopology removes network topologies based on the given model key. - ClearNetworkTopology(string) error - - // Clear removes all files. - Clear() error -} - -// storage provides storage function. -type storage struct { - baseDir string -} - -// New returns a new Storage instance. -func New(baseDir string) Storage { - return &storage{baseDir: baseDir} -} - -// ListDownload returns downloads in csv files based on the given model key. -func (s *storage) ListDownload(key string) (downloads []schedulerstorage.Download, err error) { - file, err := s.OpenDownload(key) - if err != nil { - return nil, err - } - defer func() { - if cerr := file.Close(); cerr != nil { - err = errors.Join(err, cerr) - } - }() - - if err = gocsv.UnmarshalWithoutHeaders(file, &downloads); err != nil { - return nil, err - } - - return downloads, nil -} - -// ListNetworkTopology returns network topologies in csv files based on the given model key. -func (s *storage) ListNetworkTopology(key string) (networkTopologies []schedulerstorage.NetworkTopology, err error) { - file, err := s.OpenNetworkTopology(key) - if err != nil { - return nil, err - } - defer func() { - if cerr := file.Close(); cerr != nil { - err = errors.Join(err, cerr) - } - }() - - if err = gocsv.UnmarshalWithoutHeaders(file, &networkTopologies); err != nil { - return nil, err - } - - return networkTopologies, nil -} - -// OpenDownload opens download files for read based on the given model key, it returns io.ReadCloser of download files. -func (s *storage) OpenDownload(key string) (*os.File, error) { - return os.OpenFile(s.downloadFilename(key), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) -} - -// OpenNetworkTopology opens network topology files for read based on the given model key, it returns io.ReadCloser of network topology files. -func (s *storage) OpenNetworkTopology(key string) (*os.File, error) { - return os.OpenFile(s.networkTopologyFilename(key), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) -} - -// ClearDownload removes downloads based on the given model key. -func (s *storage) ClearDownload(key string) error { - return os.Remove(s.downloadFilename(key)) -} - -// ClearNetworkTopology removes network topologies based on the given model key. -func (s *storage) ClearNetworkTopology(key string) error { - return os.Remove(s.networkTopologyFilename(key)) -} - -// Clear removes all files. -func (s *storage) Clear() error { - return os.RemoveAll(s.baseDir) -} - -// downloadFilename generates download file name based on the given model key. -func (s *storage) downloadFilename(key string) string { - return filepath.Join(s.baseDir, fmt.Sprintf("%s_%s.%s", DownloadFilePrefix, key, CSVFileExt)) -} - -// networkTopologyFilename generates network topology file name based on the given model key. -func (s *storage) networkTopologyFilename(key string) string { - return filepath.Join(s.baseDir, fmt.Sprintf("%s_%s.%s", NetworkTopologyFilePrefix, key, CSVFileExt)) -} diff --git a/trainer/storage/storage_test.go b/trainer/storage/storage_test.go deleted file mode 100644 index e26b85a4898..00000000000 --- a/trainer/storage/storage_test.go +++ /dev/null @@ -1,554 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package storage - -import ( - "fmt" - "io/fs" - "os" - "path/filepath" - "reflect" - "regexp" - "testing" - - "github.com/gocarina/gocsv" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - schedulerstorage "d7y.io/dragonfly/v2/scheduler/storage" -) - -var mockModelKey = "bar" - -func TestStorage_New(t *testing.T) { - tests := []struct { - name string - baseDir string - expect func(t *testing.T, s Storage) - }{ - { - name: "new storage", - baseDir: os.TempDir(), - expect: func(t *testing.T, s Storage) { - assert := assert.New(t) - assert.Equal(reflect.TypeOf(s).Elem().Name(), "storage") - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - tc.expect(t, New(tc.baseDir)) - }) - } -} - -func TestStorage_ListDownload(t *testing.T) { - require := require.New(t) - testData, err := os.ReadFile("./testdata/download.csv") - require.Nil(err, "load test file") - - tests := []struct { - name string - baseDir string - mock func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) - expect func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) - }{ - { - name: "empty csv file given", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "download_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - assert := assert.New(t) - _, err := s.ListDownload(modelKey) - assert.EqualError(err, "empty csv file given") - }, - }, - { - name: "get file failed", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "download_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - if _, err = file.Write(download); err != nil { - t.Fatal(err) - } - s.(*storage).baseDir = "bas" - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - assert := assert.New(t) - _, err := s.ListDownload(modelKey) - assert.EqualError(err, "open bas/download_bar.csv: no such file or directory") - s.(*storage).baseDir = baseDir - }, - }, - { - name: "list downloads of a file", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "download_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - if _, err = file.Write(download); err != nil { - t.Fatal(err) - } - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - assert := assert.New(t) - list, err := s.ListDownload(modelKey) - assert.NoError(err) - assert.Equal(len(list), 1) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - s := New(tc.baseDir) - tc.mock(t, s, tc.baseDir, mockModelKey, testData) - tc.expect(t, s, tc.baseDir, mockModelKey, testData) - if err := s.ClearDownload(mockModelKey); err != nil { - t.Fatal(err) - } - }) - } -} - -func TestStorage_ListNetworkTopology(t *testing.T) { - require := require.New(t) - testData, err := os.ReadFile("./testdata/networktopology.csv") - require.Nil(err, "load test file") - - tests := []struct { - name string - baseDir string - mock func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) - expect func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) - }{ - { - name: "empty csv file given", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "networktopology_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - assert := assert.New(t) - _, err := s.ListNetworkTopology(modelKey) - assert.EqualError(err, "empty csv file given") - }, - }, - { - name: "get file failed", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "networktopology_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - if _, err = file.Write(networkTopology); err != nil { - t.Fatal(err) - } - s.(*storage).baseDir = "foo" - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - assert := assert.New(t) - _, err := s.ListNetworkTopology(modelKey) - assert.EqualError(err, "open foo/networktopology_bar.csv: no such file or directory") - s.(*storage).baseDir = baseDir - }, - }, - { - name: "list network topologies of a file", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "networktopology_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - if _, err = file.Write(networkTopology); err != nil { - t.Fatal(err) - } - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - assert := assert.New(t) - list, err := s.ListNetworkTopology(modelKey) - assert.NoError(err) - assert.Equal(len(list), 1) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - s := New(tc.baseDir) - tc.mock(t, s, tc.baseDir, mockModelKey, testData) - tc.expect(t, s, tc.baseDir, mockModelKey, testData) - if err := s.ClearNetworkTopology(mockModelKey); err != nil { - t.Fatal(err) - } - }) - } -} - -func TestStorage_OpenDownload(t *testing.T) { - require := require.New(t) - testData, err := os.ReadFile("./testdata/download.csv") - require.Nil(err, "load test file") - - tests := []struct { - name string - baseDir string - mock func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) - expect func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) - }{ - { - name: "open file failed", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "download_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - s.(*storage).baseDir = "baw" - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - assert := assert.New(t) - _, err := s.OpenDownload(modelKey) - assert.EqualError(err, "open baw/download_bar.csv: no such file or directory") - s.(*storage).baseDir = baseDir - }, - }, - { - name: "open storage with downloads of a file", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "download_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - if _, err = file.Write(download); err != nil { - t.Fatal(err) - } - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, download []byte) { - assert := assert.New(t) - readCloser, err := s.OpenDownload(modelKey) - assert.NoError(err) - - var downloads []schedulerstorage.Download - assert.NoError(gocsv.UnmarshalWithoutHeaders(readCloser, &downloads)) - assert.Equal(len(downloads), 1) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - s := New(tc.baseDir) - tc.mock(t, s, tc.baseDir, mockModelKey, testData) - tc.expect(t, s, tc.baseDir, mockModelKey, testData) - if err := s.ClearDownload(mockModelKey); err != nil { - t.Fatal(err) - } - }) - } -} - -func TestStorage_OpenNetworkTopology(t *testing.T) { - require := require.New(t) - testData, err := os.ReadFile("./testdata/networktopology.csv") - require.Nil(err, "load test file") - - tests := []struct { - name string - baseDir string - mock func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) - expect func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) - }{ - { - name: "open file failed", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "networktopology_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - s.(*storage).baseDir = "bas" - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - assert := assert.New(t) - _, err := s.OpenNetworkTopology(modelKey) - assert.EqualError(err, "open bas/networktopology_bar.csv: no such file or directory") - s.(*storage).baseDir = baseDir - }, - }, - { - name: "open storage with network topologies of a file", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - file, err := os.OpenFile(filepath.Join(baseDir, "networktopology_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - if _, err = file.Write(networkTopology); err != nil { - t.Fatal(err) - } - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string, networkTopology []byte) { - assert := assert.New(t) - readCloser, err := s.OpenNetworkTopology(modelKey) - assert.NoError(err) - - var networkTopologies []schedulerstorage.NetworkTopology - assert.NoError(gocsv.UnmarshalWithoutHeaders(readCloser, &networkTopologies)) - assert.Equal(len(networkTopologies), 1) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - s := New(tc.baseDir) - tc.mock(t, s, tc.baseDir, mockModelKey, testData) - tc.expect(t, s, tc.baseDir, mockModelKey, testData) - if err := s.ClearNetworkTopology(mockModelKey); err != nil { - t.Fatal(err) - } - }) - } -} - -func TestStorage_ClearDownload(t *testing.T) { - tests := []struct { - name string - baseDir string - mock func(t *testing.T, s Storage, baseDir, modelKey string) - expect func(t *testing.T, s Storage, baseDir, modelKey string) - }{ - { - name: "clear file", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string) { - file, err := os.OpenFile(filepath.Join(baseDir, "download_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string) { - assert := assert.New(t) - assert.NoError(s.ClearDownload(modelKey)) - fileInfos, err := os.ReadDir(filepath.Join(baseDir)) - assert.NoError(err) - - var backups []fs.FileInfo - re := regexp.MustCompile(fmt.Sprintf("%s_%s", DownloadFilePrefix, modelKey)) - - for _, fileInfo := range fileInfos { - if !fileInfo.IsDir() && re.MatchString(fileInfo.Name()) { - info, _ := fileInfo.Info() - backups = append(backups, info) - } - } - assert.Equal(len(backups), 0) - }, - }, - { - name: "open file failed", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string) { - file, err := os.OpenFile(filepath.Join(baseDir, "download_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - s.(*storage).baseDir = "baz" - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string) { - assert := assert.New(t) - assert.EqualError(s.ClearDownload(modelKey), "remove baz/download_bar.csv: no such file or directory") - - s.(*storage).baseDir = baseDir - assert.NoError(s.ClearDownload(modelKey)) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - s := New(tc.baseDir) - tc.mock(t, s, tc.baseDir, mockModelKey) - tc.expect(t, s, tc.baseDir, mockModelKey) - }) - } -} - -func TestStorage_ClearNetworkTopology(t *testing.T) { - tests := []struct { - name string - baseDir string - mock func(t *testing.T, s Storage, baseDir, modelKey string) - expect func(t *testing.T, s Storage, baseDir, modelKey string) - }{ - { - name: "clear file", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string) { - file, err := os.OpenFile(filepath.Join(baseDir, "networktopology_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string) { - assert := assert.New(t) - assert.NoError(s.ClearNetworkTopology(modelKey)) - fileInfos, err := os.ReadDir(filepath.Join(baseDir)) - assert.NoError(err) - - var backups []fs.FileInfo - re := regexp.MustCompile(fmt.Sprintf("%s_%s", NetworkTopologyFilePrefix, modelKey)) - for _, fileInfo := range fileInfos { - if !fileInfo.IsDir() && re.MatchString(fileInfo.Name()) { - info, _ := fileInfo.Info() - backups = append(backups, info) - } - } - assert.Equal(len(backups), 0) - }, - }, - { - name: "open file failed", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir, modelKey string) { - file, err := os.OpenFile(filepath.Join(baseDir, "networktopology_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer file.Close() - s.(*storage).baseDir = "baz" - }, - expect: func(t *testing.T, s Storage, baseDir, modelKey string) { - assert := assert.New(t) - assert.EqualError(s.ClearNetworkTopology(modelKey), "remove baz/networktopology_bar.csv: no such file or directory") - s.(*storage).baseDir = baseDir - assert.NoError(s.ClearNetworkTopology(modelKey)) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - s := New(tc.baseDir) - tc.mock(t, s, tc.baseDir, mockModelKey) - tc.expect(t, s, tc.baseDir, mockModelKey) - }) - } -} - -func TestStorage_Clear(t *testing.T) { - tests := []struct { - name string - baseDir string - mock func(t *testing.T, s Storage, baseDir string) - expect func(t *testing.T, s Storage, baseDir string) - }{ - { - name: "clear file", - baseDir: os.TempDir(), - mock: func(t *testing.T, s Storage, baseDir string) { - s.(*storage).baseDir = filepath.Join(baseDir, "bae") - if err := os.MkdirAll(s.(*storage).baseDir, fs.FileMode(0700)); err != nil { - t.Fatal(err) - } - - downloadFile, err := os.OpenFile(filepath.Join(s.(*storage).baseDir, "download_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer downloadFile.Close() - - networkTopologyFile, err := os.OpenFile(filepath.Join(s.(*storage).baseDir, "networktopology_bar.csv"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - t.Fatal(err) - } - defer networkTopologyFile.Close() - }, - expect: func(t *testing.T, s Storage, baseDir string) { - assert := assert.New(t) - assert.NoError(s.Clear()) - _, err := os.Stat(filepath.Join(baseDir, "bae")) - assert.EqualError(err, fmt.Sprintf("stat %s: no such file or directory", filepath.Join(baseDir, "bae"))) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - s := New(tc.baseDir) - tc.mock(t, s, tc.baseDir) - tc.expect(t, s, tc.baseDir) - }) - } -} - -func TestStorage_downloadFilename(t *testing.T) { - baseDir := os.TempDir() - s := New(baseDir) - - filename := s.(*storage).downloadFilename(mockModelKey) - re := regexp.MustCompile(fmt.Sprintf("%s_%s.%s$", DownloadFilePrefix, mockModelKey, CSVFileExt)) - assert := assert.New(t) - assert.True(re.MatchString(filename)) -} - -func TestStorage_networkTopologyFilename(t *testing.T) { - baseDir := os.TempDir() - s := New(baseDir) - - filename := s.(*storage).networkTopologyFilename(mockModelKey) - re := regexp.MustCompile(fmt.Sprintf("%s_%s.%s$", NetworkTopologyFilePrefix, mockModelKey, CSVFileExt)) - assert := assert.New(t) - assert.True(re.MatchString(filename)) -} diff --git a/trainer/storage/testdata/download.csv b/trainer/storage/testdata/download.csv deleted file mode 100644 index bbc6d023ff7..00000000000 --- a/trainer/storage/testdata/download.csv +++ /dev/null @@ -1 +0,0 @@ -,,,,,,0,0,,,,0,0,0,0,,0,0,,,,,0,0,,,,,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,,,0,0,0,0,0,0,0,0,,,,,0,0,0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,0 diff --git a/trainer/storage/testdata/networktopology.csv b/trainer/storage/testdata/networktopology.csv deleted file mode 100644 index 4fd3b9eab25..00000000000 --- a/trainer/storage/testdata/networktopology.csv +++ /dev/null @@ -1 +0,0 @@ -6,3,super,foo,127.0.0.1,8080,400,200,china,e1,2,normal,localhost,127.0.0.1,8080,400,200,china,e1,10,1686036525367538000,1686036525367538000,2,normal,localhost,127.0.0.1,8080,400,200,china,e1,10,1686036525367538000,1686036525367538000,2,normal,localhost,127.0.0.1,8080,400,200,china,e1,10,1686036525367538000,1686036525367538000,2,normal,localhost,127.0.0.1,8080,400,200,china,e1,10,1686036525367538000,1686036525367538000,2,normal,localhost,127.0.0.1,8080,400,200,china,e1,10,1686036525367538000,1686036525367538000,1686036525367538000 diff --git a/trainer/trainer.go b/trainer/trainer.go deleted file mode 100644 index a14ab972f44..00000000000 --- a/trainer/trainer.go +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package trainer - -import ( - "context" - "errors" - "fmt" - "net" - "net/http" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - - logger "d7y.io/dragonfly/v2/internal/dflog" - "d7y.io/dragonfly/v2/pkg/dfpath" - "d7y.io/dragonfly/v2/pkg/net/ip" - "d7y.io/dragonfly/v2/pkg/rpc" - managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client" - "d7y.io/dragonfly/v2/trainer/config" - "d7y.io/dragonfly/v2/trainer/metrics" - "d7y.io/dragonfly/v2/trainer/rpcserver" - "d7y.io/dragonfly/v2/trainer/storage" - "d7y.io/dragonfly/v2/trainer/training" -) - -const ( - // gracefulStopTimeout specifies a time limit for - // grpc server to complete a graceful shutdown. - gracefulStopTimeout = 10 * time.Minute -) - -// Server is the trainer server. -type Server struct { - // Server configuration. - config *config.Config - - // GRPC server. - grpcServer *grpc.Server - - // Metrics server. - metricsServer *http.Server - - // Manager client. - managerClient managerclient.V2 - - // Storage interface. - storage storage.Storage -} - -// New creates a new Server. -func New(ctx context.Context, cfg *config.Config, d dfpath.Dfpath) (*Server, error) { - s := &Server{config: cfg} - - // Initialize dial options of manager grpc client. - managerDialOptions := []grpc.DialOption{} - if cfg.Security.AutoIssueCert { - clientTransportCredentials, err := rpc.NewClientCredentials(cfg.Security.TLSPolicy, nil, []byte(cfg.Security.CACert)) - if err != nil { - return nil, err - } - - managerDialOptions = append(managerDialOptions, grpc.WithTransportCredentials(clientTransportCredentials)) - } else { - managerDialOptions = append(managerDialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) - } - - // Initialize manager client. - managerClient, err := managerclient.GetV2ByAddr(ctx, cfg.Manager.Addr, managerDialOptions...) - if err != nil { - return nil, err - } - s.managerClient = managerClient - - // Initialize Storage. - s.storage = storage.New(d.DataDir()) - - // Initialize Training. - training := training.New(cfg, s.managerClient, s.storage) - - // Initialize trainer grpc server. - s.grpcServer = rpcserver.New(cfg, s.storage, training) - - // Initialize metrics. - if cfg.Metrics.Enable { - s.metricsServer = metrics.New(&cfg.Metrics, s.grpcServer) - } - - return s, nil -} - -// Serve starts the trainer server. -func (s *Server) Serve() error { - // Started metrics server. - if s.metricsServer != nil { - go func() { - logger.Infof("started metrics server at %s", s.metricsServer.Addr) - if err := s.metricsServer.ListenAndServe(); err != nil { - if err == http.ErrServerClosed { - return - } - - logger.Fatalf("metrics server closed unexpect: %s", err.Error()) - } - }() - } - - // Generate GRPC limit listener. - ip, ok := ip.FormatIP(s.config.Server.ListenIP.String()) - if !ok { - return errors.New("format ip failed") - } - - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", ip, s.config.Server.Port)) - if err != nil { - logger.Fatalf("net listener failed to start: %s", err.Error()) - } - defer listener.Close() - - // Started GRPC server. - logger.Infof("started grpc server at %s://%s", listener.Addr().Network(), listener.Addr().String()) - if err := s.grpcServer.Serve(listener); err != nil { - logger.Errorf("stoped grpc server: %s", err.Error()) - return err - } - - return nil -} - -// Stop stops the trainer server. -func (s *Server) Stop() { - // Stop manager client. - if s.managerClient != nil { - if err := s.managerClient.Close(); err != nil { - logger.Errorf("manager client failed to stop: %s", err.Error()) - } else { - logger.Info("manager client closed") - } - } - - // Clean storage file. - if err := s.storage.Clear(); err != nil { - logger.Errorf("clean storage file failed %s", err.Error()) - } else { - logger.Info("clean storage file completed") - } - - // Stop metrics server. - if s.metricsServer != nil { - if err := s.metricsServer.Shutdown(context.Background()); err != nil { - logger.Errorf("metrics server failed to stop: %s", err.Error()) - } else { - logger.Info("metrics server closed under request") - } - } - - // Stop GRPC server. - stopped := make(chan struct{}) - go func() { - s.grpcServer.GracefulStop() - logger.Info("grpc server closed under request") - close(stopped) - }() - - t := time.NewTimer(gracefulStopTimeout) - select { - case <-t.C: - s.grpcServer.Stop() - case <-stopped: - t.Stop() - } -} diff --git a/trainer/training/mocks/training_mock.go b/trainer/training/mocks/training_mock.go deleted file mode 100644 index a4ac62291d2..00000000000 --- a/trainer/training/mocks/training_mock.go +++ /dev/null @@ -1,54 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: training.go -// -// Generated by this command: -// -// mockgen -destination mocks/training_mock.go -source training.go -package mocks -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockTraining is a mock of Training interface. -type MockTraining struct { - ctrl *gomock.Controller - recorder *MockTrainingMockRecorder -} - -// MockTrainingMockRecorder is the mock recorder for MockTraining. -type MockTrainingMockRecorder struct { - mock *MockTraining -} - -// NewMockTraining creates a new mock instance. -func NewMockTraining(ctrl *gomock.Controller) *MockTraining { - mock := &MockTraining{ctrl: ctrl} - mock.recorder = &MockTrainingMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTraining) EXPECT() *MockTrainingMockRecorder { - return m.recorder -} - -// Train mocks base method. -func (m *MockTraining) Train(arg0 context.Context, arg1, arg2 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Train", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// Train indicates an expected call of Train. -func (mr *MockTrainingMockRecorder) Train(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Train", reflect.TypeOf((*MockTraining)(nil).Train), arg0, arg1, arg2) -} diff --git a/trainer/training/training.go b/trainer/training/training.go deleted file mode 100644 index 181c50a14db..00000000000 --- a/trainer/training/training.go +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright 2023 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package training - -import ( - "context" - - "golang.org/x/sync/errgroup" - - logger "d7y.io/dragonfly/v2/internal/dflog" - managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client" - "d7y.io/dragonfly/v2/trainer/config" - "d7y.io/dragonfly/v2/trainer/storage" -) - -//go:generate mockgen -destination mocks/training_mock.go -source training.go -package mocks - -// Training defines the interface to train GNN and MLP model. -type Training interface { - // Train begins training GNN and MLP model. - Train(context.Context, string, string) error -} - -// training implements Training interface. -type training struct { - // Trainer service config. - config *config.Config - - // Storage interface. - storage storage.Storage - - // Manager service client. - managerClient managerclient.V2 -} - -// New returns a new Training. -func New(cfg *config.Config, managerClient managerclient.V2, storage storage.Storage) Training { - return &training{ - config: cfg, - storage: storage, - managerClient: managerClient, - } -} - -// Train begins training GNN and MLP model. -func (t *training) Train(ctx context.Context, ip, hostname string) error { - eg, ctx := errgroup.WithContext(ctx) - eg.Go(func() error { - return t.trainGNN(ctx, ip, hostname) - }) - - eg.Go(func() error { - return t.trainMLP(ctx, ip, hostname) - }) - - // Wait for all train tasks to complete. - if err := eg.Wait(); err != nil { - logger.Errorf("training failed: %v", err) - return err - } - - // TODO Clean up training data. - return nil -} - -// TODO Add training GNN logic. -// trainGNN trains GNN model. -func (t *training) trainGNN(ctx context.Context, ip, hostname string) error { - // 1. Get training data from storage. - // 2. Preprocess training data. - // 2. Train GNN model. - // 3. Upload GNN model to manager service. - return nil -} - -// TODO Add training MLP logic. -// trainMLP trains MLP model. -func (t *training) trainMLP(ctx context.Context, ip, hostname string) error { - // 1. Get training data from storage. - // 2. Preprocess training data. - // 2. Train MLP model. - // 3. Upload MLP model to manager service. - return nil -}