Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect and handle vici client connection errors #26

Merged
merged 1 commit into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions internal/vici/clientConn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,31 @@ func NewClientConn(conn net.Conn) *ClientConn {
eventHandlers: map[string]func(response map[string]interface{}){},
ReadTimeout: DefaultReadTimeout,
}
go client.readThread()
return client
}

// Listen listens for data on configured net.Conn. This method is blocking until
// ClientConn.Close() is called or an unrecoverable error occours.
func (c *ClientConn) Listen() error {
for {
outMsg, err := readSegment(c.conn)
if err != nil {
return fmt.Errorf("vici: read segment: %w", err)
}
switch outMsg.typ {
case stCMD_RESPONSE, stEVENT_CONFIRM:
c.responseChan <- outMsg
case stEVENT:
handler := c.eventHandlers[outMsg.name]
if handler != nil {
handler(outMsg.msg)
}
default:
return fmt.Errorf("vici: unprocessable message type '%s': raw message: %+v", outMsg.typ, outMsg)
}
}
}

func (c *ClientConn) Request(apiname string, concretePayload interface{}) (map[string]interface{}, error) {
var request map[string]interface{}
if concretePayload != nil {
Expand Down Expand Up @@ -125,25 +146,3 @@ func (c *ClientConn) UnregisterEvent(name string) error {
delete(c.eventHandlers, name)
return nil
}

func (c *ClientConn) readThread() {
for {
outMsg, err := readSegment(c.conn)
if err != nil {
c.lastError = err
return
}
switch outMsg.typ {
case stCMD_RESPONSE, stEVENT_CONFIRM:
c.responseChan <- outMsg
case stEVENT:
handler := c.eventHandlers[outMsg.name]
if handler != nil {
handler(outMsg.msg)
}
default:
c.lastError = fmt.Errorf("[Client.readThread] unknow msg type %d", outMsg.typ)
return
}
}
}
98 changes: 98 additions & 0 deletions internal/vici/clientConn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package vici

import (
"errors"
"io"
"net"
"sync"
"testing"

"github.com/stretchr/testify/require"
)

// TestClientConn_Close_listen tests that Listen returns when Close is called.
func TestClientConn_Close_listen(t *testing.T) {
conn, _ := net.Pipe()
defer conn.Close()

client := NewClientConn(conn)

var wg sync.WaitGroup
defer wg.Wait()

// wg will never be released if Listen does not terminated leading to a test
// timeout.
wg.Add(1)
go func() {
defer wg.Done()
err := client.Listen()
t.Logf("Listen err: %v", err)
}()

err := client.Close()
require.NoError(t, err, "expected a Listen error")
}

// TestClientConn_Listen_connClosed tests that Listen returns an error if the
// net.Conn is closed on.
func TestClientConn_Listen_connClosed(t *testing.T) {
// create a net.Conn that is closed right away to simulate a failed network
// connection.
conn, _ := net.Pipe()
conn.Close()

client := NewClientConn(conn)
defer client.Close()

err := client.Listen()

t.Logf("Err: %v", err)

require.Error(t, err, "expected a Listen error")

if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("Error expected to be io.ErrClosedPipe but was: %v", err)
}
}

// TestClientConn_Listen_unhandleableSegmentType tests that Listen returns an
// error if un unprocessable segment type is received.
func TestClientConn_Listen_unprocessableSegmentType(t *testing.T) {
viciConn, clientConn := net.Pipe()
defer clientConn.Close()

var wg sync.WaitGroup
defer wg.Wait()

// write EVENT_UNKNOWN response from vici daemon to fake un unprocessable
// payload
wg.Add(1)
go func() {
defer wg.Done()
defer viciConn.Close()

message := []byte{
0x0, 0x0, 0x0, 0x1, // length 1 (single byte)
byte(stEVENT_UNKNOWN), // unprocessable but valid segment type
}
_, err := viciConn.Write(message)
if err != nil {
t.Logf("Failed to write vici message: %v", err)
return
}
t.Logf("Wrote message to vici conn: %v", message)
}()

client := NewClientConn(clientConn)
defer client.Close()

err := client.Listen()

t.Logf("Listen err: %v", err)

require.Error(t, err, "Listen() didn't return error")

if errors.Is(err, io.EOF) {
t.Fatalf("Listen() returned io.EOF: this happens on net.Conn.Close() so the listener did not stop in time")
}
}
23 changes: 23 additions & 0 deletions internal/vici/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,29 @@ const (
stEVENT segmentType = 7
)

func (s segmentType) String() string {
switch s {
case stCMD_REQUEST:
return "CMD_REQUEST"
case stCMD_RESPONSE:
return "CMD_RESPONSE"
case stCMD_UNKNOWN:
return "CMD_UNKNOWN"
case stEVENT_REGISTER:
return "EVENT_REGISTER"
case stEVENT_UNREGISTER:
return "EVENT_UNREGISTER"
case stEVENT_CONFIRM:
return "EVENT_CONFIRM"
case stEVENT_UNKNOWN:
return "EVENT_UNKNOWN"
case stEVENT:
return "EVENT"
default:
return fmt.Sprintf("unknown segment type: %d", s)
}
}

func (t segmentType) hasName() bool {
switch t {
case stCMD_REQUEST, stEVENT_REGISTER, stEVENT_UNREGISTER, stEVENT:
Expand Down
64 changes: 48 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,28 +151,14 @@ func main() {
}

if *enableReinitiator {
reinitiatorConn, err := net.Dial("unix", *socket)
if err != nil {
log.Errorf("Failed to establish socket connection to vici: %v", err)
os.Exit(1)
}
defer reinitiatorConn.Close()
reinitiatorClient := vici.NewClientConn(reinitiatorConn)
reinitiatorClient := viciClient(&shutdownWg, shutdown, componentDone, log.With("viciClient", "reinitiator"), *socket)
reinitiatorClient.ReadTimeout = 5 * time.Minute
defer reinitiatorClient.Close()

ikeSAStatusReceivers = append(ikeSAStatusReceivers, strongswan.NewReinitiator(reinitiatorClient, log.Base().With("name", "reinitiator")))
}

conn, err := net.Dial("unix", *socket)
if err != nil {
log.Errorf("Failed to establish socket connection to vici: %v", err)
os.Exit(1)
}
defer conn.Close()
client := vici.NewClientConn(conn)
client := viciClient(&shutdownWg, shutdown, componentDone, log.With("viciClient", "collector"), *socket)
client.ReadTimeout = 60 * time.Second
defer client.Close()

d := daemon.New(daemon.Configuration{
Reporter: prometheusReporter.Daemon(log.Base().With("name", "strongswan"), "strongswan"),
Expand Down Expand Up @@ -206,3 +192,49 @@ func main() {
log.Info("exited due to a component shutting down")
}
}

// viciClient returns a listening vici.ClientConn controlled by provided life
// cycle channels.
func viciClient(shutdownWg *sync.WaitGroup, shutdown chan struct{}, componentDone chan error, log log.Logger, socket string) *vici.ClientConn {
conn, err := net.Dial("unix", socket)
if err != nil {
log.Errorf("Failed to establish socket connection to vici on '%s': %v", socket, err)
os.Exit(1)
}
client := vici.NewClientConn(conn)

shutdownWg.Add(1)
go func() {
defer shutdownWg.Done()
log.Info("vici client shutdown listener started")
defer log.Info("vici client shutdown listener stopped")
<-shutdown

log.Info("Closing vici client listener")
err := client.Close()
if err != nil {
log.Errorf("Controlled close of vici client failed: %v", err)
}
}()

shutdownWg.Add(1)
go func() {
defer shutdownWg.Done()
log.Infof("vici client listening on %s", socket)
defer log.Info("vici client lister Go routine stopped")
err := client.Listen()
if err != nil {
// we don't know if Listen stopped due to a controlled shutdown or due
// to an underlying error. Log the error in the former case or report
// the component done if the shutdown is unexpected
select {
case componentDone <- fmt.Errorf("vici client listener stopped unexpectedly: %w", err):
return
default:
log.Infof("vici client listener stopped: %v", err)
}
}
}()

return client
}