From 6b27ffed7cbcfb0dba129ce00c19d49c65f127f6 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Wed, 25 Jan 2023 10:35:32 +0000 Subject: [PATCH] Initialise New Scheduler (#2050) * Make legacyScheduler generic * linting * linting * wip * wip * wip * wip * wip * remove nats from e2e * remove more nats config * fixed submit test * linting * go mod tidy * increase sleep to 60 * wip * wip * linting * unused context * remove extra queued-jobs-iterator * wip * wip * linting passes * documentation * unit tests * doc * more tests * linting * more doc, revert files modified in error * more linting * more linting * fix tests * merge master * doc improvements * wip * Update internal/scheduler/scheduler.go Co-authored-by: Albin Severinson * fixes following review * wip * renamed publisher_test * wip * linting * fixes after merge * Update internal/scheduler/scheduling_algo.go Co-authored-by: Clif Houck * code review comments * mocks * wip * wip * wip * doc for job repo * doc for job repo * addressed some warnings * restore nodes table * filter out queues with no jobs to schedule * wip * wip * remived nats from main makefiles * removed extra config * renamed file * code review comments * linting * linting * more linting * more linting * another nats reference * another nats reference * linting * fix tests * fix tests * fix tests * job repo test done * wip * merged master * merged master * fix changes following merge * restore package * fixed proto package names * linting * linting * fix null pointer in test * linting * linting * doc * unit tests * moved mock generation into its owen package * moved mock generation into its owen package * wip * go lint * linting * fixed package names * wip * linting * linting * linting * move master check into a separate function * don't modify restapi * doc * Update internal/armada/configuration/types.go Co-authored-by: Albin Severinson * Update internal/scheduler/database/job_repository.go Co-authored-by: Albin Severinson * Update internal/scheduler/reports.go Co-authored-by: Albin Severinson * wip * compilation fixes * updated proto * add job ids * wip * tests for api * linting * add grpc mock * fix tests * fix tests * fix tests * formatting * formatting * doc * more doc * wip * remove active_job_ids from streaming lease call * wip * nodes should be pointer not struct * wip * added job run lookup * wip * linting * linting * linting * fix tests * fix tests * executor tests * linting * linting * doc * fix tests * wip * wip * wip * wip * linting * import order * custom pulsar marshallers * config * more changes * commands * linting * added tests for hooks * lots of compilation fixes * go mod tidy * fix failing test * test for queue_repository * remove uneeded file * linting * flush db always * flush db always * review comments * linting Co-authored-by: Chris Martin Co-authored-by: Albin Severinson Co-authored-by: Clif Houck --- cmd/scheduler/cmd/main.go | 24 +++ cmd/scheduler/cmd/migrate_database.go | 51 ++++++ cmd/scheduler/cmd/root.go | 48 ++++++ cmd/scheduler/main.go | 59 +------ config/scheduler/config.yaml | 103 ++++++++++-- e2e/pulsartest_client/app_test.go | 20 --- go.mod | 4 + go.sum | 11 ++ internal/armada/configuration/types.go | 7 +- .../repository/apimessages/conversions.go | 1 - internal/armada/repository/event_store.go | 6 +- internal/armada/server.go | 16 +- internal/armada/server/submit_to_log.go | 2 +- internal/common/config/hooks.go | 98 ++++++++++++ internal/common/config/hooks_test.go | 125 +++++++++++++++ internal/common/config/redis.go | 48 ++++++ internal/common/config/validation.go | 34 ++++ internal/common/database/functions.go | 9 ++ internal/common/eventutil/eventutil.go | 10 +- internal/common/grpc/configuration/types.go | 1 + internal/common/pulsarutils/eventsequence.go | 2 +- internal/common/pulsarutils/pulsarclient.go | 37 ----- .../common/pulsarutils/pulsarclient_test.go | 52 ------- internal/common/startup.go | 27 +--- internal/eventingester/convert/conversions.go | 4 +- internal/eventingester/ingester.go | 2 +- internal/pulsartest/app.go | 15 +- internal/scheduler/api.go | 9 +- internal/scheduler/api_test.go | 6 +- internal/scheduler/config.go | 31 ++-- .../scheduler/database/executor_repository.go | 6 +- .../database/executor_repository_test.go | 6 +- internal/scheduler/database/job_repository.go | 4 +- .../scheduler/database/job_repository_test.go | 2 +- .../scheduler/database/queue_repository.go | 33 ++++ .../database/queue_repository_test.go | 67 ++++++++ internal/scheduler/leader.go | 11 ++ internal/scheduler/leader_test.go | 4 +- internal/scheduler/mocks/mock_repositories.go | 2 +- internal/scheduler/publisher.go | 5 +- internal/scheduler/publisher_test.go | 36 ++--- internal/scheduler/scheduler.go | 7 +- internal/scheduler/scheduler_test.go | 4 +- internal/scheduler/schedulerapp.go | 146 ++++++++++++++++++ internal/scheduler/scheduling_algo.go | 2 +- 45 files changed, 884 insertions(+), 313 deletions(-) create mode 100644 cmd/scheduler/cmd/main.go create mode 100644 cmd/scheduler/cmd/migrate_database.go create mode 100644 cmd/scheduler/cmd/root.go create mode 100644 internal/common/config/hooks.go create mode 100644 internal/common/config/hooks_test.go create mode 100644 internal/common/config/redis.go create mode 100644 internal/common/config/validation.go create mode 100644 internal/scheduler/database/queue_repository_test.go create mode 100644 internal/scheduler/schedulerapp.go diff --git a/cmd/scheduler/cmd/main.go b/cmd/scheduler/cmd/main.go new file mode 100644 index 00000000000..f6d96027969 --- /dev/null +++ b/cmd/scheduler/cmd/main.go @@ -0,0 +1,24 @@ +package cmd + +import ( + "github.com/spf13/cobra" + + "github.com/armadaproject/armada/internal/scheduler" +) + +func runCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "run", + Short: "Runs the scheduler", + RunE: runScheduler, + } + return cmd +} + +func runScheduler(_ *cobra.Command, _ []string) error { + config, err := loadConfig() + if err != nil { + return err + } + return scheduler.Run(config) +} diff --git a/cmd/scheduler/cmd/migrate_database.go b/cmd/scheduler/cmd/migrate_database.go new file mode 100644 index 00000000000..bed94b3793c --- /dev/null +++ b/cmd/scheduler/cmd/migrate_database.go @@ -0,0 +1,51 @@ +package cmd + +import ( + "context" + "time" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/armadaproject/armada/internal/common/database" + schedulerdb "github.com/armadaproject/armada/internal/scheduler/database" +) + +func migrateDbCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "migrateDatabase", + Short: "migrates the scheduler database to the latest version", + RunE: migrateDatabase, + } + cmd.PersistentFlags().Duration( + "timeout", + 5*time.Minute, + "Duration after which the migration will fail if it has not been created") + + return cmd +} + +func migrateDatabase(_ *cobra.Command, _ []string) error { + timeout := viper.GetDuration("timeout") + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + config, err := loadConfig() + if err != nil { + return err + } + start := time.Now() + log.Info("Beginning scheduler database migration") + db, err := database.OpenPgxConn(config.Postgres) + if err != nil { + return errors.WithMessagef(err, "Failed to connect to database") + } + err = schedulerdb.Migrate(ctx, db) + if err != nil { + return errors.WithMessagef(err, "Failed to migrate scheduler database") + } + taken := time.Since(start) + log.Infof("Scheduler database migrated in %s", taken) + return nil +} diff --git a/cmd/scheduler/cmd/root.go b/cmd/scheduler/cmd/root.go new file mode 100644 index 00000000000..77ebe2e2ec7 --- /dev/null +++ b/cmd/scheduler/cmd/root.go @@ -0,0 +1,48 @@ +package cmd + +import ( + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/armadaproject/armada/internal/common" + commonconfig "github.com/armadaproject/armada/internal/common/config" + "github.com/armadaproject/armada/internal/scheduler" +) + +const ( + CustomConfigLocation string = "config" +) + +func RootCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "scheduler", + SilenceUsage: true, + Short: "The main armada scheduler", + } + + cmd.PersistentFlags().StringSlice( + "armadaUrl", + []string{}, + "Fully qualified path to application configuration file (for multiple config files repeat this arg or separate paths with commas)") + + cmd.AddCommand( + runCmd(), + migrateDbCmd(), + ) + + return cmd +} + +func loadConfig() (scheduler.Configuration, error) { + var config scheduler.Configuration + userSpecifiedConfigs := viper.GetStringSlice(CustomConfigLocation) + + common.LoadConfig(&config, "./config/scheduler", userSpecifiedConfigs) + + // TODO: once we're happy with this we can move it to common app startup + err := commonconfig.Validate(config) + if err != nil { + commonconfig.LogValidationErrors(err) + } + return config, err +} diff --git a/cmd/scheduler/main.go b/cmd/scheduler/main.go index f6304e3e28c..4c40c10fc14 100644 --- a/cmd/scheduler/main.go +++ b/cmd/scheduler/main.go @@ -1,67 +1,12 @@ package main import ( - "context" - "os" - "time" - - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" - "github.com/spf13/pflag" - "github.com/spf13/viper" - + "github.com/armadaproject/armada/cmd/scheduler/cmd" "github.com/armadaproject/armada/internal/common" - "github.com/armadaproject/armada/internal/common/database" - "github.com/armadaproject/armada/internal/scheduler" - schedulerdb "github.com/armadaproject/armada/internal/scheduler/database" -) - -const ( - CustomConfigLocation string = "config" - MigrateDatabase string = "migrateDatabase" ) -func init() { - pflag.StringSlice( - CustomConfigLocation, - []string{}, - "Fully qualified path to application configuration file (for multiple config files repeat this arg or separate paths with commas)", - ) - pflag.Bool(MigrateDatabase, false, "Migrate database instead of running scheduler") - pflag.Parse() -} - func main() { common.ConfigureLogging() common.BindCommandlineArguments() - - var config scheduler.Configuration - userSpecifiedConfigs := viper.GetStringSlice(CustomConfigLocation) - - common.LoadConfig(&config, "./config/scheduler", userSpecifiedConfigs) - - if viper.GetBool(MigrateDatabase) { - migrateDatabase(&config) - } else { - if err := scheduler.Run(&config); err != nil { - log.Errorf("failed to run scheduler: %s", err) - os.Exit(1) - } - } -} - -func migrateDatabase(config *scheduler.Configuration) { - start := time.Now() - log.Info("Beginning scheduler database migration") - db, err := database.OpenPgxPool(config.Postgres) - if err != nil { - panic(errors.WithMessage(err, "Failed to connect to database")) - } - err = schedulerdb.Migrate(context.Background(), db) - if err != nil { - panic(errors.WithMessage(err, "Failed to migrate scheduler database")) - } - taken := time.Now().Sub(start) - log.Infof("Scheduler database migrated in %dms", taken.Milliseconds()) - os.Exit(0) + _ = cmd.RootCmd().Execute() } diff --git a/config/scheduler/config.yaml b/config/scheduler/config.yaml index 7fe5f8b0738..0ec155db7ce 100644 --- a/config/scheduler/config.yaml +++ b/config/scheduler/config.yaml @@ -1,25 +1,94 @@ +cyclePeriod: 10s +executorTimeout: 1h +databaseFetchSize: 1000 +pulsarSendTimeout: 5s +pulsar: + URL: "pulsar://pulsar:6650" + jobsetEventsTopic: "events" + maxConnectionsPerBroker: 1 + compressionType: zlib + compressionLevel: faster postgres: - maxOpenConns: 20 - maxIdleConns: 5 - connMaxLifetime: 30m connection: - host: localhost + host: postgres port: 5432 user: postgres password: psw dbname: postgres sslmode: disable - -metrics: - port: 9003 - -pulsar: - URL: "pulsar://localhost:6650" - jobsetEventsTopic: "events" - receiveTimeout: 5s - backoffTime: 1s - -subscriptionName: "scheduler" -batchSize: 10000 -batchDuration: 500ms +leader: + mode: standalone +grpc: + #port: 50052 + keepaliveParams: + maxConnectionIdle: 5m + time: 120s + timeout: 20s + keepaliveEnforcementPolicy: + minTime: 10s + permitWithoutStream: true +scheduling: + preemption: + enabled: true + priorityClasses: + armada-default: + priority: 1000 + maximalResourceFractionPerQueue: + memory: 0.99 + cpu: 0.99 + armada-preemptible: + priority: 900 + maximalResourceFractionPerQueue: + memory: 0.99 + cpu: 0.99 + defaultPriorityClass: armada-default + queueLeaseBatchSize: 1000 + minimumResourceToSchedule: + memory: 1000000 # 1Mb + cpu: 0.1 + maximalResourceFractionToSchedulePerQueue: + memory: 1.0 + cpu: 1.0 + maximalResourceFractionPerQueue: + memory: 1.0 + cpu: 1.0 + maximalClusterFractionToSchedule: + memory: 1.0 + cpu: 1.0 + maximumJobsToSchedule: 5000 + maxQueueReportsToStore: 1000 + MaxJobReportsToStore: 10000 + defaultJobLimits: + cpu: 1 + memory: 1Gi + ephemeral-storage: 8Gi + defaultJobTolerations: + - key: "armadaproject.io/armada" + operator: "Equal" + value: "true" + effect: "NoSchedule" + defaultJobTolerationsByPriorityClass: + "": + - key: "armadaproject.io/pc-armada-default" + operator: "Equal" + value: "true" + effect: "NoSchedule" + armada-default: + - key: "armadaproject.io/pc-armada-default" + operator: "Equal" + value: "true" + effect: "NoSchedule" + armada-preemptible: + - key: "armadaproject.io/pc-armada-preemptible" + operator: "Equal" + value: "true" + effect: "NoSchedule" + maxRetries: 5 + resourceScarcity: + cpu: 1.0 + indexedResources: + - cpu + - memory + gangIdAnnotation: armadaproject.io/gangId + gangCardinalityAnnotation: armadaproject.io/gangCardinality diff --git a/e2e/pulsartest_client/app_test.go b/e2e/pulsartest_client/app_test.go index 3ae55232f85..6ef204ada48 100644 --- a/e2e/pulsartest_client/app_test.go +++ b/e2e/pulsartest_client/app_test.go @@ -33,26 +33,6 @@ func TestNew(t *testing.T) { assert.Error(t, err) assert.Nil(t, app) - // Invalid compression type - pc = cfg.PulsarConfig{ - URL: "pulsar://localhost:6650", - CompressionType: "nocompression", - JobsetEventsTopic: "events", - } - app, err = pt.New(pt.Params{Pulsar: pc}, "submit") - assert.Error(t, err) - assert.Nil(t, app) - - // Invalid compression level - pc = cfg.PulsarConfig{ - URL: "pulsar://localhost:6650", - CompressionLevel: "veryCompressed", - JobsetEventsTopic: "events", - } - app, err = pt.New(pt.Params{Pulsar: pc}, "submit") - assert.Error(t, err) - assert.Nil(t, app) - // Invalid command type pc = cfg.PulsarConfig{ URL: "pulsar://localhost:6650", diff --git a/go.mod b/go.mod index 03f176c637f..60d49b19529 100644 --- a/go.mod +++ b/go.mod @@ -87,6 +87,7 @@ require ( github.com/go-openapi/strfmt v0.21.3 github.com/go-openapi/swag v0.22.3 github.com/go-openapi/validate v0.21.0 + github.com/go-playground/validator/v10 v10.11.1 github.com/golang/mock v1.6.0 github.com/goreleaser/goreleaser v1.11.5 github.com/jessevdk/go-flags v1.5.0 @@ -199,6 +200,8 @@ require ( github.com/go-logr/logr v0.4.0 // indirect github.com/go-openapi/inflect v0.19.0 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/go-playground/locales v0.14.0 // indirect + github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-telegram-bot-api/telegram-bot-api v4.6.4+incompatible // indirect github.com/gobwas/glob v0.2.3 // indirect github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect @@ -256,6 +259,7 @@ require ( github.com/klauspost/pgzip v1.2.5 // indirect github.com/kr/pretty v0.3.0 // indirect github.com/kr/text v0.2.0 // indirect + github.com/leodido/go-urn v1.2.1 // indirect github.com/linkedin/goavro/v2 v2.9.8 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/magiconair/properties v1.8.7 // indirect diff --git a/go.sum b/go.sum index 078a46a17ef..0c97d6e3936 100644 --- a/go.sum +++ b/go.sum @@ -487,11 +487,18 @@ github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/ github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-openapi/validate v0.21.0 h1:+Wqk39yKOhfpLqNLEC0/eViCkzM5FVXVqrvt526+wcI= github.com/go-openapi/validate v0.21.0/go.mod h1:rjnrwK57VJ7A8xqfpAOEKRH8yQSGUriMu5/zuPSQ1hg= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= +github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= +github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= +github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= @@ -918,6 +925,8 @@ github.com/kyleconroy/sqlc v1.16.0/go.mod h1:m+cX/UyBRnKP58lFfUsq+0gw87UUw9Amxwq github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -1141,6 +1150,7 @@ github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= @@ -1368,6 +1378,7 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211115234514-b4de73f9ece8/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= diff --git a/internal/armada/configuration/types.go b/internal/armada/configuration/types.go index bffad363962..3e5863613e4 100644 --- a/internal/armada/configuration/types.go +++ b/internal/armada/configuration/types.go @@ -3,6 +3,7 @@ package configuration import ( "time" + "github.com/apache/pulsar-client-go/pulsar" "github.com/go-redis/redis" v1 "k8s.io/api/core/v1" @@ -38,7 +39,7 @@ type ArmadaConfig struct { type PulsarConfig struct { // Pulsar URL - URL string + URL string `validate:"required"` // Path to the trusted TLS certificate file (must exist) TLSTrustCertsFilePath string // Whether Pulsar client accept untrusted TLS certificate from broker @@ -56,9 +57,9 @@ type PulsarConfig struct { JobsetEventsTopic string RedisFromPulsarSubscription string // Compression to use. Valid values are "None", "LZ4", "Zlib", "Zstd". Default is "None" - CompressionType string + CompressionType pulsar.CompressionType // Compression Level to use. Valid values are "Default", "Better", "Faster". Default is "Default" - CompressionLevel string + CompressionLevel pulsar.CompressionLevel // Used to construct an executorconfig.IngressConfiguration, // which is used when converting Armada-specific IngressConfig and ServiceConfig objects into k8s objects. HostnameSuffix string diff --git a/internal/armada/repository/apimessages/conversions.go b/internal/armada/repository/apimessages/conversions.go index 075f7332cf7..376b115d037 100644 --- a/internal/armada/repository/apimessages/conversions.go +++ b/internal/armada/repository/apimessages/conversions.go @@ -6,7 +6,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/common/eventutil" - "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/armadaevents" ) diff --git a/internal/armada/repository/event_store.go b/internal/armada/repository/event_store.go index 3bce0cc4589..1dc2c6bd80c 100644 --- a/internal/armada/repository/event_store.go +++ b/internal/armada/repository/event_store.go @@ -25,10 +25,10 @@ func (es *TestEventStore) ReportEvents(message []*api.EventMessage) error { type StreamEventStore struct { Producer pulsar.Producer - MaxAllowedMessageSize int + MaxAllowedMessageSize uint } -func NewEventStore(producer pulsar.Producer, maxAllowedMessageSize int) *StreamEventStore { +func NewEventStore(producer pulsar.Producer, maxAllowedMessageSize uint) *StreamEventStore { return &StreamEventStore{ Producer: producer, MaxAllowedMessageSize: maxAllowedMessageSize, } @@ -50,7 +50,7 @@ func (n *StreamEventStore) ReportEvents(apiEvents []*api.EventMessage) error { } sequences = eventutil.CompactEventSequences(sequences) - sequences, err = eventutil.LimitSequencesByteSize(sequences, int(n.MaxAllowedMessageSize), true) + sequences, err = eventutil.LimitSequencesByteSize(sequences, n.MaxAllowedMessageSize, true) if err != nil { return err } diff --git a/internal/armada/server.go b/internal/armada/server.go index 7ac70de3a37..f5e8a294d66 100644 --- a/internal/armada/server.go +++ b/internal/armada/server.go @@ -133,8 +133,6 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks // If Pulsar is enabled, use the Pulsar submit endpoints. // Store a list of all Pulsar components to use during cleanup later. var pulsarClient pulsar.Client - var pulsarCompressionType pulsar.CompressionType - var pulsarCompressionLevel pulsar.CompressionLevel submitChecker := scheduler.NewSubmitChecker( 10*time.Minute, config.Scheduling.Preemption.PriorityClasses, @@ -150,19 +148,11 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks } defer pulsarClient.Close() - pulsarCompressionType, err = pulsarutils.ParsePulsarCompressionType(config.Pulsar.CompressionType) - if err != nil { - return err - } - pulsarCompressionLevel, err = pulsarutils.ParsePulsarCompressionLevel(config.Pulsar.CompressionLevel) - if err != nil { - return err - } serverPulsarProducerName := fmt.Sprintf("armada-server-%s", serverId) producer, err := pulsarClient.CreateProducer(pulsar.ProducerOptions{ Name: serverPulsarProducerName, - CompressionType: pulsarCompressionType, - CompressionLevel: pulsarCompressionLevel, + CompressionType: config.Pulsar.CompressionType, + CompressionLevel: config.Pulsar.CompressionLevel, BatchingMaxSize: config.Pulsar.MaxAllowedMessageSize, Topic: config.Pulsar.JobsetEventsTopic, }) @@ -171,7 +161,7 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks } defer producer.Close() - eventStore := repository.NewEventStore(producer, int(config.Pulsar.MaxAllowedMessageSize)) + eventStore := repository.NewEventStore(producer, config.Pulsar.MaxAllowedMessageSize) submitServer := server.NewSubmitServer( permissions, diff --git a/internal/armada/server/submit_to_log.go b/internal/armada/server/submit_to_log.go index 3f303d72b97..16ff9aeebe7 100644 --- a/internal/armada/server/submit_to_log.go +++ b/internal/armada/server/submit_to_log.go @@ -732,7 +732,7 @@ func (srv *PulsarSubmitServer) publishToPulsar(ctx context.Context, sequences [] // Reduce the number of sequences to send to the minimum possible, // and then break up any sequences larger than srv.MaxAllowedMessageSize. sequences = eventutil.CompactEventSequences(sequences) - sequences, err := eventutil.LimitSequencesByteSize(sequences, int(srv.MaxAllowedMessageSize), true) + sequences, err := eventutil.LimitSequencesByteSize(sequences, srv.MaxAllowedMessageSize, true) if err != nil { return err } diff --git a/internal/common/config/hooks.go b/internal/common/config/hooks.go new file mode 100644 index 00000000000..dc3d7278bcf --- /dev/null +++ b/internal/common/config/hooks.go @@ -0,0 +1,98 @@ +package config + +import ( + "fmt" + "reflect" + "strings" + + "github.com/apache/pulsar-client-go/pulsar" + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" + "github.com/spf13/viper" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/armadaproject/armada/internal/common/armadaerrors" +) + +var CustomHooks = []viper.DecoderConfigOption{ + addDecodeHook(PulsarCompressionTypeHookFunc()), + addDecodeHook(PulsarCompressionLevelHookFunc()), + addDecodeHook(QuantityDecodeHook()), +} + +func PulsarCompressionTypeHookFunc() mapstructure.DecodeHookFuncType { + return func( + f reflect.Type, + t reflect.Type, + data interface{}, + ) (interface{}, error) { + // check that src and target types are valid + if f.Kind() != reflect.String || t != reflect.TypeOf(pulsar.NoCompression) { + return data, nil + } + switch strings.ToLower(data.(string)) { + case "", "none": + return pulsar.NoCompression, nil + case "lz4": + return pulsar.LZ4, nil + case "zlib": + return pulsar.ZLib, nil + case "zstd": + return pulsar.ZSTD, nil + default: + return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "pulsar.CompressionType", + Value: data, + Message: fmt.Sprintf("Unknown Pulsar compression type %s", data), + }) + } + } +} + +func PulsarCompressionLevelHookFunc() mapstructure.DecodeHookFuncType { + return func( + f reflect.Type, + t reflect.Type, + data interface{}, + ) (interface{}, error) { + // check that src and target types are valid + if f.Kind() != reflect.String || t != reflect.TypeOf(pulsar.Default) { + return data, nil + } + switch strings.ToLower(data.(string)) { + case "", "default": + return pulsar.Default, nil + case "faster": + return pulsar.Faster, nil + case "better": + return pulsar.Better, nil + default: + return pulsar.Default, errors.WithStack(&armadaerrors.ErrInvalidArgument{ + Name: "pulsar.CompressionLevel", + Value: data, + Message: fmt.Sprintf("Unknown Pulsar compression level %s", data), + }) + } + } +} + +func QuantityDecodeHook() mapstructure.DecodeHookFuncType { + return func( + f reflect.Type, + t reflect.Type, + data interface{}, + ) (interface{}, error) { + if t != reflect.TypeOf(resource.Quantity{}) { + return data, nil + } + return resource.ParseQuantity(fmt.Sprintf("%v", data)) + } +} + +func addDecodeHook(hook mapstructure.DecodeHookFuncType) viper.DecoderConfigOption { + return func(c *mapstructure.DecoderConfig) { + c.DecodeHook = mapstructure.ComposeDecodeHookFunc( + c.DecodeHook, + hook) + } +} diff --git a/internal/common/config/hooks_test.go b/internal/common/config/hooks_test.go new file mode 100644 index 00000000000..4a3bf1984d3 --- /dev/null +++ b/internal/common/config/hooks_test.go @@ -0,0 +1,125 @@ +package config + +import ( + "reflect" + "testing" + + "github.com/apache/pulsar-client-go/pulsar" + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/api/resource" +) + +type HookTest struct { + value interface{} + expected interface{} + expectError bool +} + +func TestPulsarCompressionTypeHookFunc(t *testing.T) { + tests := map[string]HookTest{ + "empty string": { + value: "", + expected: pulsar.NoCompression, + }, + "zlib": { + value: "zlib", + expected: pulsar.ZLib, + }, + "zstd": { + value: "zstd", + expected: pulsar.ZSTD, + }, + "lz4": { + value: "lz4", + expected: pulsar.LZ4, + }, + "case insensitive": { + value: "zLiB", + expected: pulsar.ZLib, + }, + "unknown": { + value: "not a valid compression", + expectError: true, + }, + "not string input": { + value: 1, + expected: 1, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + runHookTest(t, tc, reflect.TypeOf(pulsar.ZLib), PulsarCompressionTypeHookFunc()) + }) + } +} + +func TestPulsarCompressionLevelHookFunc(t *testing.T) { + tests := map[string]HookTest{ + "empty string": { + value: "", + expected: pulsar.Default, + }, + "faster": { + value: "faster", + expected: pulsar.Faster, + }, + "better": { + value: "better", + expected: pulsar.Better, + }, + "default": { + value: "default", + expected: pulsar.Default, + }, + "case insensitive": { + value: "FaSTer", + expected: pulsar.Faster, + }, + "unknown": { + value: "not a valid compression type", + expectError: true, + }, + "not string input": { + value: 1, + expected: 1, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + runHookTest(t, tc, reflect.TypeOf(pulsar.Better), PulsarCompressionLevelHookFunc()) + }) + } +} + +func TestQuantityDecodeHook(t *testing.T) { + tests := map[string]HookTest{ + "1": { + value: "1", + expected: resource.MustParse("1"), + }, + "100m": { + value: "100m", + expected: resource.MustParse("100m"), + }, + "100Mi": { + value: "100Mi", + expected: resource.MustParse("100Mi"), + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + runHookTest(t, tc, reflect.TypeOf(resource.Quantity{}), QuantityDecodeHook()) + }) + } +} + +func runHookTest(t *testing.T, tc HookTest, convertTo reflect.Type, hookFunc mapstructure.DecodeHookFuncType) { + parsed, err := hookFunc(reflect.TypeOf(tc.value), convertTo, tc.value) + if tc.expectError { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, parsed) + } +} diff --git a/internal/common/config/redis.go b/internal/common/config/redis.go new file mode 100644 index 00000000000..a6daa196655 --- /dev/null +++ b/internal/common/config/redis.go @@ -0,0 +1,48 @@ +package config + +import ( + "time" + + "github.com/go-redis/redis" +) + +type RedisConfig struct { + // Either a single address or a seed list of host:port addresses + Addrs []string `validate:"required"` + DB int `validate:"gte=0,lte=16"` + Password string + MaxRetries int + MinRetryBackoff time.Duration + MaxRetryBackoff time.Duration + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + PoolSize int `validate:"required"` + MinIdleConns int + MaxConnAge time.Duration + PoolTimeout time.Duration + IdleTimeout time.Duration + IdleCheckFrequency time.Duration + MasterName string +} + +func (rc RedisConfig) AsUniversalOptions() *redis.UniversalOptions { + return &redis.UniversalOptions{ + Addrs: rc.Addrs, + DB: rc.DB, + Password: rc.Password, + MaxRetries: rc.MaxRetries, + MinRetryBackoff: rc.MaxRetryBackoff, + MaxRetryBackoff: rc.MinRetryBackoff, + DialTimeout: rc.DialTimeout, + ReadTimeout: rc.ReadTimeout, + WriteTimeout: rc.WriteTimeout, + PoolSize: rc.PoolSize, + MinIdleConns: rc.MinIdleConns, + MaxConnAge: rc.MaxConnAge, + PoolTimeout: rc.PoolTimeout, + IdleTimeout: rc.IdleTimeout, + IdleCheckFrequency: rc.IdleCheckFrequency, + MasterName: rc.MasterName, + } +} diff --git a/internal/common/config/validation.go b/internal/common/config/validation.go new file mode 100644 index 00000000000..3a3dc5c517d --- /dev/null +++ b/internal/common/config/validation.go @@ -0,0 +1,34 @@ +package config + +import ( + "strings" + + "github.com/go-playground/validator/v10" + log "github.com/sirupsen/logrus" +) + +func Validate(config interface{}) error { + validate := validator.New() + return validate.Struct(config) +} + +func LogValidationErrors(err error) { + if err != nil { + for _, err := range err.(validator.ValidationErrors) { + fieldName := stripPrefix(err.Namespace()) + switch err.Tag() { + case "required": + log.Errorf("ConfigError: Field %s is required but was not found", fieldName) + default: + log.Errorf("ConfigError: %s is not a valid value for %s", err.Value(), fieldName) + } + } + } +} + +func stripPrefix(s string) string { + if idx := strings.Index(s, "."); idx != -1 { + return s[idx+1:] + } + return s +} diff --git a/internal/common/database/functions.go b/internal/common/database/functions.go index 635e856ffb9..23f182b40c1 100644 --- a/internal/common/database/functions.go +++ b/internal/common/database/functions.go @@ -25,6 +25,15 @@ func CreateConnectionString(values map[string]string) string { return result } +func OpenPgxConn(config configuration.PostgresConfig) (*pgx.Conn, error) { + db, err := pgx.Connect(context.Background(), CreateConnectionString(config.Connection)) + if err != nil { + return nil, err + } + err = db.Ping(context.Background()) + return db, err +} + func OpenPgxPool(config configuration.PostgresConfig) (*pgxpool.Pool, error) { db, err := pgxpool.Connect(context.Background(), CreateConnectionString(config.Connection)) if err != nil { diff --git a/internal/common/eventutil/eventutil.go b/internal/common/eventutil/eventutil.go index d9363951444..fc3b6b91265 100644 --- a/internal/common/eventutil/eventutil.go +++ b/internal/common/eventutil/eventutil.go @@ -462,7 +462,7 @@ func groupsEqual(g1, g2 []string) bool { // LimitSequencesByteSize calls LimitSequenceByteSize for each of the provided sequences // and returns all resulting sequences. -func LimitSequencesByteSize(sequences []*armadaevents.EventSequence, sizeInBytes int, strict bool) ([]*armadaevents.EventSequence, error) { +func LimitSequencesByteSize(sequences []*armadaevents.EventSequence, sizeInBytes uint, strict bool) ([]*armadaevents.EventSequence, error) { rv := make([]*armadaevents.EventSequence, 0, len(sequences)) for _, sequence := range sequences { limitedSequences, err := LimitSequenceByteSize(sequence, sizeInBytes, strict) @@ -476,18 +476,18 @@ func LimitSequencesByteSize(sequences []*armadaevents.EventSequence, sizeInBytes // LimitSequenceByteSize returns a slice of sequences produced by breaking up sequence.Events // into separate sequences, each of which is at most MAX_SEQUENCE_SIZE_IN_BYTES bytes in size. -func LimitSequenceByteSize(sequence *armadaevents.EventSequence, sizeInBytes int, strict bool) ([]*armadaevents.EventSequence, error) { +func LimitSequenceByteSize(sequence *armadaevents.EventSequence, sizeInBytes uint, strict bool) ([]*armadaevents.EventSequence, error) { // Compute the size of the sequence without events. events := sequence.Events sequence.Events = make([]*armadaevents.EventSequence_Event, 0) - headerSize := proto.Size(sequence) + headerSize := uint(proto.Size(sequence)) sequence.Events = events // var currentSequence *armadaevents.EventSequence sequences := make([]*armadaevents.EventSequence, 0, 1) - lastSequenceEventSize := 0 + lastSequenceEventSize := uint(0) for _, event := range sequence.Events { - eventSize := proto.Size(event) + eventSize := uint(proto.Size(event)) if eventSize+headerSize > sizeInBytes && strict { return nil, errors.WithStack(&armadaerrors.ErrInvalidArgument{ Name: "sequence", diff --git a/internal/common/grpc/configuration/types.go b/internal/common/grpc/configuration/types.go index e3b182dc989..689878e89cd 100644 --- a/internal/common/grpc/configuration/types.go +++ b/internal/common/grpc/configuration/types.go @@ -3,6 +3,7 @@ package configuration import "google.golang.org/grpc/keepalive" type GrpcConfig struct { + Port int `validate:"required"` KeepaliveParams keepalive.ServerParameters KeepaliveEnforcementPolicy keepalive.EnforcementPolicy } diff --git a/internal/common/pulsarutils/eventsequence.go b/internal/common/pulsarutils/eventsequence.go index 2fd2ef87a1c..dc6316d9306 100644 --- a/internal/common/pulsarutils/eventsequence.go +++ b/internal/common/pulsarutils/eventsequence.go @@ -16,7 +16,7 @@ import ( // CompactAndPublishSequences reduces the number of sequences to the smallest possible, // while respecting per-job set ordering and max Pulsar message size, and then publishes to Pulsar. -func CompactAndPublishSequences(ctx context.Context, sequences []*armadaevents.EventSequence, producer pulsar.Producer, maxMessageSizeInBytes int) error { +func CompactAndPublishSequences(ctx context.Context, sequences []*armadaevents.EventSequence, producer pulsar.Producer, maxMessageSizeInBytes uint) error { // Reduce the number of sequences to send to the minimum possible, // and then break up any sequences larger than maxMessageSizeInBytes. sequences = eventutil.CompactEventSequences(sequences) diff --git a/internal/common/pulsarutils/pulsarclient.go b/internal/common/pulsarutils/pulsarclient.go index 169625a31cb..8f6ba0ea1e9 100644 --- a/internal/common/pulsarutils/pulsarclient.go +++ b/internal/common/pulsarutils/pulsarclient.go @@ -1,7 +1,6 @@ package pulsarutils import ( - "fmt" "strings" "github.com/apache/pulsar-client-go/pulsar" @@ -42,39 +41,3 @@ func NewPulsarClient(config *configuration.PulsarConfig) (pulsar.Client, error) Authentication: authentication, }) } - -func ParsePulsarCompressionType(compressionTypeStr string) (pulsar.CompressionType, error) { - switch strings.ToLower(compressionTypeStr) { - case "", "none": - return pulsar.NoCompression, nil - case "lz4": - return pulsar.LZ4, nil - case "zlib": - return pulsar.ZLib, nil - case "zstd": - return pulsar.ZSTD, nil - default: - return pulsar.NoCompression, errors.WithStack(&armadaerrors.ErrInvalidArgument{ - Name: "pulsar.CompressionType", - Value: compressionTypeStr, - Message: fmt.Sprintf("Unknown Pulsar compression type %s", compressionTypeStr), - }) - } -} - -func ParsePulsarCompressionLevel(compressionLevelStr string) (pulsar.CompressionLevel, error) { - switch strings.ToLower(compressionLevelStr) { - case "", "default": - return pulsar.Default, nil - case "faster": - return pulsar.Faster, nil - case "better": - return pulsar.Better, nil - default: - return pulsar.Default, errors.WithStack(&armadaerrors.ErrInvalidArgument{ - Name: "pulsar.CompressionLevel", - Value: compressionLevelStr, - Message: fmt.Sprintf("Unknown Pulsar compression level %s", compressionLevelStr), - }) - } -} diff --git a/internal/common/pulsarutils/pulsarclient_test.go b/internal/common/pulsarutils/pulsarclient_test.go index 1073ddf44c2..bc5b1598eda 100644 --- a/internal/common/pulsarutils/pulsarclient_test.go +++ b/internal/common/pulsarutils/pulsarclient_test.go @@ -4,63 +4,11 @@ import ( "os" "testing" - "github.com/apache/pulsar-client-go/pulsar" "github.com/stretchr/testify/assert" "github.com/armadaproject/armada/internal/armada/configuration" ) -func TestParsePulsarCompressionType(t *testing.T) { - // No compression - comp, err := ParsePulsarCompressionType("") - assert.NoError(t, err) - assert.Equal(t, pulsar.NoCompression, comp) - - // Zlib - comp, err = ParsePulsarCompressionType("ZliB") - assert.NoError(t, err) - assert.Equal(t, pulsar.ZLib, comp) - - // Zstd - comp, err = ParsePulsarCompressionType("zstd") - assert.NoError(t, err) - assert.Equal(t, pulsar.ZSTD, comp) - - // Lz4 - comp, err = ParsePulsarCompressionType("LZ4") - assert.NoError(t, err) - assert.Equal(t, pulsar.LZ4, comp) - - // unknown - _, err = ParsePulsarCompressionType("not a valid compression") - assert.Error(t, err) -} - -func TestParsePulsarCompressionLevel(t *testing.T) { - // No compression - comp, err := ParsePulsarCompressionLevel("") - assert.NoError(t, err) - assert.Equal(t, pulsar.Default, comp) - - comp, err = ParsePulsarCompressionLevel("Default") - assert.NoError(t, err) - assert.Equal(t, pulsar.Default, comp) - - // Faster - comp, err = ParsePulsarCompressionLevel("FASTER") - assert.NoError(t, err) - assert.Equal(t, pulsar.Faster, comp) - - // Better - comp, err = ParsePulsarCompressionLevel("Better") - assert.NoError(t, err) - assert.Equal(t, pulsar.Better, comp) - - // unknown - _, err = ParsePulsarCompressionLevel("not a valid compression type") - assert.Error(t, err) -} - func TestCreatePulsarClientHappyPath(t *testing.T) { cwd, _ := os.Executable() // Need a valid directory for tokens and certs diff --git a/internal/common/startup.go b/internal/common/startup.go index 4837163bcd4..2e80220eda9 100644 --- a/internal/common/startup.go +++ b/internal/common/startup.go @@ -5,19 +5,17 @@ import ( "fmt" "net/http" "os" - "reflect" "strings" "time" - "github.com/mitchellh/mapstructure" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" "github.com/spf13/pflag" "github.com/spf13/viper" "github.com/weaveworks/promrus" - "k8s.io/apimachinery/pkg/api/resource" + commonconfig "github.com/armadaproject/armada/internal/common/config" "github.com/armadaproject/armada/internal/common/logging" ) @@ -56,7 +54,7 @@ func LoadConfig(config interface{}, defaultPath string, overrideConfigs []string v.SetEnvPrefix("ARMADA") v.AutomaticEnv() - err := v.Unmarshal(config, addDecodeHook(quantityDecodeHook)) + err := v.Unmarshal(config, commonconfig.CustomHooks...) if err != nil { log.Error(err) os.Exit(-1) @@ -66,26 +64,7 @@ func LoadConfig(config interface{}, defaultPath string, overrideConfigs []string } func UnmarshalKey(v *viper.Viper, key string, item interface{}) error { - return v.UnmarshalKey(key, item, addDecodeHook(quantityDecodeHook)) -} - -func addDecodeHook(hook mapstructure.DecodeHookFuncType) viper.DecoderConfigOption { - return func(c *mapstructure.DecoderConfig) { - c.DecodeHook = mapstructure.ComposeDecodeHookFunc( - c.DecodeHook, - hook) - } -} - -func quantityDecodeHook( - from reflect.Type, - to reflect.Type, - data interface{}, -) (interface{}, error) { - if to != reflect.TypeOf(resource.Quantity{}) { - return data, nil - } - return resource.ParseQuantity(fmt.Sprintf("%v", data)) + return v.UnmarshalKey(key, item, commonconfig.CustomHooks...) } // TODO Move logging-related code out of common into a new package internal/logging diff --git a/internal/eventingester/convert/conversions.go b/internal/eventingester/convert/conversions.go index 0ff251cfeb4..c48df1b56e7 100644 --- a/internal/eventingester/convert/conversions.go +++ b/internal/eventingester/convert/conversions.go @@ -17,11 +17,11 @@ import ( // EventConverter converts event sequences into events that we can store in Redis type EventConverter struct { Compressor compress.Compressor - MaxMessageBatchSize int + MaxMessageBatchSize uint metrics *metrics.Metrics } -func NewEventConverter(compressor compress.Compressor, maxMessageBatchSize int, metrics *metrics.Metrics) ingest.InstructionConverter[*model.BatchUpdate] { +func NewEventConverter(compressor compress.Compressor, maxMessageBatchSize uint, metrics *metrics.Metrics) ingest.InstructionConverter[*model.BatchUpdate] { return &EventConverter{ Compressor: compressor, MaxMessageBatchSize: maxMessageBatchSize, diff --git a/internal/eventingester/ingester.go b/internal/eventingester/ingester.go index 92ae25c6b88..4b771d5ddeb 100644 --- a/internal/eventingester/ingester.go +++ b/internal/eventingester/ingester.go @@ -50,7 +50,7 @@ func Run(config *configuration.EventIngesterConfiguration) { log.Errorf("Error creating compressor for consumer") panic(err) } - converter := convert.NewEventConverter(compressor, config.BatchSize, metrics) + converter := convert.NewEventConverter(compressor, uint(config.BatchSize), metrics) ingester := ingest. NewIngestionPipeline(config.Pulsar, config.SubscriptionName, config.BatchSize, config.BatchDuration, converter, eventDb, config.Metrics, metrics) diff --git a/internal/pulsartest/app.go b/internal/pulsartest/app.go index 7471eb984e8..93b79259bfd 100644 --- a/internal/pulsartest/app.go +++ b/internal/pulsartest/app.go @@ -32,21 +32,10 @@ func New(params Params, cmdType string) (*App, error) { var reader pulsar.Reader if cmdType == "submit" { - compressionType, err := pulsarutils.ParsePulsarCompressionType(params.Pulsar.CompressionType) - if err != nil { - return nil, err - } - compressionLevel, err := pulsarutils.ParsePulsarCompressionLevel(params.Pulsar.CompressionLevel) - if err != nil { - return nil, err - } - producerName := fmt.Sprintf("pulsartest-%s", serverId) producer, err = pulsarClient.CreateProducer(pulsar.ProducerOptions{ - Name: producerName, - CompressionType: compressionType, - CompressionLevel: compressionLevel, - Topic: params.Pulsar.JobsetEventsTopic, + Name: producerName, + Topic: params.Pulsar.JobsetEventsTopic, }) if err != nil { diff --git a/internal/scheduler/api.go b/internal/scheduler/api.go index 615c476a0f0..c963885c08a 100644 --- a/internal/scheduler/api.go +++ b/internal/scheduler/api.go @@ -28,8 +28,8 @@ type ExecutorApi struct { jobRepository database.JobRepository executorRepository database.ExecutorRepository allowedPriorities []int32 // allowed priority classes - maxJobsPerCall int // maximum number of jobs that will be leased in a single call - maxPulsarMessageSize int // maximum sizer of pulsar messages produced + maxJobsPerCall uint // maximum number of jobs that will be leased in a single call + maxPulsarMessageSize uint // maximum sizer of pulsar messages produced clock clock.Clock } @@ -37,8 +37,7 @@ func NewExecutorApi(producer pulsar.Producer, jobRepository database.JobRepository, executorRepository database.ExecutorRepository, allowedPriorities []int32, - maxJobsPerCall int, - maxPulsarMessageSize int, + maxJobsPerCall uint, ) *ExecutorApi { return &ExecutorApi{ producer: producer, @@ -46,7 +45,7 @@ func NewExecutorApi(producer pulsar.Producer, executorRepository: executorRepository, allowedPriorities: allowedPriorities, maxJobsPerCall: maxJobsPerCall, - maxPulsarMessageSize: maxPulsarMessageSize, + maxPulsarMessageSize: 1024 * 1024 * 2, clock: clock.RealClock{}, } } diff --git a/internal/scheduler/api_test.go b/internal/scheduler/api_test.go index ad3211b579e..5a925f4d67e 100644 --- a/internal/scheduler/api_test.go +++ b/internal/scheduler/api_test.go @@ -24,7 +24,7 @@ import ( ) func TestExecutorApi_LeaseJobRuns(t *testing.T) { - const maxJobsPerCall = 100 + const maxJobsPerCall = uint(100) testClock := clock.NewFakeClock(time.Now()) runId1 := uuid.New() runId2 := uuid.New() @@ -145,7 +145,6 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) { mockExecutorRepository, []int32{}, maxJobsPerCall, - 1024, ) server.clock = testClock @@ -213,8 +212,7 @@ func TestExecutorApi_Publish(t *testing.T) { mockJobRepository, mockExecutorRepository, []int32{}, - 100, - 1024) + 100) empty, err := server.ReportEvents(ctx, &executorapi.EventList{Events: tc.sequences}) require.NoError(t, err) diff --git a/internal/scheduler/config.go b/internal/scheduler/config.go index 57189f42a84..0f2976d051f 100644 --- a/internal/scheduler/config.go +++ b/internal/scheduler/config.go @@ -4,26 +4,37 @@ import ( "time" "github.com/armadaproject/armada/internal/armada/configuration" + authconfig "github.com/armadaproject/armada/internal/common/auth/configuration" + "github.com/armadaproject/armada/internal/common/config" + grpcconfig "github.com/armadaproject/armada/internal/common/grpc/configuration" ) type Configuration struct { // Database configuration Postgres configuration.PostgresConfig - // Metrics configuration - Metrics configuration.MetricsConfig + // Redis Comnfig + Redis config.RedisConfig // General Pulsar configuration Pulsar configuration.PulsarConfig - // Pulsar subscription name - SubscriptionName string - // Maximum time since the last batch before a batch will be inserted into the database - BatchDuration time.Duration - // Time for which the pulsar consumer will wait for a new message before retrying - PulsarReceiveTimeout time.Duration - // Time for which the pulsar consumer will back off after receiving an error on trying to receive a message - PulsarBackoffTime time.Duration + // Configuration controlling leader election + Leader LeaderConfig + // Scheduler configuration (this is shared with the old scheduler) + Scheduling configuration.SchedulingConfig + Auth authconfig.AuthConfig + Grpc grpcconfig.GrpcConfig + // How often the scheduling cycle should run + CyclePeriod time.Duration `validate:"required"` + // How long after a heartbeat an executor will be considered lost + ExecutorTimeout time.Duration `validate:"required"` + // Maximum number of rows to fetch in a given query + DatabaseFetchSize int `validate:"required"` + // Timeout to use when sending messages to pulsar + PulsarSendTimeout time.Duration `validate:"required"` } type LeaderConfig struct { + // Valid modes are "standalone" or "cluster" + Mode string `validate:"required"` // Name of the K8s Lock Object LeaseLockName string // Namespace of the K8s Lock Object diff --git a/internal/scheduler/database/executor_repository.go b/internal/scheduler/database/executor_repository.go index e93b5762a86..00b9a9c0f1f 100644 --- a/internal/scheduler/database/executor_repository.go +++ b/internal/scheduler/database/executor_repository.go @@ -31,11 +31,11 @@ type PostgresExecutorRepository struct { decompressor compress.Decompressor } -func NewPostgresExecutorRepository(db *pgxpool.Pool, compressor compress.Compressor, decompressor compress.Decompressor) *PostgresExecutorRepository { +func NewPostgresExecutorRepository(db *pgxpool.Pool) *PostgresExecutorRepository { return &PostgresExecutorRepository{ db: db, - compressor: compressor, - decompressor: decompressor, + compressor: compress.NewThreadSafeZlibCompressor(1024), + decompressor: compress.NewThreadSafeZlibDecompressor(), } } diff --git a/internal/scheduler/database/executor_repository_test.go b/internal/scheduler/database/executor_repository_test.go index 8504f157a02..d1de56769a3 100644 --- a/internal/scheduler/database/executor_repository_test.go +++ b/internal/scheduler/database/executor_repository_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/slices" - "github.com/armadaproject/armada/internal/common/compress" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -125,10 +124,7 @@ func TestExecutorRepository_GetLastUpdateTimes(t *testing.T) { func withExecutorRepository(action func(repository *PostgresExecutorRepository) error) error { return WithTestDb(func(_ *Queries, db *pgxpool.Pool) error { - repo := NewPostgresExecutorRepository( - db, - compress.NewThreadSafeZlibCompressor(1024), - compress.NewThreadSafeZlibDecompressor()) + repo := NewPostgresExecutorRepository(db) return action(repo) }) } diff --git a/internal/scheduler/database/job_repository.go b/internal/scheduler/database/job_repository.go index 3d8bb7318b1..181b4d4c8a4 100644 --- a/internal/scheduler/database/job_repository.go +++ b/internal/scheduler/database/job_repository.go @@ -50,7 +50,7 @@ type JobRepository interface { // FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run // in excludedRunIds will be excluded - FetchJobRunLeases(ctx context.Context, executor string, maxResults int, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) + FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) } // PostgresJobRepository is an implementation of JobRepository that stores its state in postgres @@ -193,7 +193,7 @@ func (r *PostgresJobRepository) FindInactiveRuns(ctx context.Context, runIds []u // FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run // in excludedRunIds will be excluded -func (r *PostgresJobRepository) FetchJobRunLeases(ctx context.Context, executor string, maxResults int, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) { +func (r *PostgresJobRepository) FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) { var newRuns []*JobRunLease err := r.db.BeginTxFunc(ctx, pgx.TxOptions{ IsoLevel: pgx.ReadCommitted, diff --git a/internal/scheduler/database/job_repository_test.go b/internal/scheduler/database/job_repository_test.go index dd4a7640ff3..d0f27e61992 100644 --- a/internal/scheduler/database/job_repository_test.go +++ b/internal/scheduler/database/job_repository_test.go @@ -333,7 +333,7 @@ func TestFetchJobRunLeases(t *testing.T) { dbRuns []Run dbJobs []Job excludedRuns []uuid.UUID - maxRowsToFetch int + maxRowsToFetch uint executor string expectedLeases []*JobRunLease }{ diff --git a/internal/scheduler/database/queue_repository.go b/internal/scheduler/database/queue_repository.go index 472df0d973a..cb3c34b9839 100644 --- a/internal/scheduler/database/queue_repository.go +++ b/internal/scheduler/database/queue_repository.go @@ -1,5 +1,38 @@ package database +import ( + "github.com/go-redis/redis" + + legacyrepository "github.com/armadaproject/armada/internal/armada/repository" +) + +// QueueRepository is an interface to be implemented by structs which provide queue information type QueueRepository interface { GetAllQueues() ([]*Queue, error) } + +// LegacyQueueRepository is a QueueRepository which is backed by Armada's redis store +type LegacyQueueRepository struct { + backingRepo legacyrepository.QueueRepository +} + +func NewLegacyQueueRepository(db redis.UniversalClient) *LegacyQueueRepository { + return &LegacyQueueRepository{ + backingRepo: legacyrepository.NewRedisQueueRepository(db), + } +} + +func (r *LegacyQueueRepository) GetAllQueues() ([]*Queue, error) { + legacyQueues, err := r.backingRepo.GetAllQueues() + if err != nil { + return nil, err + } + queues := make([]*Queue, len(legacyQueues)) + for i, legacyQueue := range legacyQueues { + queues[i] = &Queue{ + Name: legacyQueue.Name, + Weight: float64(legacyQueue.PriorityFactor), + } + } + return queues, nil +} diff --git a/internal/scheduler/database/queue_repository_test.go b/internal/scheduler/database/queue_repository_test.go new file mode 100644 index 00000000000..edd33be3d41 --- /dev/null +++ b/internal/scheduler/database/queue_repository_test.go @@ -0,0 +1,67 @@ +package database + +import ( + "testing" + + "github.com/go-redis/redis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" + + clientQueue "github.com/armadaproject/armada/pkg/client/queue" +) + +func TestLegacyQueueRepository_GetAllQueues(t *testing.T) { + tests := map[string]struct { + queues []clientQueue.Queue + expectedQueues []*Queue + }{ + "Not empty": { + queues: []clientQueue.Queue{ + { + Name: "test-queue-1", + PriorityFactor: 10, + }, + { + Name: "test-queue-2", + PriorityFactor: 20, + }, + }, + expectedQueues: []*Queue{ + { + Name: "test-queue-1", + Weight: 10, + }, + { + Name: "test-queue-2", + Weight: 20, + }, + }, + }, + "Empty": { + queues: []clientQueue.Queue{}, + expectedQueues: []*Queue{}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + rc := redis.NewClient(&redis.Options{Addr: "localhost:6379", DB: 10}) + rc.FlushDB() + defer func() { + rc.FlushDB() + _ = rc.Close() + }() + repo := NewLegacyQueueRepository(rc) + for _, queue := range tc.queues { + err := repo.backingRepo.CreateQueue(queue) + require.NoError(t, err) + } + retrievedQueues, err := repo.GetAllQueues() + require.NoError(t, err) + sortFunc := func(a, b *Queue) bool { return a.Name > b.Name } + slices.SortFunc(tc.expectedQueues, sortFunc) + slices.SortFunc(retrievedQueues, sortFunc) + assert.Equal(t, tc.expectedQueues, retrievedQueues) + }) + } +} diff --git a/internal/scheduler/leader.go b/internal/scheduler/leader.go index 566aae860b4..134ca03e359 100644 --- a/internal/scheduler/leader.go +++ b/internal/scheduler/leader.go @@ -19,6 +19,8 @@ type LeaderController interface { // ValidateToken allows a caller to determine whether a previously obtained token is still valid. // Returns true if the token is a leader and false otherwise ValidateToken(tok LeaderToken) bool + // Run starts the controller. This is a blocking call which will return when the provided context is cancelled + Run(ctx context.Context) error } // StandaloneLeaderController returns a token that always indicates you are leader @@ -44,6 +46,15 @@ func (lc *StandaloneLeaderController) ValidateToken(tok LeaderToken) bool { return false } +func (lc *StandaloneLeaderController) Run(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return nil + } + } +} + // LeaseListener allows clients to listen for lease events type LeaseListener interface { // Called when the client has started leading diff --git a/internal/scheduler/leader_test.go b/internal/scheduler/leader_test.go index e0be540d2d9..752916ae7e6 100644 --- a/internal/scheduler/leader_test.go +++ b/internal/scheduler/leader_test.go @@ -104,7 +104,7 @@ func TestK8sLeaderController_BecomingLeader(t *testing.T) { }).AnyTimes() // Run the test - controller := NewKubernetesLeaderController(config(), client) + controller := NewKubernetesLeaderController(testLeaderConfig(), client) testListener := NewTestLeaseListener(controller) controller.listener = testListener ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -151,7 +151,7 @@ func TestK8sLeaderController_BecomingLeader(t *testing.T) { } } -func config() LeaderConfig { +func testLeaderConfig() LeaderConfig { return LeaderConfig{ LeaseLockName: lockName, LeaseLockNamespace: lockNamespace, diff --git a/internal/scheduler/mocks/mock_repositories.go b/internal/scheduler/mocks/mock_repositories.go index 365f98dc5c2..adb23219ef0 100644 --- a/internal/scheduler/mocks/mock_repositories.go +++ b/internal/scheduler/mocks/mock_repositories.go @@ -175,7 +175,7 @@ func (mr *MockJobRepositoryMockRecorder) FetchJobRunErrors(arg0, arg1 interface{ } // FetchJobRunLeases mocks base method. -func (m *MockJobRepository) FetchJobRunLeases(arg0 context.Context, arg1 string, arg2 int, arg3 []uuid.UUID) ([]*database.JobRunLease, error) { +func (m *MockJobRepository) FetchJobRunLeases(arg0 context.Context, arg1 string, arg2 uint, arg3 []uuid.UUID) ([]*database.JobRunLease, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJobRunLeases", arg0, arg1, arg2, arg3) ret0, _ := ret[0].([]*database.JobRunLease) diff --git a/internal/scheduler/publisher.go b/internal/scheduler/publisher.go index 2f0c83f7bb1..38d00d5fb69 100644 --- a/internal/scheduler/publisher.go +++ b/internal/scheduler/publisher.go @@ -41,14 +41,13 @@ type PulsarPublisher struct { pulsarSendTimeout time.Duration // Maximum size (in bytes) of produced pulsar messages. // This must be below 4MB which is the pulsar message size limit - maxMessageBatchSize int + maxMessageBatchSize uint } func NewPulsarPublisher( pulsarClient pulsar.Client, producerOptions pulsar.ProducerOptions, pulsarSendTimeout time.Duration, - maxMessageBatchSize int, ) (*PulsarPublisher, error) { partitions, err := pulsarClient.TopicPartitions(producerOptions.Topic) if err != nil { @@ -62,7 +61,7 @@ func NewPulsarPublisher( return &PulsarPublisher{ producer: producer, pulsarSendTimeout: pulsarSendTimeout, - maxMessageBatchSize: maxMessageBatchSize, + maxMessageBatchSize: 2 * 1024 * 1024, // max pulsar message size is 4MB, so we use 2MB here to be safe numPartitions: len(partitions), }, nil } diff --git a/internal/scheduler/publisher_test.go b/internal/scheduler/publisher_test.go index 6c98f225bcf..f588efec52f 100644 --- a/internal/scheduler/publisher_test.go +++ b/internal/scheduler/publisher_test.go @@ -24,7 +24,6 @@ import ( const ( topic = "testTopic" numPartitions = 100 - messageSize = 1024 // 1kb ) func TestPulsarPublisher_TestPublish(t *testing.T) { @@ -91,6 +90,8 @@ func TestPulsarPublisher_TestPublish(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() ctrl := gomock.NewController(t) mockPulsarClient := schedulermocks.NewMockClient(ctrl) mockPulsarProducer := schedulermocks.NewMockProducer(ctrl) @@ -120,8 +121,7 @@ func TestPulsarPublisher_TestPublish(t *testing.T) { }).AnyTimes() options := pulsar.ProducerOptions{Topic: topic} - ctx := context.TODO() - publisher, err := NewPulsarPublisher(mockPulsarClient, options, 5*time.Second, messageSize) + publisher, err := NewPulsarPublisher(mockPulsarClient, options, 5*time.Second) require.NoError(t, err) err = publisher.PublishMessages(ctx, tc.eventSequences, func() bool { return tc.amLeader }) @@ -147,22 +147,22 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { allPartitions[fmt.Sprintf("%d", i)] = true } tests := map[string]struct { - numSucessfulPublishes int - expectedError bool - expectedPartitons map[string]bool + numSuccessfulPublishes int + expectedError bool + expectedPartitions map[string]bool }{ "Publish successful": { - numSucessfulPublishes: math.MaxInt, - expectedError: false, - expectedPartitons: allPartitions, + numSuccessfulPublishes: math.MaxInt, + expectedError: false, + expectedPartitions: allPartitions, }, "All Publishes fail": { - numSucessfulPublishes: 0, - expectedError: true, + numSuccessfulPublishes: 0, + expectedError: true, }, "Some Publishes fail": { - numSucessfulPublishes: 10, - expectedError: true, + numSuccessfulPublishes: 10, + expectedError: true, }, } for name, tc := range tests { @@ -173,7 +173,7 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { mockPulsarClient.EXPECT().CreateProducer(gomock.Any()).Return(mockPulsarProducer, nil).Times(1) mockPulsarClient.EXPECT().TopicPartitions(topic).Return(make([]string, numPartitions), nil) numPublished := 0 - capturedPartitons := make(map[string]bool) + capturedPartitions := make(map[string]bool) mockPulsarProducer. EXPECT(). @@ -182,9 +182,9 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { numPublished++ key, ok := msg.Properties[explicitPartitionKey] if ok { - capturedPartitons[key] = true + capturedPartitions[key] = true } - if numPublished > tc.numSucessfulPublishes { + if numPublished > tc.numSuccessfulPublishes { log.Info("returning error") return pulsarutils.NewMessageId(numPublished), errors.New("error from mock pulsar producer") } @@ -193,7 +193,7 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { options := pulsar.ProducerOptions{Topic: topic} ctx := context.TODO() - publisher, err := NewPulsarPublisher(mockPulsarClient, options, 5*time.Second, messageSize) + publisher, err := NewPulsarPublisher(mockPulsarClient, options, 5*time.Second) require.NoError(t, err) published, err := publisher.PublishMarkers(ctx, uuid.New()) @@ -207,7 +207,7 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { if !tc.expectedError { assert.Equal(t, uint32(numPartitions), published) - assert.Equal(t, tc.expectedPartitons, capturedPartitons) + assert.Equal(t, tc.expectedPartitions, capturedPartitions) } }) } diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 4a055e7ac88..8e1e85f3e46 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -17,11 +17,6 @@ import ( "github.com/armadaproject/armada/pkg/armadaevents" ) -func Run(_ *Configuration) error { - // TODO: instantiate scheduler and start cycling - return nil -} - // Scheduler is the main armada Scheduler. It runs a periodic scheduling cycle during which the following actions are // performed: // * Determine if we are leader @@ -193,7 +188,7 @@ func (s *Scheduler) cycle(ctx context.Context, updateAll bool, leaderToken Leade events = append(events, expirationEvents...) // Schedule Jobs - scheduledJobs, err := s.schedulingAlgo.Schedule(txn, s.jobDb) + scheduledJobs, err := s.schedulingAlgo.Schedule(ctx, txn, s.jobDb) if err != nil { return err } diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 242eaee9b0b..05919a637df 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -401,7 +401,7 @@ func (t *testJobRepository) FindInactiveRuns(ctx context.Context, runIds []uuid. panic("implement me") } -func (t *testJobRepository) FetchJobRunLeases(ctx context.Context, executor string, maxResults int, excludedRunIds []uuid.UUID) ([]*database.JobRunLease, error) { +func (t *testJobRepository) FetchJobRunLeases(ctx context.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*database.JobRunLease, error) { // TODO implement me panic("implement me") } @@ -452,7 +452,7 @@ type testSchedulingAlgo struct { shouldError bool } -func (t *testSchedulingAlgo) Schedule(txn *memdb.Txn, jobDb *JobDb) ([]*SchedulerJob, error) { +func (t *testSchedulingAlgo) Schedule(ctx context.Context, txn *memdb.Txn, jobDb *JobDb) ([]*SchedulerJob, error) { if t.shouldError { return nil, errors.New("error scheduling jobs") } diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go new file mode 100644 index 00000000000..9fd77d1c067 --- /dev/null +++ b/internal/scheduler/schedulerapp.go @@ -0,0 +1,146 @@ +package scheduler + +import ( + "fmt" + "net" + "strings" + + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + + "github.com/apache/pulsar-client-go/pulsar" + "github.com/go-redis/redis" + "github.com/google/uuid" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" + + "github.com/armadaproject/armada/internal/common/app" + "github.com/armadaproject/armada/internal/common/auth" + dbcommon "github.com/armadaproject/armada/internal/common/database" + grpcCommon "github.com/armadaproject/armada/internal/common/grpc" + "github.com/armadaproject/armada/internal/common/pulsarutils" + "github.com/armadaproject/armada/internal/scheduler/database" + "github.com/armadaproject/armada/pkg/executorapi" +) + +// Run sets up a Scheduler application and runs it until a SIGTERM is received +func Run(config Configuration) error { + g, ctx := errgroup.WithContext(app.CreateContextWithShutdown()) + + ////////////////////////////////////////////////////////////////////////// + // Database setup (postgres and redis) + ////////////////////////////////////////////////////////////////////////// + log.Infof("Setting up database connections") + db, err := dbcommon.OpenPgxPool(config.Postgres) + if err != nil { + return errors.WithMessage(err, "Error opening connection to postgres") + } + jobRepository := database.NewPostgresJobRepository(db, int32(config.DatabaseFetchSize)) + executorRepository := database.NewPostgresExecutorRepository(db) + + redisClient := redis.NewUniversalClient(config.Redis.AsUniversalOptions()) + defer func() { + err := redisClient.Close() + if err != nil { + log.WithError(errors.WithStack(err)).Warnf("Redis client didn't close down cleanly") + } + }() + queueRepository := database.NewLegacyQueueRepository(redisClient) + + ////////////////////////////////////////////////////////////////////////// + // Pulsar + ////////////////////////////////////////////////////////////////////////// + log.Infof("Setting up Pulsar connectivity") + pulsarClient, err := pulsarutils.NewPulsarClient(&config.Pulsar) + defer pulsarClient.Close() + if err != nil { + return errors.WithMessage(err, "Error creating pulsar client") + } + pulsarPublisher, err := NewPulsarPublisher(pulsarClient, pulsar.ProducerOptions{ + Name: fmt.Sprintf("armada-scheduler-%s", uuid.NewString()), + CompressionType: config.Pulsar.CompressionType, + CompressionLevel: config.Pulsar.CompressionLevel, + BatchingMaxSize: config.Pulsar.MaxAllowedMessageSize, + Topic: config.Pulsar.JobsetEventsTopic, + }, config.PulsarSendTimeout) + if err != nil { + return errors.WithMessage(err, "error creating pulsar publisher") + } + + ////////////////////////////////////////////////////////////////////////// + // Leader Election + ////////////////////////////////////////////////////////////////////////// + leaderController, err := createLeaderController(config.Leader) + if err != nil { + return errors.WithMessage(err, "error creating leader controller") + } + g.Go(func() error { return leaderController.Run(ctx) }) + + ////////////////////////////////////////////////////////////////////////// + // Executor Api + ////////////////////////////////////////////////////////////////////////// + log.Infof("Setting up executor api") + apiProducer, err := pulsarClient.CreateProducer(pulsar.ProducerOptions{ + Name: fmt.Sprintf("armada-executor-api-%s", uuid.NewString()), + CompressionType: config.Pulsar.CompressionType, + CompressionLevel: config.Pulsar.CompressionLevel, + BatchingMaxSize: config.Pulsar.MaxAllowedMessageSize, + Topic: config.Pulsar.JobsetEventsTopic, + }) + if err != nil { + return errors.Wrapf(err, "error creating pulsar producer for executor api") + } + authServices := auth.ConfigureAuth(config.Auth) + grpcServer := grpcCommon.CreateGrpcServer(config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, authServices) + defer grpcServer.GracefulStop() + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", config.Grpc.Port)) + if err != nil { + return errors.WithMessage(err, "error setting up grpc server") + } + executorServer := NewExecutorApi(apiProducer, jobRepository, executorRepository, []int32{}, config.Scheduling.MaximumJobsToSchedule) + executorapi.RegisterExecutorApiServer(grpcServer, executorServer) + g.Go(func() error { return grpcServer.Serve(lis) }) + log.Infof("Executor api listening on %s", lis.Addr()) + + ////////////////////////////////////////////////////////////////////////// + // Scheduling + ////////////////////////////////////////////////////////////////////////// + log.Infof("Starting up scheduling loop") + schedulingAlgo := NewLegacySchedulingAlgo(config.Scheduling, executorRepository, queueRepository) + scheduler, err := NewScheduler(jobRepository, + executorRepository, + schedulingAlgo, + leaderController, + pulsarPublisher, + config.CyclePeriod, + config.ExecutorTimeout, + config.Scheduling.MaxRetries) + if err != nil { + return errors.WithMessage(err, "error creating scheduler") + } + g.Go(func() error { return scheduler.Run(ctx) }) + + return g.Wait() +} + +func createLeaderController(config LeaderConfig) (LeaderController, error) { + switch mode := strings.ToLower(config.Mode); mode { + case "standalone": + log.Infof("Scheduler will run in standalone mode") + return NewStandaloneLeaderController(), nil + case "cluster": + log.Infof("Scheduler will run cluster mode") + clusterConfig, err := rest.InClusterConfig() + if err != nil { + return nil, errors.Wrapf(err, "Error creating kubernetes client") + } + clientSet, err := kubernetes.NewForConfig(clusterConfig) + if err != nil { + return nil, errors.Wrapf(err, "Error creating kubernetes client") + } + return NewKubernetesLeaderController(LeaderConfig{}, clientSet.CoordinationV1()), nil + default: + return nil, errors.Errorf("%s is not a value leader mode", config.Mode) + } +} diff --git a/internal/scheduler/scheduling_algo.go b/internal/scheduler/scheduling_algo.go index 191f17bdc8c..93ecc7718c7 100644 --- a/internal/scheduler/scheduling_algo.go +++ b/internal/scheduler/scheduling_algo.go @@ -23,7 +23,7 @@ type SchedulingAlgo interface { // Schedule should assign jobs to nodes // Any jobs that are scheduled should be marked as such in the JobDb using the transaction provided // It should return a slice containing all scheduled jobs. - Schedule(txn *memdb.Txn, jobDb *JobDb) ([]*SchedulerJob, error) + Schedule(ctx context.Context, txn *memdb.Txn, jobDb *JobDb) ([]*SchedulerJob, error) } // LegacySchedulingAlgo is a SchedulingAlgo that schedules jobs in the same way as the old lease call