diff --git a/deploy/piv-agent.service b/deploy/piv-agent.service index 4a644d3..2ebe98b 100644 --- a/deploy/piv-agent.service +++ b/deploy/piv-agent.service @@ -2,4 +2,4 @@ Description=piv-agent service [Service] -ExecStart=piv-agent serve --debug --agent-types=ssh=0;gpg=1 +ExecStart=piv-agent serve --agent-types=ssh=0;gpg=1 diff --git a/internal/assuan/assuan_test.go b/internal/assuan/assuan_test.go index 31aee80..8c9e633 100644 --- a/internal/assuan/assuan_test.go +++ b/internal/assuan/assuan_test.go @@ -2,6 +2,7 @@ package assuan_test import ( "bytes" + "context" "crypto" "crypto/ecdsa" "encoding/hex" @@ -153,7 +154,7 @@ func TestSign(t *testing.T) { } } // start the state machine - if err := a.Run(); err != nil { + if err := a.Run(context.Background()); err != nil { tt.Fatal(err) } // check the responses @@ -226,7 +227,7 @@ func TestKeyinfo(t *testing.T) { } } // start the state machine - if err := a.Run(); err != nil { + if err := a.Run(context.Background()); err != nil { tt.Fatal(err) } // check the responses @@ -349,7 +350,7 @@ func TestDecryptRSAKeyfile(t *testing.T) { } } // start the state machine - if err := a.Run(); err != nil { + if err := a.Run(context.Background()); err != nil { tt.Fatal(err) } // check the responses @@ -447,7 +448,7 @@ func TestSignRSAKeyfile(t *testing.T) { } } // start the state machine - if err := a.Run(); err != nil { + if err := a.Run(context.Background()); err != nil { tt.Fatal(err) } // check the responses @@ -533,7 +534,7 @@ func TestReadKey(t *testing.T) { } } // start the state machine - if err := a.Run(); err != nil { + if err := a.Run(context.Background()); err != nil { tt.Fatal(err) } // check the responses diff --git a/internal/assuan/run.go b/internal/assuan/run.go index dbdfe65..40c00c8 100644 --- a/internal/assuan/run.go +++ b/internal/assuan/run.go @@ -2,18 +2,24 @@ package assuan import ( "bytes" + "context" "fmt" "io" ) // Run the event machine loop -func (a *Assuan) Run() error { +func (a *Assuan) Run(ctx context.Context) error { // register connection if err := a.Occur(connect); err != nil { return fmt.Errorf("error handling connect: %w", err) } var e Event for { + // check for cancellation + if err := ctx.Err(); err != nil { + return err + } + // get the next command. returns at latest after conn deadline expiry. line, err := a.reader.ReadBytes(byte('\n')) if err != nil { if err == io.EOF { diff --git a/internal/server/gpg.go b/internal/server/gpg.go index 9cde7ca..cf7912b 100644 --- a/internal/server/gpg.go +++ b/internal/server/gpg.go @@ -12,6 +12,8 @@ import ( "go.uber.org/zap" ) +const connTimeout = 4 * time.Minute + // GPG represents a gpg-agent server. type GPG struct { log *zap.Logger @@ -39,26 +41,31 @@ func (g *GPG) Serve(ctx context.Context, l net.Listener, exit *time.Ticker, timeout time.Duration) error { // start serving connections conns := accept(g.log, l) - g.log.Debug("accepted gpg-agent connection") for { select { case conn, ok := <-conns: if !ok { return fmt.Errorf("listen socket closed") } + g.log.Debug("accepted gpg-agent connection") // reset the exit timer exit.Reset(timeout) - // if the client stops responding for 300 seconds, give up. - if err := conn.SetDeadline(time.Now().Add(300 * time.Second)); err != nil { + // if the client takes too long, give up + if err := conn.SetDeadline(time.Now().Add(connTimeout)); err != nil { return fmt.Errorf("couldn't set deadline: %v", err) } // init protocol state machine a := assuan.New(conn, g.log, g.pivKeyService, g.gpgKeyService) - // run the protocol state machine to completion - // (client severs connection) - if err := a.Run(); err != nil { - return err - } + // this goroutine will exit by either: + // * client severs connection (the usual case) + // * conn deadline reached (client stopped responding) + // err will be non-nil in this case. + go func() { + // run the protocol state machine to completion + if err := a.Run(ctx); err != nil { + g.log.Error("gpg-agent error", zap.Error(err)) + } + }() case <-ctx.Done(): return nil } diff --git a/internal/server/ssh.go b/internal/server/ssh.go index eb7af47..38360ab 100644 --- a/internal/server/ssh.go +++ b/internal/server/ssh.go @@ -42,7 +42,7 @@ func (s *SSH) Serve(ctx context.Context, a *ssh.Agent, l net.Listener, s.log.Debug("start serving SSH connection") if err := agent.ServeAgent(a, conn); err != nil { if errors.Is(err, io.EOF) { - s.log.Debug("finish serving connection") + s.log.Debug("finish serving SSH connection") continue } return fmt.Errorf("ssh Serve error: %w", err)