Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ContextV2 protobuf structure #3506

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/docs/ref/proto.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

76 changes: 76 additions & 0 deletions internal/controlplane/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"regexp"

"github.com/google/uuid"
"google.golang.org/grpc/codes"

"github.com/stacklok/minder/internal/providers/github/clients"
Expand All @@ -36,11 +37,86 @@ const (

var validRepoSlugRe = regexp.MustCompile(`(?i)^[-a-z0-9_\.]+\/[-a-z0-9_\.]+$`)

var (
// ErrNoProjectInContext is returned when no project is found in the context
ErrNoProjectInContext = errors.New("no project found in context")
)

// ProviderGetter is an interface that can be implemented by a context,
// since both the context V1 and V2 have a provider field
type ProviderGetter interface {
GetProvider() string
}

// HasProtoContextV2Compat is an interface that can be implemented by a request.
// It implements the GetContext V1 and V2 methods for backwards compatibility.
type HasProtoContextV2Compat interface {
HasProtoContext
GetContextV2() *pb.ContextV2
}

// HasProtoContextV2 is an interface that can be implemented by a request
type HasProtoContextV2 interface {
GetContextV2() *pb.ContextV2
}

// HasProtoContext is an interface that can be implemented by a request
type HasProtoContext interface {
GetContext() *pb.Context
}

func getProjectFromContextV2Compat(accessor HasProtoContextV2Compat) (uuid.UUID, error) {
// First check if the context is V2
if accessor.GetContextV2() != nil && accessor.GetContextV2().GetProjectId() != "" {
return parseProject(accessor.GetContextV2().GetProjectId())
}

// If the context is not V2, check if it is V1
if accessor.GetContext() != nil && accessor.GetContext().GetProject() != "" {
return parseProject(accessor.GetContext().GetProject())
}

if accessor.GetContextV2() == nil && accessor.GetContext() == nil {
return uuid.Nil, util.UserVisibleError(codes.InvalidArgument, "context cannot be nil")
}

return uuid.Nil, ErrNoProjectInContext
}

func getProjectFromContextV2(accessor HasProtoContextV2) (uuid.UUID, error) {
if accessor.GetContextV2() == nil {
return uuid.Nil, util.UserVisibleError(codes.InvalidArgument, "context cannot be nil")
}

// First check if the context is V2
if accessor.GetContextV2() != nil && accessor.GetContextV2().GetProjectId() != "" {
return parseProject(accessor.GetContextV2().GetProjectId())
}

return uuid.Nil, ErrNoProjectInContext
}

func getProjectFromContext(accessor HasProtoContext) (uuid.UUID, error) {
if accessor.GetContext() == nil {
return uuid.Nil, util.UserVisibleError(codes.InvalidArgument, "context cannot be nil")
}

if accessor.GetContext() != nil && accessor.GetContext().GetProject() != "" {
return parseProject(accessor.GetContext().GetProject())
}

return uuid.Nil, ErrNoProjectInContext
}

func parseProject(project string) (uuid.UUID, error) {
projID, err := uuid.Parse(project)
if err != nil {
return uuid.Nil, util.UserVisibleError(codes.InvalidArgument, "malformed project ID")
}

return projID, nil
}

// providerError wraps an error with a user visible error message
func providerError(err error) error {
if errors.Is(err, sql.ErrNoRows) {
Expand Down
152 changes: 152 additions & 0 deletions internal/controlplane/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ package controlplane
import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/minder/internal/util/ptr"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

func TestGetRemediationURLFromMetadata(t *testing.T) {
Expand Down Expand Up @@ -81,3 +86,150 @@ func TestGetAlertURLFromMetadata(t *testing.T) {
})
}
}

func Test_getProjectFromContextV2(t *testing.T) {
t.Parallel()

proj1 := uuid.New()
proj2 := uuid.New()

type args struct {
accessor HasProtoContextV2Compat
}
tests := []struct {
name string
args args
want uuid.UUID
wantErr bool
}{
{
name: "no project",
args: args{
accessor: newMockHasProtoContextV2(),
},
want: uuid.Nil,
wantErr: true,
},
{
name: "v1 project",
args: args{
accessor: newMockHasProtoContextV2().withV1(&pb.Context{
Project: ptr.Ptr(proj1.String()),
}),
},
want: proj1,
wantErr: false,
},
{
name: "v2 project",
args: args{
accessor: newMockHasProtoContextV2().withV2(&pb.ContextV2{
ProjectId: proj1.String(),
}),
},
want: proj1,
wantErr: false,
},
{
name: "v2 project wins",
args: args{
accessor: newMockHasProtoContextV2().withV1(&pb.Context{
Project: ptr.Ptr(proj1.String()),
}).withV2(&pb.ContextV2{
ProjectId: proj2.String(),
}),
},
want: proj2,
wantErr: false,
},
{
name: "v2 project wins with malformed v1",
args: args{
accessor: newMockHasProtoContextV2().withV1(&pb.Context{
Project: ptr.Ptr("malformed"),
}).withV2(&pb.ContextV2{
ProjectId: proj2.String(),
}),
},
want: proj2,
wantErr: false,
},
{
name: "malformed v2 project",
args: args{
accessor: newMockHasProtoContextV2().withV2(&pb.ContextV2{
ProjectId: "malformed",
}),
},
want: uuid.Nil,
wantErr: true,
},
{
name: "malformed v2 project is still an error",
args: args{
accessor: newMockHasProtoContextV2().withV1(&pb.Context{
Project: ptr.Ptr(proj1.String()),
}).withV2(&pb.ContextV2{
ProjectId: "malformed",
}),
},
want: uuid.Nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

got, err := getProjectFromContextV2Compat(tt.args.accessor)
if tt.wantErr {
assert.Error(t, err, "expected error")
return
}

assert.Equal(t, tt.want, got)
})
}
}

type mockHasProtoContextV2 struct {
getV1 func() *pb.Context
getV2 func() *pb.ContextV2
}

func emptyv1() *pb.Context {
return nil
}

func emptyv2() *pb.ContextV2 {
return nil
}

func newMockHasProtoContextV2() *mockHasProtoContextV2 {
return &mockHasProtoContextV2{
getV1: emptyv1,
getV2: emptyv2,
}
}

func (m *mockHasProtoContextV2) withV1(v1 *pb.Context) *mockHasProtoContextV2 {
m.getV1 = func() *pb.Context {
return v1
}
return m
}

func (m *mockHasProtoContextV2) withV2(v2 *pb.ContextV2) *mockHasProtoContextV2 {
m.getV2 = func() *pb.ContextV2 {
return v2
}
return m
}

func (m *mockHasProtoContextV2) GetContext() *pb.Context {
return m.getV1()
}

func (m *mockHasProtoContextV2) GetContextV2() *pb.ContextV2 {
return m.getV2()
}
72 changes: 44 additions & 28 deletions internal/controlplane/handlers_authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,12 @@ func EntityContextProjectInterceptor(ctx context.Context, req interface{}, info
return handler(ctx, req)
}

request, ok := req.(HasProtoContext)
server, ok := info.Server.(*Server)
if !ok {
return nil, status.Errorf(codes.Internal, "Error extracting context from request")
return nil, status.Errorf(codes.Internal, "error casting serrver for request handling")
}

server := info.Server.(*Server)

ctx, err := populateEntityContext(ctx, server.store, server.authzClient, request)
ctx, err := populateEntityContext(ctx, server.store, server.authzClient, req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -119,48 +117,66 @@ func populateEntityContext(
ctx context.Context,
store db.Store,
authzClient authz.Client,
in HasProtoContext,
req any,
) (context.Context, error) {
if in.GetContext() == nil {
return ctx, util.UserVisibleError(codes.InvalidArgument, "context cannot be nil")
}

projectID, err := getProjectFromRequestOrDefault(ctx, store, authzClient, in)
projectID, err := getProjectIDFromContext(req)
if err != nil {
return ctx, err
if errors.Is(err, ErrNoProjectInContext) {
projectID, err = getDefaultProjectID(ctx, store, authzClient)
if err != nil {
return ctx, err
}
} else {
return ctx, err
}
}

// don't look up default provider until user has been authorized
providerName := in.GetContext().GetProvider()

entityCtx := &engine.EntityContext{
Project: engine.Project{
ID: projectID,
},
Provider: engine.Provider{
Name: providerName,
Name: getProviderFromContext(req),
},
}

return engine.WithEntityContext(ctx, entityCtx), nil
}

func getProjectFromRequestOrDefault(
func getProjectIDFromContext(req any) (uuid.UUID, error) {
switch req := req.(type) {
case HasProtoContextV2Compat:
return getProjectFromContextV2Compat(req)
case HasProtoContextV2:
return getProjectFromContextV2(req)
case HasProtoContext:
return getProjectFromContext(req)
default:
return uuid.Nil, status.Errorf(codes.Internal, "Error extracting context from request")
}
}

func getProviderFromContext(req any) string {
switch req := req.(type) {
case HasProtoContextV2Compat:
if req.GetContextV2().GetProvider() != "" {
return req.GetContextV2().GetProvider()
}
return req.GetContext().GetProvider()
case HasProtoContextV2:
return req.GetContextV2().GetProvider()
case HasProtoContext:
return req.GetContext().GetProvider()
default:
return ""
}
}

func getDefaultProjectID(
ctx context.Context,
store db.Store,
authzClient authz.Client,
in HasProtoContext,
) (uuid.UUID, error) {
// Prefer the context message from the protobuf
if in.GetContext().GetProject() != "" {
requestedProject := in.GetContext().GetProject()
parsedProjectID, err := uuid.Parse(requestedProject)
if err != nil {
return uuid.UUID{}, util.UserVisibleError(codes.InvalidArgument, "malformed project ID")
}
return parsedProjectID, nil
}

subject := auth.GetUserSubjectFromContext(ctx)

userInfo, err := store.GetUserBySubject(ctx, subject)
Expand Down
Loading