diff --git a/authz/providers/azure/azure.go b/authz/providers/azure/azure.go index 8553fcc09..99d481eaf 100644 --- a/authz/providers/azure/azure.go +++ b/authz/providers/azure/azure.go @@ -16,6 +16,7 @@ limitations under the License. package azure import ( + "net/http" "strings" "sync" @@ -24,6 +25,7 @@ import ( authzOpts "go.kubeguard.dev/guard/authz/providers/azure/options" "go.kubeguard.dev/guard/authz/providers/azure/rbac" azureutils "go.kubeguard.dev/guard/util/azure" + errutils "go.kubeguard.dev/guard/util/error" "github.com/Azure/go-autorest/autorest/azure" "github.com/pkg/errors" @@ -78,7 +80,7 @@ func newAuthzClient(opts authzOpts.Options, authopts auth.Options, operationsMap func (s Authorizer) Check(request *authzv1.SubjectAccessReviewSpec, store authz.Store) (*authzv1.SubjectAccessReviewStatus, error) { if request == nil { - return nil, errors.New("subject access review is nil") + return nil, errutils.WithCode(errors.New("subject access review is nil"), http.StatusBadRequest) } // check if user is system accounts @@ -118,7 +120,9 @@ func (s Authorizer) Check(request *authzv1.SubjectAccessReviewSpec, store authz. } if s.rbacClient.IsTokenExpired() { - _ = s.rbacClient.RefreshToken() + if err := s.rbacClient.RefreshToken(); err != nil { + return nil, errutils.WithCode(err, http.StatusInternalServerError) + } } response, err := s.rbacClient.CheckAccess(request) @@ -126,7 +130,11 @@ func (s Authorizer) Check(request *authzv1.SubjectAccessReviewSpec, store authz. klog.V(5).Infof(response.Reason) _ = s.rbacClient.SetResultInCache(request, response.Allowed, store) } else { - err = errors.Errorf(rbac.CheckAccessErrorFormat, err) + code := http.StatusInternalServerError + if v, ok := err.(errutils.HttpStatusCode); ok { + code = v.Code() + } + err = errutils.WithCode(errors.Errorf(rbac.CheckAccessErrorFormat, err), code) } return response, err diff --git a/authz/providers/azure/azure_test.go b/authz/providers/azure/azure_test.go index 906975f4e..24c54936d 100644 --- a/authz/providers/azure/azure_test.go +++ b/authz/providers/azure/azure_test.go @@ -29,6 +29,7 @@ import ( authzOpts "go.kubeguard.dev/guard/authz/providers/azure/options" "go.kubeguard.dev/guard/authz/providers/azure/rbac" azureutils "go.kubeguard.dev/guard/util/azure" + errutils "go.kubeguard.dev/guard/util/error" "github.com/appscode/pat" "github.com/stretchr/testify/assert" @@ -71,7 +72,7 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) { return c, nil } -func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStatus int) (*httptest.Server, error) { +func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStatus int, sleepFor time.Duration) (*httptest.Server, error) { listener, err := net.Listen("tcp", "127.0.0.1:") if err != nil { return nil, err @@ -85,6 +86,7 @@ func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStat })) m.Post("/arm/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(sleepFor) w.WriteHeader(checkaccessStatus) _, _ = w.Write([]byte(checkaccessResp)) })) @@ -98,8 +100,8 @@ func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStat return srv, nil } -func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkaccessStatus int) (*httptest.Server, *Authorizer, authz.Store) { - srv, err := serverSetup(loginResp, checkaccessResp, http.StatusOK, checkaccessStatus) +func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkaccessStatus int, sleepFor time.Duration) (*httptest.Server, *Authorizer, authz.Store) { + srv, err := serverSetup(loginResp, checkaccessResp, http.StatusOK, checkaccessStatus, sleepFor) if err != nil { t.Fatalf("Error when creating server, reason: %v", err) } @@ -129,7 +131,7 @@ func TestCheck(t *testing.T) { "actionId":"Microsoft.Kubernetes/connectedClusters/pods/delete", "isDataAction":true,"roleAssignment":null,"denyAssignment":null,"timeToLiveInMs":300000}]` - srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusOK) + srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusOK, 1*time.Second) defer srv.Close() defer store.Close() @@ -146,11 +148,14 @@ func TestCheck(t *testing.T) { assert.NotNil(t, resp) assert.Equal(t, resp.Allowed, true) assert.Equal(t, resp.Denied, false) + if v, ok := err.(errutils.HttpStatusCode); ok { + assert.Equal(t, v.Code(), http.StatusOK) + } }) t.Run("unsuccessful request", func(t *testing.T) { validBody := `""` - srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError) + srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second) defer srv.Close() defer store.Close() @@ -166,5 +171,31 @@ func TestCheck(t *testing.T) { assert.Nilf(t, resp, "response should be nil") assert.NotNilf(t, err, "should get error") assert.Contains(t, err.Error(), "Error occured during authorization check") + if v, ok := err.(errutils.HttpStatusCode); ok { + assert.Equal(t, v.Code(), http.StatusInternalServerError) + } + }) + + t.Run("context timeout request", func(t *testing.T) { + validBody := `""` + srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 25*time.Second) + defer srv.Close() + defer store.Close() + + request := &authzv1.SubjectAccessReviewSpec{ + User: "beta@bing.com", + ResourceAttributes: &authzv1.ResourceAttributes{ + Namespace: "dev", Group: "", Resource: "pods", + Subresource: "status", Version: "v1", Name: "test", Verb: "delete", + }, Extra: map[string]authzv1.ExtraValue{"oid": {"00000000-0000-0000-0000-000000000000"}}, + } + + resp, err := client.Check(request, store) + assert.Nilf(t, resp, "response should be nil") + assert.NotNilf(t, err, "should get error") + assert.Contains(t, err.Error(), "Checkaccess requests have timed out") + if v, ok := err.(errutils.HttpStatusCode); ok { + assert.Equal(t, v.Code(), http.StatusInternalServerError) + } }) } diff --git a/authz/providers/azure/rbac/checkaccessreqhelper.go b/authz/providers/azure/rbac/checkaccessreqhelper.go index 55ccc9904..396fbbd79 100644 --- a/authz/providers/azure/rbac/checkaccessreqhelper.go +++ b/authz/providers/azure/rbac/checkaccessreqhelper.go @@ -18,10 +18,12 @@ package rbac import ( "encoding/json" "fmt" + "net/http" "path" "strings" azureutils "go.kubeguard.dev/guard/util/azure" + errutils "go.kubeguard.dev/guard/util/error" "github.com/google/uuid" "github.com/pkg/errors" @@ -494,16 +496,16 @@ func prepareCheckAccessRequestBody(req *authzv1.SubjectAccessReviewSpec, cluster val := oid.String() userOid = val[1 : len(val)-1] if !isValidUUID(userOid) { - return nil, errors.New("oid info sent from authentication module is not valid") + return nil, errutils.WithCode(errors.New("oid info sent from authentication module is not valid"), http.StatusBadRequest) } } else { - return nil, errors.New("oid info not sent from authentication module") + return nil, errutils.WithCode(errors.New("oid info not sent from authentication module"), http.StatusBadRequest) } groups := getValidSecurityGroups(req.Groups) username = req.User actions, err := getDataActions(req, clusterType, operationsMap) if err != nil { - return nil, errors.Wrap(err, "Error while creating list of dataactions for check access call") + return nil, errutils.WithCode(errors.Wrap(err, "Error while creating list of dataactions for check access call"), http.StatusInternalServerError) } var checkAccessReqs []*CheckAccessRequest for i := 0; i < len(actions); i += ActionBatchCount { @@ -547,7 +549,7 @@ func ConvertCheckAccessResponse(body []byte) (*authzv1.SubjectAccessReviewStatus err := json.Unmarshal(body, &response) if err != nil { klog.V(10).Infof("Failed to parse checkacccess response. Error:%s", err.Error()) - return nil, errors.Wrap(err, "Error in unmarshalling check access response.") + return nil, errutils.WithCode(errors.Wrap(err, "Error in unmarshalling check access response."), http.StatusInternalServerError) } deniedResultFound := slices.IndexFunc(response, func(a AuthorizationDecision) bool { return strings.ToLower(a.Decision) != Allowed }) diff --git a/authz/providers/azure/rbac/rbac.go b/authz/providers/azure/rbac/rbac.go index be0e05e3b..ef8b03133 100644 --- a/authz/providers/azure/rbac/rbac.go +++ b/authz/providers/azure/rbac/rbac.go @@ -17,6 +17,7 @@ package rbac import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -33,24 +34,29 @@ import ( "go.kubeguard.dev/guard/authz" authzOpts "go.kubeguard.dev/guard/authz/providers/azure/options" azureutils "go.kubeguard.dev/guard/util/azure" + errutils "go.kubeguard.dev/guard/util/error" "go.kubeguard.dev/guard/util/httpclient" + "github.com/google/uuid" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "golang.org/x/sync/errgroup" v "gomodules.xyz/x/version" authzv1 "k8s.io/api/authorization/v1" "k8s.io/klog/v2" ) const ( - managedClusters = "Microsoft.ContainerService/managedClusters" - fleets = "Microsoft.ContainerService/fleets" - connectedClusters = "Microsoft.Kubernetes/connectedClusters" - checkAccessPath = "/providers/Microsoft.Authorization/checkaccess" - checkAccessAPIVersion = "2018-09-01-preview" - remainingSubReadARMHeader = "x-ms-ratelimit-remaining-subscription-reads" - expiryDelta = 60 * time.Second + managedClusters = "Microsoft.ContainerService/managedClusters" + fleets = "Microsoft.ContainerService/fleets" + connectedClusters = "Microsoft.Kubernetes/connectedClusters" + checkAccessPath = "/providers/Microsoft.Authorization/checkaccess" + checkAccessAPIVersion = "2018-09-01-preview" + remainingSubReadARMHeader = "x-ms-ratelimit-remaining-subscription-reads" + expiryDelta = 60 * time.Second + checkaccessContextTimeout = 23 * time.Second + correlationRequestIDHeader = "x-ms-correlation-request-id" ) type AuthzInfo struct { @@ -58,12 +64,10 @@ type AuthzInfo struct { ARMEndPoint string } -type reviewResult struct { - status *authzv1.SubjectAccessReviewStatus - err error -} - -type void struct{} +type ( + void struct{} + correlationRequestIDKey string +) // AccessInfo allows you to check user access from MS RBAC type AccessInfo struct { @@ -88,24 +92,56 @@ type AccessInfo struct { var ( checkAccessThrottled = promauto.NewCounter(prometheus.CounterOpts{ Name: "guard_azure_checkaccess_throttling_failure_total", - Help: "Azure checkaccess call throttled.", - }) - checkAccessTotal = promauto.NewCounter(prometheus.CounterOpts{ - Name: "guard_azure_check_access_requests_total", - Help: "Azure number of checkaccess request calls.", - }) - checkAccessFailed = promauto.NewCounter(prometheus.CounterOpts{ - Name: "guard_azure_checkaccess_failure_total", - Help: "Azure checkaccess failed calls.", + Help: "No of throttled checkaccess calls.", }) + + checkAccessTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "guard_azure_check_access_requests_total", + Help: "Number of checkaccess request calls.", + }, + []string{"code"}, + ) + + checkAccessFailed = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "guard_azure_checkaccess_failure_total", + Help: "No of checkaccess failures", + }, + []string{"code"}, + ) + checkAccessSucceeded = promauto.NewCounter(prometheus.CounterOpts{ Name: "guard_azure_checkaccess_success_total", - Help: "Azure checkaccess success calls.", + Help: "Number of successful checkaccess calls.", }) + checkAccessContextTimedOutCount = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "guard_azure_checkaccess_context_timeout", + Help: "No of checkacces context timeout calls", + }, + []string{"checkAccessBatchCount", "totalActionsCount"}, + ) + + // checkAccessDuration is partitioned by the HTTP status code It uses custom + // buckets based on the expected request duration. + checkAccessDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "guard_azure_checkaccess_request_duration_seconds", + Help: "A histogram of latencies for requests.", + Buckets: []float64{.25, .5, 1, 2.5, 5, 10, 15, 20}, + }, + []string{"code"}, + ) + CheckAccessErrorFormat = "Error occured during authorization check. Please retry again. Error: %s" ) +func init() { + prometheus.MustRegister(checkAccessDuration, checkAccessTotal, checkAccessFailed, checkAccessContextTimedOutCount) +} + func getClusterType(clsType string) string { switch clsType { case authzOpts.ARCAuthzMode: @@ -272,47 +308,69 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut params.Add("api-version", checkAccessAPIVersion) checkAccessURL.RawQuery = params.Encode() - var wg sync.WaitGroup // New wait group + ctx, cancel := context.WithTimeout(context.Background(), checkaccessContextTimeout) + defer cancel() + eg, egCtx := errgroup.WithContext(ctx) - ch := make(chan reviewResult, len(checkAccessBodies)) + ch := make(chan *authzv1.SubjectAccessReviewStatus, len(checkAccessBodies)) if len(checkAccessBodies) > 1 { klog.V(5).Infof("Number of checkaccess requests to make: %d", len(checkAccessBodies)) } + eg.SetLimit(len(checkAccessBodies)) for _, checkAccessBody := range checkAccessBodies { - wg.Add(1) - go a.sendCheckAccessRequest(checkAccessURL, checkAccessBody, &wg, ch) + body := checkAccessBody + eg.Go(func() error { + // create a request id for every checkaccess request + requestUUID := uuid.New() + reqContext := context.WithValue(egCtx, correlationRequestIDKey(correlationRequestIDHeader), []string{requestUUID.String()}) + err := a.sendCheckAccessRequest(reqContext, checkAccessURL, body, ch) + if err != nil { + code := http.StatusInternalServerError + if v, ok := err.(errutils.HttpStatusCode); ok { + code = v.Code() + } + err = errutils.WithCode(errors.Errorf("Error: %s. Correlation ID: %s", requestUUID.String(), err), code) + return err + } + return nil + }) } - go func() { - wg.Wait() - close(ch) - }() - - var finalResult *authzv1.SubjectAccessReviewStatus - for result := range ch { - if result.err != nil { - return nil, result.err + if err := eg.Wait(); err != nil { + if ctx.Err() == context.DeadlineExceeded { + klog.V(5).Infof("Checkaccess requests have timed out. Error: %v", ctx.Err()) + actionsCount := 0 + for i := 0; i < len(checkAccessBodies); i += 1 { + actionsCount = actionsCount + len(checkAccessBodies[i].Actions) + } + checkAccessContextTimedOutCount.WithLabelValues(azureutils.ConvertIntToString(len(checkAccessBodies)), azureutils.ConvertIntToString(actionsCount)).Inc() + close(ch) + return nil, errutils.WithCode(errors.Wrap(ctx.Err(), "Checkaccess requests have timed out."), http.StatusInternalServerError) + } else { + close(ch) + // print error we get from sendcheckAccessRequest + klog.Error(err) + return nil, err } + } + close(ch) - if result.status.Denied { - finalResult = result.status + var finalStatus *authzv1.SubjectAccessReviewStatus + for status := range ch { + if status.Denied { + finalStatus = status break } - finalResult = result.status + finalStatus = status } - - return finalResult, nil + return finalStatus, nil } -func (a *AccessInfo) sendCheckAccessRequest(checkAccessURL url.URL, checkAccessBody *CheckAccessRequest, wg *sync.WaitGroup, ch chan reviewResult) { - defer wg.Done() - reviewResult := reviewResult{} +func (a *AccessInfo) sendCheckAccessRequest(ctx context.Context, checkAccessURL url.URL, checkAccessBody *CheckAccessRequest, ch chan *authzv1.SubjectAccessReviewStatus) error { buf := new(bytes.Buffer) if err := json.NewEncoder(buf).Encode(checkAccessBody); err != nil { - reviewResult.err = errors.Wrap(err, "error encoding check access request") - ch <- reviewResult - return + return errutils.WithCode(errors.Wrap(err, "error encoding check access request"), http.StatusInternalServerError) } if klog.V(10).Enabled() { @@ -321,36 +379,42 @@ func (a *AccessInfo) sendCheckAccessRequest(checkAccessURL url.URL, checkAccessB klog.V(10).Infof("binary data:%s", binaryData) } - req, err := http.NewRequest(http.MethodPost, checkAccessURL.String(), buf) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, checkAccessURL.String(), buf) if err != nil { - reviewResult.err = errors.Wrap(err, "error creating check access request") - ch <- reviewResult - return + return errutils.WithCode(errors.Wrap(err, "error creating check access request"), http.StatusInternalServerError) } a.setReqHeaders(req) - + // set x-ms-correlation-request-id for the checkaccess request + correlationID := ctx.Value(correlationRequestIDKey(correlationRequestIDHeader)).([]string) + req.Header[correlationRequestIDHeader] = correlationID + internalServerCode := azureutils.ConvertIntToString(http.StatusInternalServerError) + // start time to calculate checkaccess duration + start := time.Now() + klog.V(5).Infof("Sending checkAccess request with correlationID: %s", correlationID[0]) resp, err := a.client.Do(req) + duration := time.Since(start).Seconds() if err != nil { - reviewResult.err = errors.Wrap(err, "error in check access request execution") - ch <- reviewResult - return + checkAccessTotal.WithLabelValues(internalServerCode).Inc() + checkAccessDuration.WithLabelValues(internalServerCode).Observe(duration) + return errutils.WithCode(errors.Wrap(err, "error in check access request execution."), http.StatusInternalServerError) } defer resp.Body.Close() - - checkAccessTotal.Inc() + respStatusCode := azureutils.ConvertIntToString(resp.StatusCode) + checkAccessTotal.WithLabelValues(respStatusCode).Inc() + checkAccessDuration.WithLabelValues(respStatusCode).Observe(duration) data, err := io.ReadAll(resp.Body) if err != nil { - reviewResult.err = errors.Wrap(err, "error in reading response body") - ch <- reviewResult - return + checkAccessTotal.WithLabelValues(internalServerCode).Inc() + checkAccessDuration.WithLabelValues(internalServerCode).Observe(duration) + return errutils.WithCode(errors.Wrap(err, "error in reading response body"), http.StatusInternalServerError) } klog.V(7).Infof("checkaccess response: %s, Configured ARM call limit: %d", string(data), a.armCallLimit) if resp.StatusCode != http.StatusOK { - klog.Errorf("error in check access response. error code: %d, response: %s", resp.StatusCode, string(data)) + klog.Errorf("error in check access response. error code: %d, response: %s, correlationID: %s", resp.StatusCode, string(data), correlationID[0]) // metrics for calls with StatusCode >= 300 if resp.StatusCode >= http.StatusMultipleChoices { if resp.StatusCode == http.StatusTooManyRequests { @@ -359,15 +423,13 @@ func (a *AccessInfo) sendCheckAccessRequest(checkAccessURL url.URL, checkAccessB checkAccessThrottled.Inc() } - checkAccessFailed.Inc() + checkAccessFailed.WithLabelValues(respStatusCode).Inc() } - reviewResult.err = errors.Errorf("request %s failed with status code: %d and response: %s", req.URL.Path, resp.StatusCode, string(data)) - ch <- reviewResult - return + return errutils.WithCode(errors.Errorf("request %s failed with status code: %d and response: %s", req.URL.Path, resp.StatusCode, string(data)), resp.StatusCode) } else { remaining := resp.Header.Get(remainingSubReadARMHeader) - klog.Infof("Remaining request count in ARM instance:%s", remaining) + klog.Infof("Checkaccess Request has succeeded, CorrelationID is %s. Remaining request count in ARM instance:%s", correlationID[0], remaining) count, _ := strconv.Atoi(remaining) if count < a.armCallLimit { if klog.V(10).Enabled() { @@ -382,6 +444,11 @@ func (a *AccessInfo) sendCheckAccessRequest(checkAccessURL url.URL, checkAccessB } // Decode response and prepare k8s response - reviewResult.status, reviewResult.err = ConvertCheckAccessResponse(data) - ch <- reviewResult + status, err := ConvertCheckAccessResponse(data) + if err != nil { + return err + } + + ch <- status + return nil } diff --git a/go.mod b/go.mod index ea50609cc..9209e9626 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( golang.org/x/exp v0.0.0-20221026004748-78e5e7837ae6 golang.org/x/net v0.0.0-20220722155237-a158d28d115b golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d + golang.org/x/sync v0.1.0 golang.org/x/text v0.3.7 gomodules.xyz/blobfs v0.1.7 gomodules.xyz/cert v1.4.1 diff --git a/go.sum b/go.sum index d0d762f85..c0584bb76 100644 --- a/go.sum +++ b/go.sum @@ -631,6 +631,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/server/authzhandler.go b/server/authzhandler.go index 4d4fce2a0..3488fc89f 100644 --- a/server/authzhandler.go +++ b/server/authzhandler.go @@ -22,6 +22,7 @@ import ( "go.kubeguard.dev/guard/authz" "go.kubeguard.dev/guard/authz/providers/azure" azureutils "go.kubeguard.dev/guard/util/azure" + errutils "go.kubeguard.dev/guard/util/error" "github.com/pkg/errors" authzv1 "k8s.io/api/authorization/v1" @@ -38,12 +39,12 @@ type Authzhandler struct { func (s *Authzhandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { klog.Infof("Recieved subject access review request") if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 { - writeAuthzResponse(w, nil, nil, WithCode(errors.New("Missing client certificate"), http.StatusBadRequest)) + writeAuthzResponse(w, nil, nil, errutils.WithCode(errors.New("Missing client certificate"), http.StatusBadRequest)) return } crt := req.TLS.PeerCertificates[0] if len(crt.Subject.Organization) == 0 { - writeAuthzResponse(w, nil, nil, WithCode(errors.New("Client certificate is missing organization"), http.StatusBadRequest)) + writeAuthzResponse(w, nil, nil, errutils.WithCode(errors.New("Client certificate is missing organization"), http.StatusBadRequest)) return } org := crt.Subject.Organization[0] @@ -51,18 +52,18 @@ func (s *Authzhandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { data := authzv1.SubjectAccessReview{} err := json.NewDecoder(req.Body).Decode(&data) if err != nil { - writeAuthzResponse(w, nil, nil, WithCode(errors.Wrap(err, "Failed to parse request"), http.StatusBadRequest)) + writeAuthzResponse(w, nil, nil, errutils.WithCode(errors.Wrap(err, "Failed to parse request"), http.StatusBadRequest)) return } if !s.AuthzRecommendedOptions.AuthzProvider.Has(org) { - writeAuthzResponse(w, &data.Spec, nil, WithCode(errors.Errorf("guard does not provide service for %v", org), http.StatusBadRequest)) + writeAuthzResponse(w, &data.Spec, nil, errutils.WithCode(errors.Errorf("guard does not provide service for %v", org), http.StatusBadRequest)) return } client, err := s.getAuthzProviderClient(org) if client == nil || err != nil { - writeAuthzResponse(w, &data.Spec, nil, err) + writeAuthzResponse(w, &data.Spec, nil, errutils.WithCode(err, http.StatusInternalServerError)) return } diff --git a/server/handler.go b/server/handler.go index b41896540..a5c31cac4 100644 --- a/server/handler.go +++ b/server/handler.go @@ -27,6 +27,7 @@ import ( "go.kubeguard.dev/guard/auth/providers/google" "go.kubeguard.dev/guard/auth/providers/ldap" "go.kubeguard.dev/guard/auth/providers/token" + errutils "go.kubeguard.dev/guard/util/error" "github.com/pkg/errors" authv1 "k8s.io/api/authentication/v1" @@ -35,12 +36,12 @@ import ( func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 { - write(w, nil, WithCode(errors.New("Missing client certificate"), http.StatusBadRequest)) + write(w, nil, errutils.WithCode(errors.New("Missing client certificate"), http.StatusBadRequest)) return } crt := req.TLS.PeerCertificates[0] if len(crt.Subject.Organization) == 0 { - write(w, nil, WithCode(errors.New("Client certificate is missing organization"), http.StatusBadRequest)) + write(w, nil, errutils.WithCode(errors.New("Client certificate is missing organization"), http.StatusBadRequest)) return } org := crt.Subject.Organization[0] @@ -49,12 +50,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { data := authv1.TokenReview{} err := json.NewDecoder(req.Body).Decode(&data) if err != nil { - write(w, nil, WithCode(errors.Wrap(err, "Failed to parse request"), http.StatusBadRequest)) + write(w, nil, errutils.WithCode(errors.Wrap(err, "Failed to parse request"), http.StatusBadRequest)) return } if !s.AuthRecommendedOptions.AuthProvider.Has(org) { - write(w, nil, WithCode(errors.Errorf("guard does not provide service for %v", org), http.StatusBadRequest)) + write(w, nil, errutils.WithCode(errors.Errorf("guard does not provide service for %v", org), http.StatusBadRequest)) return } diff --git a/server/prometheus.go b/server/prometheus.go index f98cdde65..e1e3f485d 100644 --- a/server/prometheus.go +++ b/server/prometheus.go @@ -49,7 +49,7 @@ var ( prometheus.HistogramOpts{ Name: "request_duration_seconds", Help: "A histogram of latencies for requests.", - Buckets: []float64{.25, .5, 1, 2.5, 5, 10}, + Buckets: []float64{.25, .5, 1, 2.5, 5, 10, 15, 20}, }, []string{"handler", "method"}, ) @@ -67,13 +67,13 @@ var ( inFlightGaugeAuthz = prometheus.NewGauge(prometheus.GaugeOpts{ Name: "subjectaccessreviews_handler_requests_in_flight", - Help: "A gauge of requests currently being served by the tokenreviews handler.", + Help: "A gauge of requests currently being served by the subjectaccessreviews handler.", }) counterAuthz = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "subjectaccessreviews_handler_requests_total", - Help: "A counter for requests to the tokenreviews handler.", + Help: "A counter for requests to the subjectaccessreviews handler.", }, []string{"code", "method"}, ) diff --git a/server/server.go b/server/server.go index cdbac84be..fecc06b2f 100644 --- a/server/server.go +++ b/server/server.go @@ -229,11 +229,15 @@ func (s Server) ListenAndServe() { klog.Fatalf("Failed to create settings for discovering resources. Error:%s", err) } + discoverResourcesListStart := time.Now() operationsMap, err := azureutils.DiscoverResources(settings) + discoverResourcesDuration := time.Since(discoverResourcesListStart).Seconds() if err != nil { + azureutils.DiscoverResourcesTotalDuration.Observe(discoverResourcesDuration) klog.Fatalf("Failed to create map of data actions. Error:%s", err) } + azureutils.DiscoverResourcesTotalDuration.Observe(discoverResourcesDuration) authzhandler.operationsMap = operationsMap } } diff --git a/server/utils.go b/server/utils.go index 176af13ce..7d244463f 100644 --- a/server/utils.go +++ b/server/utils.go @@ -17,10 +17,10 @@ limitations under the License. package server import ( - "fmt" - "io" "net/http" + errutils "go.kubeguard.dev/guard/util/error" + jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" auth "k8s.io/api/authentication/v1" @@ -47,7 +47,7 @@ func write(w http.ResponseWriter, info *auth.UserInfo, err error) { if err != nil { code := http.StatusUnauthorized - if v, ok := err.(httpStatusCode); ok { + if v, ok := err.(errutils.HttpStatusCode); ok { code = v.Code() } printStackTrace(err) @@ -99,12 +99,16 @@ func writeAuthzResponse(w http.ResponseWriter, spec *authzv1.SubjectAccessReview } resp.Status = accessInfo } - + code := http.StatusOK if err != nil { + if v, ok := err.(errutils.HttpStatusCode); ok { + code = v.Code() + } printStackTrace(err) } - w.WriteHeader(http.StatusOK) + w.WriteHeader(code) + if klog.V(7).Enabled() { if _, ok := spec.Extra["oid"]; ok { data, _ := json.Marshal(resp) @@ -122,10 +126,6 @@ type stackTracer interface { StackTrace() errors.StackTrace } -type httpStatusCode interface { - Code() int -} - func printStackTrace(err error) { klog.Errorln(err) @@ -134,43 +134,3 @@ func printStackTrace(err error) { klog.V(5).Infof("Stacktrace: %+v", st) // top two frames } } - -// WithCode annotates err with a new code. -// If err is nil, WithCode returns nil. -func WithCode(err error, code int) error { - if err == nil { - return nil - } - return &withCode{ - cause: err, - code: code, - } -} - -type withCode struct { - cause error - code int -} - -func (w *withCode) Error() string { return w.cause.Error() } -func (w *withCode) Cause() error { return w.cause } -func (w *withCode) Code() int { return w.code } - -func (w *withCode) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - _, err := fmt.Fprintf(s, "%+v\n", w.Cause()) - if err != nil { - klog.Fatal(err) - } - return - } - fallthrough - case 's', 'q': - _, err := io.WriteString(s, w.Error()) - if err != nil { - klog.Fatal(err) - } - } -} diff --git a/util/azure/utils.go b/util/azure/utils.go index 37b68bec8..d93a242b4 100644 --- a/util/azure/utils.go +++ b/util/azure/utils.go @@ -21,7 +21,9 @@ import ( "io" "net/http" "path" + "strconv" "strings" + "time" "go.kubeguard.dev/guard/auth/providers/azure/graph" "go.kubeguard.dev/guard/util/httpclient" @@ -29,6 +31,7 @@ import ( "github.com/Azure/go-autorest/autorest/azure" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" v "gomodules.xyz/x/version" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" @@ -47,6 +50,29 @@ const ( OperationsEndpointFormatAKS = "%s/providers/Microsoft.ContainerService/operations?api-version=2018-10-31" ) +var ( + discoverResourcesApiServerCallDuration = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Name: "guard_apiresources_request_duration_seconds", + Help: "A histogram of latencies for apiserver requests.", + Buckets: []float64{.25, .5, 1, 2.5, 5, 10, 15, 20}, + }) + + discoverResourcesAzureCallDuration = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Name: "guard_azure_get_operations_request_duration_seconds", + Help: "A histogram of latencies for azure get operations requests.", + Buckets: []float64{.25, .5, 1, 2.5, 5, 10, 15, 20}, + }) + + DiscoverResourcesTotalDuration = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Name: "guard_discover_resources_request_duration_seconds", + Help: "A histogram of latencies for azure get operations requests.", + Buckets: []float64{.25, .5, 1, 2.5, 5, 10, 15, 20}, + }) +) + type TokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn string `json:"expires_in"` @@ -132,6 +158,10 @@ func (o OperationsMap) String() string { return string(opMapString) } +func ConvertIntToString(number int) string { + return strconv.Itoa(number) +} + func NewDiscoverResourcesSettings(clusterType string, environment string, loginURL string, kubeconfigFilePath string, tenantID string, clientID string, clientSecret string) (*DiscoverResourcesSettings, error) { settings := &DiscoverResourcesSettings{ clusterType: clusterType, @@ -176,16 +206,26 @@ func NewDiscoverResourcesSettings(clusterType string, environment string, loginU */ func DiscoverResources(settings *DiscoverResourcesSettings) (OperationsMap, error) { operationsMap := OperationsMap{} + apiResourcesListStart := time.Now() apiResourcesList, err := fetchApiResources(settings) + apiResourcesListDuration := time.Since(apiResourcesListStart).Seconds() + if err != nil { return operationsMap, errors.Wrap(err, "Failed to fetch list of api-resources from apiserver.") } + discoverResourcesApiServerCallDuration.Observe(apiResourcesListDuration) + + getOperationsStart := time.Now() operationsList, err := fetchDataActionsList(settings) + getOperationsDuration := time.Since(getOperationsStart).Seconds() + if err != nil { return operationsMap, errors.Wrap(err, "Failed to fetch operations from Azure.") } + discoverResourcesAzureCallDuration.Observe(getOperationsDuration) + operationsMap = createOperationsMap(apiResourcesList, operationsList, settings.clusterType) klog.V(5).Infof("Operations Map created for resources: %s", operationsMap) @@ -383,3 +423,7 @@ func fetchDataActionsList(settings *DiscoverResourcesSettings) ([]Operation, err return finalOperations, nil } + +func init() { + prometheus.MustRegister(DiscoverResourcesTotalDuration, discoverResourcesAzureCallDuration, discoverResourcesApiServerCallDuration) +} diff --git a/util/error/utils.go b/util/error/utils.go new file mode 100644 index 000000000..44f70ee92 --- /dev/null +++ b/util/error/utils.go @@ -0,0 +1,68 @@ +/* +Copyright The Guard Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package error + +import ( + "fmt" + "io" + + "k8s.io/klog/v2" +) + +// WithCode annotates err with a new code. +// If err is nil, WithCode returns nil. +func WithCode(err error, code int) error { + if err == nil { + return nil + } + return &withCode{ + cause: err, + code: code, + } +} + +type withCode struct { + cause error + code int +} + +func (w *withCode) Error() string { return w.cause.Error() } +func (w *withCode) Cause() error { return w.cause } +func (w *withCode) Code() int { return w.code } + +type HttpStatusCode interface { + Code() int +} + +func (w *withCode) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + _, err := fmt.Fprintf(s, "%+v\n", w.Cause()) + if err != nil { + klog.Fatal(err) + } + return + } + fallthrough + case 's', 'q': + _, err := io.WriteString(s, w.Error()) + if err != nil { + klog.Fatal(err) + } + } +} diff --git a/vendor/golang.org/x/sync/LICENSE b/vendor/golang.org/x/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/golang.org/x/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/sync/PATENTS b/vendor/golang.org/x/sync/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/sync/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/sync/errgroup/errgroup.go b/vendor/golang.org/x/sync/errgroup/errgroup.go new file mode 100644 index 000000000..cbee7a4e2 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/errgroup.go @@ -0,0 +1,132 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package errgroup provides synchronization, error propagation, and Context +// cancelation for groups of goroutines working on subtasks of a common task. +package errgroup + +import ( + "context" + "fmt" + "sync" +) + +type token struct{} + +// A Group is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero Group is valid, has no limit on the number of active goroutines, +// and does not cancel on error. +type Group struct { + cancel func() + + wg sync.WaitGroup + + sem chan token + + errOnce sync.Once + err error +} + +func (g *Group) done() { + if g.sem != nil { + <-g.sem + } + g.wg.Done() +} + +// WithContext returns a new Group and an associated Context derived from ctx. +// +// The derived Context is canceled the first time a function passed to Go +// returns a non-nil error or the first time Wait returns, whichever occurs +// first. +func WithContext(ctx context.Context) (*Group, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &Group{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +func (g *Group) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel() + } + return g.err +} + +// Go calls the given function in a new goroutine. +// It blocks until the new goroutine can be added without the number of +// active goroutines in the group exceeding the configured limit. +// +// The first call to return a non-nil error cancels the group's context, if the +// group was created by calling WithContext. The error will be returned by Wait. +func (g *Group) Go(f func() error) { + if g.sem != nil { + g.sem <- token{} + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() +} + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// +// The return value reports whether the goroutine was started. +func (g *Group) TryGo(f func() error) bool { + if g.sem != nil { + select { + case g.sem <- token{}: + // Note: this allows barging iff channels in general allow barging. + default: + return false + } + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() + return true +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. +// +// Any subsequent call to the Go method will block until it can add an active +// goroutine without exceeding the configured limit. +// +// The limit must not be modified while any goroutines in the group are active. +func (g *Group) SetLimit(n int) { + if n < 0 { + g.sem = nil + return + } + if len(g.sem) != 0 { + panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + } + g.sem = make(chan token, n) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index ac9c8377f..7040e9033 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -443,6 +443,9 @@ golang.org/x/oauth2/google golang.org/x/oauth2/internal golang.org/x/oauth2/jws golang.org/x/oauth2/jwt +# golang.org/x/sync v0.1.0 +## explicit +golang.org/x/sync/errgroup # golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f ## explicit; go 1.17 golang.org/x/sys/internal/unsafeheader