Skip to content

Commit

Permalink
Add ContextV2 protobuf structure
Browse files Browse the repository at this point in the history
This takes a `project_id` by default. It also provides the functionality
to prefer a `ContextV2` from a request if it's present and available.

Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com>
  • Loading branch information
JAORMX committed Jun 3, 2024
1 parent 3afa50e commit db0f5a4
Show file tree
Hide file tree
Showing 6 changed files with 2,329 additions and 1,986 deletions.
12 changes: 12 additions & 0 deletions docs/docs/ref/proto.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

76 changes: 76 additions & 0 deletions internal/controlplane/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"regexp"

"github.com/google/uuid"
"google.golang.org/grpc/codes"

"github.com/stacklok/minder/internal/providers/github/clients"
Expand All @@ -36,11 +37,86 @@ const (

var validRepoSlugRe = regexp.MustCompile(`(?i)^[-a-z0-9_\.]+\/[-a-z0-9_\.]+$`)

var (
// ErrNoProjectInContext is returned when no project is found in the context
ErrNoProjectInContext = errors.New("no project found in context")
)

// ProviderGetter is an interface that can be implemented by a context,
// since both the context V1 and V2 have a provider field
type ProviderGetter interface {
GetProvider() string
}

// HasProtoContextV2Compat is an interface that can be implemented by a request.
// It implements the GetContext V1 and V2 methods for backwards compatibility.
type HasProtoContextV2Compat interface {
HasProtoContext
GetContextV2() *pb.ContextV2
}

// HasProtoContextV2 is an interface that can be implemented by a request
type HasProtoContextV2 interface {
GetContextV2() *pb.ContextV2
}

// HasProtoContext is an interface that can be implemented by a request
type HasProtoContext interface {
GetContext() *pb.Context
}

func getProjectFromContextV2Compat(accessor HasProtoContextV2Compat) (uuid.UUID, error) {
// First check if the context is V2
if accessor.GetContextV2() != nil && accessor.GetContextV2().GetProjectId() != "" {
return parseProject(accessor.GetContextV2().GetProjectId())
}

// If the context is not V2, check if it is V1
if accessor.GetContext() != nil && accessor.GetContext().GetProject() != "" {
return parseProject(accessor.GetContext().GetProject())
}

if accessor.GetContextV2() == nil && accessor.GetContext() == nil {
return uuid.Nil, util.UserVisibleError(codes.InvalidArgument, "context cannot be nil")
}

return uuid.Nil, ErrNoProjectInContext
}

func getProjectFromContextV2(accessor HasProtoContextV2) (uuid.UUID, error) {
if accessor.GetContextV2() == nil {
return uuid.Nil, util.UserVisibleError(codes.InvalidArgument, "context cannot be nil")
}

// First check if the context is V2
if accessor.GetContextV2() != nil && accessor.GetContextV2().GetProjectId() != "" {
return parseProject(accessor.GetContextV2().GetProjectId())
}

return uuid.Nil, ErrNoProjectInContext
}

func getProjectFromContext(accessor HasProtoContext) (uuid.UUID, error) {
if accessor.GetContext() == nil {
return uuid.Nil, util.UserVisibleError(codes.InvalidArgument, "context cannot be nil")
}

if accessor.GetContext() != nil && accessor.GetContext().GetProject() != "" {
return parseProject(accessor.GetContext().GetProject())
}

return uuid.Nil, ErrNoProjectInContext
}

func parseProject(project string) (uuid.UUID, error) {
projID, err := uuid.Parse(project)
if err != nil {
return uuid.Nil, util.UserVisibleError(codes.InvalidArgument, "malformed project ID")
}

return projID, nil
}

// providerError wraps an error with a user visible error message
func providerError(err error) error {
if errors.Is(err, sql.ErrNoRows) {
Expand Down
152 changes: 152 additions & 0 deletions internal/controlplane/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ package controlplane
import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/minder/internal/util/ptr"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

func TestGetRemediationURLFromMetadata(t *testing.T) {
Expand Down Expand Up @@ -81,3 +86,150 @@ func TestGetAlertURLFromMetadata(t *testing.T) {
})
}
}

func Test_getProjectFromContextV2(t *testing.T) {
t.Parallel()

proj1 := uuid.New()
proj2 := uuid.New()

type args struct {
accessor HasProtoContextV2Compat
}
tests := []struct {
name string
args args
want uuid.UUID
wantErr bool
}{
{
name: "no project",
args: args{
accessor: newMockHasProtoContextV2(),
},
want: uuid.Nil,
wantErr: true,
},
{
name: "v1 project",
args: args{
accessor: newMockHasProtoContextV2().withV1(&pb.Context{
Project: ptr.Ptr(proj1.String()),
}),
},
want: proj1,
wantErr: false,
},
{
name: "v2 project",
args: args{
accessor: newMockHasProtoContextV2().withV2(&pb.ContextV2{
ProjectId: proj1.String(),
}),
},
want: proj1,
wantErr: false,
},
{
name: "v2 project wins",
args: args{
accessor: newMockHasProtoContextV2().withV1(&pb.Context{
Project: ptr.Ptr(proj1.String()),
}).withV2(&pb.ContextV2{
ProjectId: proj2.String(),
}),
},
want: proj2,
wantErr: false,
},
{
name: "v2 project wins with malformed v1",
args: args{
accessor: newMockHasProtoContextV2().withV1(&pb.Context{
Project: ptr.Ptr("malformed"),
}).withV2(&pb.ContextV2{
ProjectId: proj2.String(),
}),
},
want: proj2,
wantErr: false,
},
{
name: "malformed v2 project",
args: args{
accessor: newMockHasProtoContextV2().withV2(&pb.ContextV2{
ProjectId: "malformed",
}),
},
want: uuid.Nil,
wantErr: true,
},
{
name: "malformed v2 project is still an error",
args: args{
accessor: newMockHasProtoContextV2().withV1(&pb.Context{
Project: ptr.Ptr(proj1.String()),
}).withV2(&pb.ContextV2{
ProjectId: "malformed",
}),
},
want: uuid.Nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

got, err := getProjectFromContextV2Compat(tt.args.accessor)
if tt.wantErr {
assert.Error(t, err, "expected error")
return
}

assert.Equal(t, tt.want, got)
})
}
}

type mockHasProtoContextV2 struct {
getV1 func() *pb.Context
getV2 func() *pb.ContextV2
}

func emptyv1() *pb.Context {
return nil
}

func emptyv2() *pb.ContextV2 {
return nil
}

func newMockHasProtoContextV2() *mockHasProtoContextV2 {
return &mockHasProtoContextV2{
getV1: emptyv1,
getV2: emptyv2,
}
}

func (m *mockHasProtoContextV2) withV1(v1 *pb.Context) *mockHasProtoContextV2 {
m.getV1 = func() *pb.Context {
return v1
}
return m
}

func (m *mockHasProtoContextV2) withV2(v2 *pb.ContextV2) *mockHasProtoContextV2 {
m.getV2 = func() *pb.ContextV2 {
return v2
}
return m
}

func (m *mockHasProtoContextV2) GetContext() *pb.Context {
return m.getV1()
}

func (m *mockHasProtoContextV2) GetContextV2() *pb.ContextV2 {
return m.getV2()
}
72 changes: 44 additions & 28 deletions internal/controlplane/handlers_authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,12 @@ func EntityContextProjectInterceptor(ctx context.Context, req interface{}, info
return handler(ctx, req)
}

request, ok := req.(HasProtoContext)
server, ok := info.Server.(*Server)
if !ok {
return nil, status.Errorf(codes.Internal, "Error extracting context from request")
return nil, status.Errorf(codes.Internal, "error casting serrver for request handling")
}

server := info.Server.(*Server)

ctx, err := populateEntityContext(ctx, server.store, server.authzClient, request)
ctx, err := populateEntityContext(ctx, server.store, server.authzClient, req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -119,48 +117,66 @@ func populateEntityContext(
ctx context.Context,
store db.Store,
authzClient authz.Client,
in HasProtoContext,
req any,
) (context.Context, error) {
if in.GetContext() == nil {
return ctx, util.UserVisibleError(codes.InvalidArgument, "context cannot be nil")
}

projectID, err := getProjectFromRequestOrDefault(ctx, store, authzClient, in)
projectID, err := getProjectIDFromContext(req)
if err != nil {
return ctx, err
if errors.Is(err, ErrNoProjectInContext) {
projectID, err = getDefaultProjectID(ctx, store, authzClient)
if err != nil {
return ctx, err
}
} else {
return ctx, err
}
}

// don't look up default provider until user has been authorized
providerName := in.GetContext().GetProvider()

entityCtx := &engine.EntityContext{
Project: engine.Project{
ID: projectID,
},
Provider: engine.Provider{
Name: providerName,
Name: getProviderFromContext(req),
},
}

return engine.WithEntityContext(ctx, entityCtx), nil
}

func getProjectFromRequestOrDefault(
func getProjectIDFromContext(req any) (uuid.UUID, error) {
switch req := req.(type) {
case HasProtoContextV2Compat:
return getProjectFromContextV2Compat(req)
case HasProtoContextV2:
return getProjectFromContextV2(req)
case HasProtoContext:
return getProjectFromContext(req)
default:
return uuid.Nil, status.Errorf(codes.Internal, "Error extracting context from request")
}
}

func getProviderFromContext(req any) string {
switch req := req.(type) {
case HasProtoContextV2Compat:
if req.GetContextV2().GetProvider() != "" {
return req.GetContextV2().GetProvider()
}
return req.GetContext().GetProvider()
case HasProtoContextV2:
return req.GetContextV2().GetProvider()
case HasProtoContext:
return req.GetContext().GetProvider()
default:
return ""
}
}

func getDefaultProjectID(
ctx context.Context,
store db.Store,
authzClient authz.Client,
in HasProtoContext,
) (uuid.UUID, error) {
// Prefer the context message from the protobuf
if in.GetContext().GetProject() != "" {
requestedProject := in.GetContext().GetProject()
parsedProjectID, err := uuid.Parse(requestedProject)
if err != nil {
return uuid.UUID{}, util.UserVisibleError(codes.InvalidArgument, "malformed project ID")
}
return parsedProjectID, nil
}

subject := auth.GetUserSubjectFromContext(ctx)

userInfo, err := store.GetUserBySubject(ctx, subject)
Expand Down
Loading

0 comments on commit db0f5a4

Please sign in to comment.