Skip to content

Commit

Permalink
coap: allow to set inactivity monitor (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkralik authored Jun 9, 2022
1 parent b45ec1b commit e287cc7
Showing 1 changed file with 39 additions and 31 deletions.
70 changes: 39 additions & 31 deletions pkg/net/coap/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
codecOcf "github.com/plgd-dev/kit/v2/codec/ocf"
)

var errInactivityTimeout = fmt.Errorf("connection inactivity has reached a fail limit: closing connection")

type Observation = interface {
Cancel(context.Context) error
Canceled() bool
Expand Down Expand Up @@ -392,7 +394,7 @@ func (c *ClientCloseHandler) UnregisterCloseHandler(closeHandlerID int) {
c.onClose.Remove(closeHandlerID)
}

func newClientCloseHandler(conn ClientConn, onClose *OnCloseHandler) *ClientCloseHandler {
func NewClientCloseHandler(conn ClientConn, onClose *OnCloseHandler) *ClientCloseHandler {
return &ClientCloseHandler{Client: NewClient(conn), onClose: onClose}
}

Expand All @@ -406,6 +408,7 @@ type dialOptions struct {
DisableTCPSignalMessageCSM bool
DisablePeerTCPSignalMessageCSMs bool
KeepaliveTimeout time.Duration
InactivityMonitorTimeout time.Duration
errors func(err error)
maxMessageSize uint32
dialer *net.Dialer
Expand Down Expand Up @@ -440,6 +443,14 @@ func WithKeepAlive(connectionTimeout time.Duration) DialOptionFunc {
}
}

// InactiveMonitor if connection is inactive for the given duration, it will be closed.
func WithInactivityMonitor(inactivityTimeout time.Duration) DialOptionFunc {
return func(c dialOptions) dialOptions {
c.InactivityMonitorTimeout = inactivityTimeout
return c
}
}

func WithErrors(errors func(err error)) DialOptionFunc {
return func(c dialOptions) dialOptions {
c.errors = errors
Expand Down Expand Up @@ -472,8 +483,13 @@ func WithBlockwise(enable bool, szx blockwise.SZX, transferTimeout time.Duration
}
}

func keepAliveTimeoutError() error {
return fmt.Errorf("keep alive has reached fail limit: closing connection")
func makeOnInactiveFunc(dialName string, errorsFn func(err error)) func(cc inactivity.ClientConn) {
return func(cc inactivity.ClientConn) {
if err := cc.Close(); err != nil {
errorsFn(fmt.Errorf("%v: %w", dialName, err))
}
errorsFn(errInactivityTimeout)
}
}

func DialUDP(ctx context.Context, addr string, opts ...DialOptionFunc) (*ClientCloseHandler, error) {
Expand All @@ -491,12 +507,10 @@ func DialUDP(ctx context.Context, addr string, opts ...DialOptionFunc) (*ClientC
dopts = append(dopts, udp.WithErrors(cfg.errors))
}
if cfg.KeepaliveTimeout != 0 {
dopts = append(dopts, udp.WithKeepAlive(3, cfg.KeepaliveTimeout/3, func(cc inactivity.ClientConn) {
if err := cc.Close(); err != nil {
errorsFn(fmt.Errorf("DialUDP: %w", err))
}
errorsFn(keepAliveTimeoutError())
}))
dopts = append(dopts, udp.WithKeepAlive(3, cfg.KeepaliveTimeout/3, makeOnInactiveFunc("DialUDP", errorsFn)))
}
if cfg.InactivityMonitorTimeout != 0 {
dopts = append(dopts, udp.WithInactivityMonitor(cfg.InactivityMonitorTimeout, makeOnInactiveFunc("DialUDP", errorsFn)))
}
if cfg.blockwise != nil {
dopts = append(dopts, udp.WithBlockwise(cfg.blockwise.enable, cfg.blockwise.szx, cfg.blockwise.transferTimeout))
Expand All @@ -522,7 +536,7 @@ func DialUDP(ctx context.Context, addr string, opts ...DialOptionFunc) (*ClientC
c.AddOnClose(func() {
h.OnClose(nil)
})
return newClientCloseHandler(c.Client(), h), nil
return NewClientCloseHandler(c.Client(), h), nil
}

func DialTCP(ctx context.Context, addr string, opts ...DialOptionFunc) (*ClientCloseHandler, error) {
Expand All @@ -540,12 +554,10 @@ func DialTCP(ctx context.Context, addr string, opts ...DialOptionFunc) (*ClientC
dopts = append(dopts, tcp.WithErrors(cfg.errors))
}
if cfg.KeepaliveTimeout != 0 {
dopts = append(dopts, tcp.WithKeepAlive(3, cfg.KeepaliveTimeout/3, func(cc inactivity.ClientConn) {
if err := cc.Close(); err != nil {
errorsFn(fmt.Errorf("DialTCP: %w", err))
}
errorsFn(keepAliveTimeoutError())
}))
dopts = append(dopts, tcp.WithKeepAlive(3, cfg.KeepaliveTimeout/3, makeOnInactiveFunc("DialTCP", errorsFn)))
}
if cfg.InactivityMonitorTimeout != 0 {
dopts = append(dopts, tcp.WithInactivityMonitor(cfg.InactivityMonitorTimeout, makeOnInactiveFunc("DialTCP", errorsFn)))
}
if cfg.DisablePeerTCPSignalMessageCSMs {
dopts = append(dopts, tcp.WithDisablePeerTCPSignalMessageCSMs())
Expand Down Expand Up @@ -577,7 +589,7 @@ func DialTCP(ctx context.Context, addr string, opts ...DialOptionFunc) (*ClientC
c.AddOnClose(func() {
h.OnClose(nil)
})
return newClientCloseHandler(c.Client(), h), nil
return NewClientCloseHandler(c.Client(), h), nil
}

func NewVerifyPeerCertificate(rootCAs *x509.CertPool, verifyPeerCertificate func(verifyPeerCertificate *x509.Certificate) error) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
Expand Down Expand Up @@ -632,12 +644,10 @@ func DialTCPSecure(ctx context.Context, addr string, tlsCfg *tls.Config, opts ..
dopts = append(dopts, tcp.WithErrors(cfg.errors))
}
if cfg.KeepaliveTimeout != 0 {
dopts = append(dopts, tcp.WithKeepAlive(3, cfg.KeepaliveTimeout/3, func(cc inactivity.ClientConn) {
if err := cc.Close(); err != nil {
errorsFn(fmt.Errorf("DialTCPSecure: %w", err))
}
errorsFn(keepAliveTimeoutError())
}))
dopts = append(dopts, tcp.WithKeepAlive(3, cfg.KeepaliveTimeout/3, makeOnInactiveFunc("DialTCPSecure", errorsFn)))
}
if cfg.InactivityMonitorTimeout != 0 {
dopts = append(dopts, tcp.WithInactivityMonitor(cfg.InactivityMonitorTimeout, makeOnInactiveFunc("DialTCPSecure", errorsFn)))
}
if cfg.DisablePeerTCPSignalMessageCSMs {
dopts = append(dopts, tcp.WithDisablePeerTCPSignalMessageCSMs())
Expand Down Expand Up @@ -669,7 +679,7 @@ func DialTCPSecure(ctx context.Context, addr string, tlsCfg *tls.Config, opts ..
c.AddOnClose(func() {
h.OnClose(nil)
})
return newClientCloseHandler(c.Client(), h), nil
return NewClientCloseHandler(c.Client(), h), nil
}

func DialUDPSecure(ctx context.Context, addr string, dtlsCfg *piondtls.Config, opts ...DialOptionFunc) (*ClientCloseHandler, error) {
Expand All @@ -694,12 +704,10 @@ func DialUDPSecure(ctx context.Context, addr string, dtlsCfg *piondtls.Config, o
dopts = append(dopts, dtls.WithErrors(cfg.errors))
}
if cfg.KeepaliveTimeout != 0 {
dopts = append(dopts, dtls.WithKeepAlive(3, cfg.KeepaliveTimeout/3, func(cc inactivity.ClientConn) {
if err := cc.Close(); err != nil {
errorsFn(fmt.Errorf("DialUDPSecure: %w", err))
}
errorsFn(keepAliveTimeoutError())
}))
dopts = append(dopts, dtls.WithKeepAlive(3, cfg.KeepaliveTimeout/3, makeOnInactiveFunc("DialUDPSecure", errorsFn)))
}
if cfg.InactivityMonitorTimeout != 0 {
dopts = append(dopts, dtls.WithInactivityMonitor(cfg.InactivityMonitorTimeout, makeOnInactiveFunc("DialUDPSecure", errorsFn)))
}
if cfg.blockwise != nil {
dopts = append(dopts, dtls.WithBlockwise(cfg.blockwise.enable, cfg.blockwise.szx, cfg.blockwise.transferTimeout))
Expand All @@ -725,5 +733,5 @@ func DialUDPSecure(ctx context.Context, addr string, dtlsCfg *piondtls.Config, o
c.AddOnClose(func() {
h.OnClose(nil)
})
return newClientCloseHandler(c.Client(), h), nil
return NewClientCloseHandler(c.Client(), h), nil
}

0 comments on commit e287cc7

Please sign in to comment.