From 4173f54504fbc95dd1e85c897f99b4333e7590ec Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Thu, 22 Aug 2024 10:08:51 -0700 Subject: [PATCH] Cherrypick #32276 to 2.59.0 release branch. (#32280) * Add an idle shutdown timout to prism binary. * Correct flag text. Co-authored-by: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> --------- Co-authored-by: lostluck <13907733+lostluck@users.noreply.github.com> Co-authored-by: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> --- sdks/go/cmd/prism/prism.go | 25 +++++---- .../prism/internal/jobservices/management.go | 22 ++++++-- .../prism/internal/jobservices/server.go | 53 ++++++++++++++++++- .../beam/runners/prism/internal/web/web.go | 8 +-- sdks/go/pkg/beam/runners/prism/prism.go | 12 +++++ 5 files changed, 102 insertions(+), 18 deletions(-) diff --git a/sdks/go/cmd/prism/prism.go b/sdks/go/cmd/prism/prism.go index 804ae0c2ab2d0..39c19df00dc39 100644 --- a/sdks/go/cmd/prism/prism.go +++ b/sdks/go/cmd/prism/prism.go @@ -30,16 +30,24 @@ import ( ) var ( - jobPort = flag.Int("job_port", 8073, "specify the job management service port") - webPort = flag.Int("web_port", 8074, "specify the web ui port") - jobManagerEndpoint = flag.String("jm_override", "", "set to only stand up a web ui that refers to a seperate JobManagement endpoint") - serveHTTP = flag.Bool("serve_http", true, "enable or disable the web ui") + jobPort = flag.Int("job_port", 8073, "specify the job management service port") + webPort = flag.Int("web_port", 8074, "specify the web ui port") + jobManagerEndpoint = flag.String("jm_override", "", "set to only stand up a web ui that refers to a seperate JobManagement endpoint") + serveHTTP = flag.Bool("serve_http", true, "enable or disable the web ui") + idleShutdownTimeout = flag.Duration("idle_shutdown_timeout", -1, "duration that prism will wait for a new job before shutting itself down. Negative durations disable auto shutdown. Defaults to never shutting down.") ) func main() { flag.Parse() - ctx := context.Background() - cli, err := makeJobClient(ctx, prism.Options{Port: *jobPort}, *jobManagerEndpoint) + ctx, cancel := context.WithCancelCause(context.Background()) + + cli, err := makeJobClient(ctx, + prism.Options{ + Port: *jobPort, + IdleShutdownTimeout: *idleShutdownTimeout, + CancelFn: cancel, + }, + *jobManagerEndpoint) if err != nil { log.Fatalf("error creating job server: %v", err) } @@ -47,10 +55,9 @@ func main() { if err := prism.CreateWebServer(ctx, cli, prism.Options{Port: *webPort}); err != nil { log.Fatalf("error creating web server: %v", err) } - } else { - // Block main thread forever to keep main from exiting. - <-(chan struct{})(nil) // receives on nil channels block. } + // Block main thread forever to keep main from exiting. + <-ctx.Done() } func makeJobClient(ctx context.Context, opts prism.Options, endpoint string) (jobpb.JobServiceClient, error) { diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index 2b03eddff05d7..b957b99ca63d2 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -73,12 +73,14 @@ func (e *joinError) Error() string { return string(b) } -func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jobpb.PrepareJobResponse, error) { +func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ *jobpb.PrepareJobResponse, err error) { s.mu.Lock() defer s.mu.Unlock() // Since jobs execute in the background, they should not be tied to a request's context. rootCtx, cancelFn := context.WithCancelCause(context.Background()) + // Wrap in a Once so it will only be invoked a single time for the job. + terminalOnceWrap := sync.OnceFunc(s.jobTerminated) job := &Job{ key: s.nextId(), Pipeline: req.GetPipeline(), @@ -86,10 +88,16 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo options: req.GetPipelineOptions(), streamCond: sync.NewCond(&sync.Mutex{}), RootCtx: rootCtx, - CancelFn: cancelFn, - + CancelFn: func(err error) { + cancelFn(err) + terminalOnceWrap() + }, artifactEndpoint: s.Endpoint(), } + // Stop the idle timer when a new job appears. + if idleTimer := s.idleTimer.Load(); idleTimer != nil { + idleTimer.Stop() + } // Queue initial state of the job. job.state.Store(jobpb.JobState_STOPPED) @@ -155,7 +163,9 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo case urns.TransformParDo: var pardo pipepb.ParDoPayload if err := proto.Unmarshal(t.GetSpec().GetPayload(), &pardo); err != nil { - return nil, fmt.Errorf("unable to unmarshal ParDoPayload for %v - %q: %w", tid, t.GetUniqueName(), err) + wrapped := fmt.Errorf("unable to unmarshal ParDoPayload for %v - %q: %w", tid, t.GetUniqueName(), err) + job.Failed(wrapped) + return nil, wrapped } isStateful := false @@ -181,7 +191,9 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo case urns.TransformTestStream: var testStream pipepb.TestStreamPayload if err := proto.Unmarshal(t.GetSpec().GetPayload(), &testStream); err != nil { - return nil, fmt.Errorf("unable to unmarshal TestStreamPayload for %v - %q: %w", tid, t.GetUniqueName(), err) + wrapped := fmt.Errorf("unable to unmarshal TestStreamPayload for %v - %q: %w", tid, t.GetUniqueName(), err) + job.Failed(wrapped) + return nil, wrapped } t.EnvironmentId = "" // Unset the environment, to ensure it's handled prism side. diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index 647e9ad962830..320159f54c063 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -16,10 +16,14 @@ package jobservices import ( + "context" "fmt" "math" "net" + "os" "sync" + "sync/atomic" + "time" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" @@ -39,9 +43,17 @@ type Server struct { // Job Management mu sync.Mutex - index uint32 + index uint32 // Use with atomics. jobs map[string]*Job + // IdleShutdown management. Needs to use atomics, since they + // may be both while already holding the lock, or when not + // (eg via job state). + idleTimer atomic.Pointer[time.Timer] + terminatedJobCount uint32 // Use with atomics. + idleTimeout time.Duration + cancelFn context.CancelCauseFunc + // execute defines how a job is executed. execute func(*Job) @@ -91,3 +103,42 @@ func (s *Server) Serve() { func (s *Server) Stop() { s.server.GracefulStop() } + +// IdleShutdown allows the server to call the cancelFn if there have been no active jobs +// for at least the given timeout. +func (s *Server) IdleShutdown(timeout time.Duration, cancelFn context.CancelCauseFunc) { + s.mu.Lock() + defer s.mu.Unlock() + s.idleTimeout = timeout + s.cancelFn = cancelFn + + // Stop gap to kill the process less gracefully. + if s.cancelFn == nil { + s.cancelFn = func(cause error) { + os.Exit(1) + } + } + + s.idleTimer.Store(time.AfterFunc(timeout, s.idleShutdownCallback)) +} + +// idleShutdownCallback is called by the AfterFunc timer for idle shutdown. +func (s *Server) idleShutdownCallback() { + index := atomic.LoadUint32(&s.index) + terminated := atomic.LoadUint32(&s.terminatedJobCount) + if index == terminated { + slog.Info("shutting down after being idle", "idleTimeout", s.idleTimeout) + s.cancelFn(nil) + } +} + +// jobTerminated marks that the job has been terminated, and if there are no active jobs, starts the idle timer. +func (s *Server) jobTerminated() { + if s.idleTimer.Load() != nil { + terminated := atomic.AddUint32(&s.terminatedJobCount, 1) + total := atomic.LoadUint32(&s.index) + if total == terminated { + s.idleTimer.Store(time.AfterFunc(s.idleTimeout, s.idleShutdownCallback)) + } + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/web/web.go b/sdks/go/pkg/beam/runners/prism/internal/web/web.go index 18a3140e033dc..9fabe22cee3a8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/web/web.go +++ b/sdks/go/pkg/beam/runners/prism/internal/web/web.go @@ -24,8 +24,6 @@ import ( "embed" "encoding/json" "fmt" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "html/template" "io" "net/http" @@ -33,6 +31,9 @@ import ( "strings" "sync" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/metricsx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex" @@ -435,5 +436,6 @@ func Initialize(ctx context.Context, port int, jobcli jobpb.JobServiceClient) er endpoint := fmt.Sprintf("localhost:%d", port) slog.Info("Serving WebUI", slog.String("endpoint", "http://"+endpoint)) - return http.ListenAndServe(endpoint, mux) + go http.ListenAndServe(endpoint, mux) + return nil } diff --git a/sdks/go/pkg/beam/runners/prism/prism.go b/sdks/go/pkg/beam/runners/prism/prism.go index bcb7a3fb689f8..e260a7bb7ecd5 100644 --- a/sdks/go/pkg/beam/runners/prism/prism.go +++ b/sdks/go/pkg/beam/runners/prism/prism.go @@ -19,6 +19,7 @@ package prism import ( "context" + "time" "github.com/apache/beam/sdks/v2/go/pkg/beam" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" @@ -58,13 +59,24 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) // Options for in process server creation. type Options struct { + // Port the Job Management Server should start on. Port int + + // The time prism will wait for new jobs before shuting itself down. + IdleShutdownTimeout time.Duration + // CancelFn allows Prism to terminate the program due to it's internal state, such as via the idle shutdown timeout. + // If unset, os.Exit(1) will be called instead. + CancelFn context.CancelCauseFunc } // CreateJobServer returns a Beam JobServicesClient connected to an in memory JobServer. // This call is non-blocking. func CreateJobServer(ctx context.Context, opts Options) (jobpb.JobServiceClient, error) { s := jobservices.NewServer(opts.Port, internal.RunPipeline) + + if opts.IdleShutdownTimeout > 0 { + s.IdleShutdown(opts.IdleShutdownTimeout, opts.CancelFn) + } go s.Serve() clientConn, err := grpc.DialContext(ctx, s.Endpoint(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) if err != nil {