diff --git a/modules/dashboard/graph/schema.resolvers.go b/modules/dashboard/graph/schema.resolvers.go index 33690038..a4b676f6 100644 --- a/modules/dashboard/graph/schema.resolvers.go +++ b/modules/dashboard/graph/schema.resolvers.go @@ -6,13 +6,13 @@ package graph import ( "context" - "fmt" "slices" "time" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/kubetail-org/kubetail/modules/dashboard/graph/model" "github.com/kubetail-org/kubetail/modules/dashboard/internal/k8shelpers" + "github.com/kubetail-org/kubetail/modules/shared/config" gqlerrors "github.com/kubetail-org/kubetail/modules/shared/graphql/errors" "github.com/kubetail-org/kubetail/modules/shared/helm" sharedk8shelpers "github.com/kubetail-org/kubetail/modules/shared/k8shelpers" @@ -120,6 +120,11 @@ func (r *kubeConfigResolver) Contexts(ctx context.Context, obj *model.KubeConfig // KubetailClusterAPIInstall is the resolver for the kubetailClusterAPIInstall field. func (r *mutationResolver) KubetailClusterAPIInstall(ctx context.Context, kubeContext *string) (*bool, error) { + // Reject requests not in desktop environment + if r.environment != config.EnvironmentDesktop { + return nil, gqlerrors.ErrForbidden + } + // Init client client, err := helm.NewClient() if err != nil { @@ -440,10 +445,16 @@ func (r *queryResolver) KubetailClusterAPIHealthzGet(ctx context.Context, kubeCo // KubeConfigGet is the resolver for the kubeConfigGet field. func (r *queryResolver) KubeConfigGet(ctx context.Context) (*model.KubeConfig, error) { + // Reject requests not in desktop environment + if r.environment != config.EnvironmentDesktop { + return nil, gqlerrors.ErrForbidden + } + cm, ok := r.cm.(*k8shelpers.DesktopConnectionManager) if !ok { - return nil, fmt.Errorf("DesktopConnectionManager not found") + return nil, gqlerrors.ErrInternalServerError } + return &model.KubeConfig{Config: cm.GetKubeConfig()}, nil } @@ -662,9 +673,14 @@ func (r *subscriptionResolver) KubetailClusterAPIHealthzWatch(ctx context.Contex // KubeConfigWatch is the resolver for the kubeConfigWatch field. func (r *subscriptionResolver) KubeConfigWatch(ctx context.Context) (<-chan *model.KubeConfigWatchEvent, error) { + // Reject requests not in desktop environment + if r.environment != config.EnvironmentDesktop { + return nil, gqlerrors.ErrForbidden + } + cm, ok := r.cm.(*k8shelpers.DesktopConnectionManager) if !ok { - return nil, fmt.Errorf("DesktopConnectionManager not found") + return nil, gqlerrors.ErrInternalServerError } // Init output channel diff --git a/modules/dashboard/graph/schema.resolvers_test.go b/modules/dashboard/graph/schema.resolvers_test.go index f0c2c27b..39889b86 100644 --- a/modules/dashboard/graph/schema.resolvers_test.go +++ b/modules/dashboard/graph/schema.resolvers_test.go @@ -22,6 +22,7 @@ import ( "github.com/stretchr/testify/mock" "k8s.io/utils/ptr" + "github.com/kubetail-org/kubetail/modules/shared/config" "github.com/kubetail-org/kubetail/modules/shared/graphql/errors" k8shelpersmock "github.com/kubetail-org/kubetail/modules/dashboard/internal/k8shelpers/mock" @@ -134,3 +135,28 @@ func TestAllowedNamespacesListQueries(t *testing.T) { }) } } + +func TestDesktopOnlyRequests(t *testing.T) { + resolver := &Resolver{environment: config.EnvironmentCluster} + + t.Run("kubeConfigGet", func(t *testing.T) { + r := &queryResolver{resolver} + _, err := r.KubeConfigGet(context.Background()) + assert.NotNil(t, err) + assert.Equal(t, err, errors.ErrForbidden) + }) + + t.Run("kubeConfigWatch", func(t *testing.T) { + r := &subscriptionResolver{resolver} + _, err := r.KubeConfigWatch(context.Background()) + assert.NotNil(t, err) + assert.Equal(t, err, errors.ErrForbidden) + }) + + t.Run("kubetailClusterAPIInstall", func(t *testing.T) { + r := &mutationResolver{resolver} + _, err := r.KubetailClusterAPIInstall(context.Background(), nil) + assert.NotNil(t, err) + assert.Equal(t, err, errors.ErrForbidden) + }) +} diff --git a/modules/shared/graphql/errors/errors.go b/modules/shared/graphql/errors/errors.go index 74c388ae..2a3e9258 100644 --- a/modules/shared/graphql/errors/errors.go +++ b/modules/shared/graphql/errors/errors.go @@ -21,7 +21,7 @@ var ( ErrValidationError = NewError("KUBETAIL_VALIDATION_ERROR", "Validation error") ErrRecordNotFound = NewError("KUBETAIL_RECORD_NOT_FOUND", "Record not found") ErrUnauthenticated = NewError("KUBETAIL_UNAUTHENTICATED", "Authentication required") - ErrForbidden = NewError("KUBETAIL_FORBIDDEN", "Access forbidden") + ErrForbidden = NewError("KUBETAIL_FORBIDDEN", "Forbidden") ErrWatchError = NewError("KUBETAIL_WATCH_ERROR", "Watch error") ErrInternalServerError = NewError("INTERNAL_SERVER_ERROR", "Internal server error") )