Skip to content

Commit

Permalink
Merge pull request #85 from inexio/pre-release
Browse files Browse the repository at this point in the history
Pre release
  • Loading branch information
babos77 authored Oct 19, 2021
2 parents 5b26088 + 1b6bf96 commit 5a95c9a
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 49 deletions.
64 changes: 37 additions & 27 deletions api/request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/inexio/thola/internal/tholaerr"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"net/http"
Expand All @@ -19,10 +20,10 @@ import (
"time"
)

var deviceLocks struct {
var deviceChannels struct {
sync.RWMutex

locks map[string]*sync.Mutex
channels map[string]chan struct{}
}

// StartAPI starts the API.
Expand All @@ -36,7 +37,7 @@ func StartAPI() {
log.Fatal().Err(err).Msg("starting the server failed")
}

deviceLocks.locks = make(map[string]*sync.Mutex)
deviceChannels.channels = make(map[string]chan struct{})
e := echo.New()

e.HideBanner = true
Expand Down Expand Up @@ -988,35 +989,44 @@ func returnInFormat(ctx echo.Context, statusCode int, resp interface{}) error {
return ctx.String(http.StatusInternalServerError, "Invalid output format set")
}

func getDeviceLock(ip string) *sync.Mutex {
deviceLocks.RLock()
lock, ok := deviceLocks.locks[ip]
deviceLocks.RUnlock()
if !ok {
deviceLocks.Lock()
if lock, ok = deviceLocks.locks[ip]; !ok {
lock = &sync.Mutex{}
deviceLocks.locks[ip] = lock
}
deviceLocks.Unlock()
}
return lock
}

func handleAPIRequest(echoCTX echo.Context, r request.Request, ip *string) (request.Response, error) {
logger := log.With().Str("request_id", echoCTX.Request().Header.Get(echo.HeaderXRequestID)).Logger()
ctx := logger.WithContext(context.Background())
log.Ctx(ctx).Debug().Msg("incoming request")

if ip != nil && !viper.GetBool("request.no-ip-lock") {
lock := getDeviceLock(*ip)
lock.Lock()
defer func() {
lock.Unlock()
log.Ctx(ctx).Debug().Msg("unlocked IP " + *ip)
}()

log.Ctx(ctx).Debug().Msg("locked IP " + *ip)
ctx, cancel := request.CheckForTimeout(ctx, r)
defer cancel()

ch := getDeviceChannel(*ip)
select {
case <-ctx.Done():
return r.HandlePreProcessError(errors.New("request timed out while waiting on the IP lock"))
case <-ch:
log.Ctx(ctx).Debug().Msgf("locked IP '%s'", *ip)
defer func() {
ch <- struct{}{}
log.Ctx(ctx).Debug().Msgf("unlocked IP '%s'", *ip)
}()
return request.ProcessRequest(ctx, r)
}
} else {
return request.ProcessRequest(ctx, r)
}
}

return request.ProcessRequest(ctx, r)
func getDeviceChannel(ip string) chan struct{} {
deviceChannels.RLock()
ch, ok := deviceChannels.channels[ip]
deviceChannels.RUnlock()
if !ok {
deviceChannels.Lock()
if ch, ok = deviceChannels.channels[ip]; !ok {
ch = make(chan struct{}, 1)
ch <- struct{}{}
deviceChannels.channels[ip] = ch
}
deviceChannels.Unlock()
}
return ch
}
2 changes: 1 addition & 1 deletion internal/request/base_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func (r *BaseRequest) getTimeout() *int {
return r.Timeout
}

func (r *BaseRequest) handlePreProcessError(err error) (Response, error) {
func (r *BaseRequest) HandlePreProcessError(err error) (Response, error) {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions internal/request/check_device_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ type CheckDeviceRequest struct {
CheckRequest
}

func (r *CheckDeviceRequest) handlePreProcessError(err error) (Response, error) {
return r.CheckRequest.handlePreProcessError(err)
func (r *CheckDeviceRequest) HandlePreProcessError(err error) (Response, error) {
return r.CheckRequest.HandlePreProcessError(err)
}
10 changes: 0 additions & 10 deletions internal/request/check_identify_request_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,6 @@ func (r *CheckIdentifyRequest) process(ctx context.Context) (Response, error) {
}, nil
}

func (r *CheckIdentifyRequest) handlePreProcessError(err error) (Response, error) {
r.init()
r.mon.UpdateStatusOnError(err, monitoringplugin.UNKNOWN, err.Error(), false)
return &CheckIdentifyResponse{
CheckResponse: CheckResponse{r.mon.GetInfo()},
IdentifyResult: nil,
FailedExpectations: nil,
}, nil
}

func (r *CheckIdentifyRequest) validate(ctx context.Context) error {
err := r.BaseRequest.validate(ctx)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/request/check_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (r *CheckRequest) init() {
r.mon.SetPerformanceDataJSONLabel(r.JSONMetrics)
}

func (r *CheckRequest) handlePreProcessError(err error) (Response, error) {
func (r *CheckRequest) HandlePreProcessError(err error) (Response, error) {
r.init()
r.mon.UpdateStatusOnError(err, monitoringplugin.UNKNOWN, err.Error(), false)
return &CheckResponse{r.mon.GetInfo()}, nil
Expand Down
16 changes: 9 additions & 7 deletions internal/request/process_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,25 @@ type response struct {

// ProcessRequest is called by every request Thola receives
func ProcessRequest(ctx context.Context, request Request) (Response, error) {
ctx, cancel := CheckForTimeout(ctx, request)
defer cancel()

err := request.validate(ctx)
if err != nil {
return request.handlePreProcessError(errors.Wrap(err, "invalid request"))
return request.HandlePreProcessError(errors.Wrap(err, "invalid request"))
}
ctx, cancel := checkForTimeout(ctx, request)
defer cancel()

responseChannel := make(chan response)
go processRequest(ctx, request, responseChannel)
select {
case res := <-responseChannel:
return res.res, res.err
case <-ctx.Done():
return request.handlePreProcessError(errors.New("request timed out"))
return request.HandlePreProcessError(errors.New("request timed out"))
}
}

func checkForTimeout(ctx context.Context, request Request) (context.Context, context.CancelFunc) {
func CheckForTimeout(ctx context.Context, request Request) (context.Context, context.CancelFunc) {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
if timeout := request.getTimeout(); timeout != nil && *timeout != 0 {
Expand All @@ -47,7 +49,7 @@ func checkForTimeout(ctx context.Context, request Request) (context.Context, con
func processRequest(ctx context.Context, request Request, responseChan chan response) {
defer func() {
if r := recover(); r != nil {
res, err := request.handlePreProcessError(errors.New("thola paniced: " + fmt.Sprint(r)))
res, err := request.HandlePreProcessError(errors.New("thola paniced: " + fmt.Sprint(r)))
responseChan <- response{
res: res,
err: err,
Expand All @@ -56,7 +58,7 @@ func processRequest(ctx context.Context, request Request, responseChan chan resp
}()
con, err := request.setupConnection(ctx)
if err != nil {
res, err := request.handlePreProcessError(err)
res, err := request.HandlePreProcessError(err)
responseChan <- response{
res: res,
err: err,
Expand Down
7 changes: 6 additions & 1 deletion internal/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ import (

// Request is the interface which all requests must implement.
type Request interface {
// HandlePreProcessError implements request specific error handling (e.g. sets state to UNKNOWN and exit code to 3 in case
// that the request is a check request).
// Always call HandlePreProcessError if you want to correctly exit with an error before you call the process function
// on a request.
HandlePreProcessError(error) (Response, error)

validate(ctx context.Context) error
getTimeout() *int
setupConnection(ctx context.Context) (*network.RequestDeviceConnection, error)
process(ctx context.Context) (Response, error)
handlePreProcessError(error) (Response, error)
}

// Response is a generic interface that is returned by any Request.
Expand Down

0 comments on commit 5a95c9a

Please sign in to comment.