diff --git a/sdks/go/pkg/beam/log/log.go b/sdks/go/pkg/beam/log/log.go index feae77b6c971..4c1f5dddb018 100644 --- a/sdks/go/pkg/beam/log/log.go +++ b/sdks/go/pkg/beam/log/log.go @@ -22,6 +22,7 @@ import ( "context" "fmt" "os" + "sync/atomic" ) // Severity is the severity of the log message. @@ -44,9 +45,17 @@ type Logger interface { Log(ctx context.Context, sev Severity, calldepth int, msg string) } -var ( - logger Logger = &Standard{} -) +var logger atomic.Value + +// concreteLogger works around atomic.Value's requirement that the type +// be identical for all callers. +type concreteLogger struct { + Logger +} + +func init() { + logger.Store(&concreteLogger{&Standard{}}) +} // SetLogger sets the global Logger. Intended to be called during initialization // only. @@ -54,13 +63,13 @@ func SetLogger(l Logger) { if l == nil { panic("Logger cannot be nil") } - logger = l + logger.Store(&concreteLogger{l}) } // Output logs the given message to the global logger. Calldepth is the count // of the number of frames to skip when computing the file name and line number. func Output(ctx context.Context, sev Severity, calldepth int, msg string) { - logger.Log(ctx, sev, calldepth+1, msg) // +1 for this frame + logger.Load().(Logger).Log(ctx, sev, calldepth+1, msg) // +1 for this frame } // User-facing logging functions. diff --git a/sdks/go/pkg/beam/runners/universal/extworker/extworker.go b/sdks/go/pkg/beam/runners/universal/extworker/extworker.go index dc75c7c8ca5b..ffc8f8e47c09 100644 --- a/sdks/go/pkg/beam/runners/universal/extworker/extworker.go +++ b/sdks/go/pkg/beam/runners/universal/extworker/extworker.go @@ -65,6 +65,12 @@ func (s *Loopback) StartWorker(ctx context.Context, req *fnpb.StartWorkerRequest log.Infof(ctx, "starting worker %v", req.GetWorkerId()) s.mu.Lock() defer s.mu.Unlock() + if s.workers == nil { + return &fnpb.StartWorkerResponse{ + Error: "worker pool shutting down", + }, nil + } + if _, ok := s.workers[req.GetWorkerId()]; ok { return &fnpb.StartWorkerResponse{ Error: fmt.Sprintf("worker with ID %q already exists", req.GetWorkerId()), @@ -92,6 +98,10 @@ func (s *Loopback) StopWorker(ctx context.Context, req *fnpb.StopWorkerRequest) log.Infof(ctx, "stopping worker %v", req.GetWorkerId()) s.mu.Lock() defer s.mu.Unlock() + if s.workers == nil { + // Worker pool is already shutting down, so no action is needed. + return &fnpb.StopWorkerResponse{}, nil + } if cancelfn, ok := s.workers[req.GetWorkerId()]; ok { cancelfn() delete(s.workers, req.GetWorkerId()) @@ -106,12 +116,15 @@ func (s *Loopback) StopWorker(ctx context.Context, req *fnpb.StopWorkerRequest) // Stop terminates the service and stops all workers. func (s *Loopback) Stop(ctx context.Context) error { s.mu.Lock() - defer s.mu.Unlock() log.Infof(ctx, "stopping Loopback, and %d workers", len(s.workers)) - s.workers = map[string]context.CancelFunc{} - s.lis.Close() + s.workers = nil s.rootCancel() + + // There can be a deadlock between the StopWorker RPC and GracefulStop + // which waits for all RPCs to finish, so it must be outside the critical section. + s.mu.Unlock() + s.grpcServer.GracefulStop() return nil }