Skip to content

Commit

Permalink
Merge pull request #650 from thedadams/sdk-server-log-port
Browse files Browse the repository at this point in the history
feat: improve SDK server start up
  • Loading branch information
ibuildthecloud authored Jul 22, 2024
2 parents 3c29ebe + 3c53e2a commit e519494
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 24 deletions.
4 changes: 2 additions & 2 deletions pkg/cli/sdk_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ func (c *SDKServer) Run(cmd *cobra.Command, _ []string) error {
// Don't use cmd.Context() as we don't want to die on ctrl+c
ctx := context.Background()
if term.IsTerminal(int(os.Stdin.Fd())) {
// Only support CTRL+C if stdin is the terminal. When ran as a SDK it will be a pipe
// Only support CTRL+C if stdin is the terminal. When ran as an SDK it will be a pipe
ctx = cmd.Context()
}

return sdkserver.Start(ctx, sdkserver.Options{
return sdkserver.Run(ctx, sdkserver.Options{
Options: opts,
ListenAddress: c.ListenAddress,
Debug: c.Debug,
Expand Down
74 changes: 52 additions & 22 deletions pkg/sdkserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
Expand All @@ -29,7 +28,18 @@ type Options struct {
Debug bool
}

func Start(ctx context.Context, opts Options) error {
// Run will start the server and block until the server is shut down.
func Run(ctx context.Context, opts Options) error {
listener, err := newListener(opts)
if err != nil {
return err
}

_, err = io.WriteString(os.Stderr, listener.Addr().String()+"\n")
if err != nil {
return fmt.Errorf("failed to write to address to stderr: %w", err)
}

sigCtx, cancel := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL)
defer cancel()
go func() {
Expand All @@ -40,6 +50,34 @@ func Start(ctx context.Context, opts Options) error {
cancel()
}()

return run(sigCtx, listener, opts)
}

// EmbeddedStart allows running the server as an embedded process that may use Stdin for input.
// It returns the address the server is listening on.
func EmbeddedStart(ctx context.Context, opts Options) (string, error) {
listener, err := newListener(opts)
if err != nil {
return "", err
}

go func() {
_ = run(ctx, listener, opts)
}()

return listener.Addr().String(), nil
}

func (s *server) close() {
s.client.Close(true)
s.events.Close()
}

func newListener(opts Options) (net.Listener, error) {
return net.Listen("tcp", opts.ListenAddress)
}

func run(ctx context.Context, listener net.Listener, opts Options) error {
if opts.Debug {
mvl.SetDebug()
}
Expand All @@ -58,11 +96,6 @@ func Start(ctx context.Context, opts Options) error {
return err
}

listener, err := net.Listen("tcp", opts.ListenAddress)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", opts.ListenAddress, err)
}

s := &server{
gptscriptOpts: opts.Options,
address: listener.Addr().String(),
Expand All @@ -72,11 +105,11 @@ func Start(ctx context.Context, opts Options) error {
waitingToConfirm: make(map[string]chan runner.AuthorizerResponse),
waitingToPrompt: make(map[string]chan map[string]string),
}
defer s.Close()
defer s.close()

s.addRoutes(http.DefaultServeMux)

server := http.Server{
httpServer := &http.Server{
Handler: apply(http.DefaultServeMux,
contentType("application/json"),
addRequestID,
Expand All @@ -86,25 +119,22 @@ func Start(ctx context.Context, opts Options) error {
),
}

slog.Info("Starting server", "addr", s.address)

context.AfterFunc(sigCtx, func() {
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
logger := mvl.Package()
done := make(chan struct{})
context.AfterFunc(ctx, func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()

slog.Info("Shutting down server")
_ = server.Shutdown(ctx)
slog.Info("Server stopped")
logger.Infof("Shutting down server")
_ = httpServer.Shutdown(ctx)
logger.Infof("Server stopped")
close(done)
})

if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
if err = httpServer.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server error: %w", err)
}

<-done
return nil
}

func (s *server) Close() {
s.client.Close(true)
s.events.Close()
}

0 comments on commit e519494

Please sign in to comment.