Skip to content

Commit

Permalink
Improve error handling for GetCredentialStateForProvider (#3373)
Browse files Browse the repository at this point in the history
The `GetCredentialStateForProvider` swallows an error since it was
originally written for a codepath where the error was not important.
Since then, it is used in a second place where the error is important.
I'm not sure if it's possible for the error to get triggered in
practice, but I managed to trigger the error path while writing unit
tests, which led to a strange and hard to debug error happening after
the error was skipped over.

Refactor this function to return an error. Ignore it in the call site
where it should be ignored, and react to it where it is significant. As
a nice side effect of this change, the control flow of
`GetCredentialStateForProvider` is significantly simplified.

Some unrelated unit test improvements from an ongoing development branch
have also been added to this branch.
  • Loading branch information
dmjb authored May 20, 2024
1 parent 12c5e04 commit c886c5d
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 26 deletions.
8 changes: 7 additions & 1 deletion internal/controlplane/handlers_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,20 @@ func protobufProviderFromDB(
}
}

state, err := providers.GetCredentialStateForProvider(ctx, *p, store, cryptoEngine, pc)
if err != nil {
// This is non-fatal
zerolog.Ctx(ctx).Error().Err(err).Str("provider", p.Name).Msg("error getting credential")
}

return &minderv1.Provider{
Name: p.Name,
Project: p.ProjectID.String(),
Version: p.Version,
Implements: protobufProviderImplementsFromDB(ctx, *p),
AuthFlows: protobufProviderAuthFlowFromDB(ctx, *p),
Config: cfg,
CredentialsState: providers.GetCredentialStateForProvider(ctx, *p, store, cryptoEngine, pc),
CredentialsState: state,
Class: string(p.Class),
}, nil
}
Expand Down
12 changes: 8 additions & 4 deletions internal/db/provider_access_tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,21 @@ func TestUpsertProviderAccessToken(t *testing.T) {
require.Equal(t, secret, deserializeSecret(t, tok.EncryptedAccessToken))
require.Equal(t, sql.NullString{}, tok.OwnerFilter)

newSecret := createSecret(t, "def")
newSerialized := serializeSecret(t, newSecret)
tokUpdate, err := testQueries.UpsertAccessToken(context.Background(), UpsertAccessTokenParams{
ProjectID: project.ID,
Provider: prov.Name,
EncryptedToken: "def",
OwnerFilter: sql.NullString{},
ProjectID: project.ID,
Provider: prov.Name,
EncryptedToken: "def",
EncryptedAccessToken: newSerialized,
OwnerFilter: sql.NullString{},
})

require.NoError(t, err)
require.Equal(t, project.ID, tokUpdate.ProjectID)
require.Equal(t, prov.Name, tokUpdate.Provider)
require.Equal(t, "def", tokUpdate.EncryptedToken)
require.Equal(t, newSecret, deserializeSecret(t, tokUpdate.EncryptedAccessToken))
require.Equal(t, sql.NullString{}, tokUpdate.OwnerFilter)
require.Equal(t, tok.ID, tokUpdate.ID)
require.Equal(t, tok.CreatedAt, tokUpdate.CreatedAt)
Expand Down
5 changes: 4 additions & 1 deletion internal/providers/github/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ func (g *githubProviderManager) Build(ctx context.Context, config *db.Provider)
}

func (g *githubProviderManager) Delete(ctx context.Context, config *db.Provider) error {
state := providers.GetCredentialStateForProvider(ctx, *config, g.store, g.crypteng, g.config)
state, err := providers.GetCredentialStateForProvider(ctx, *config, g.store, g.crypteng, g.config)
if err != nil {
return fmt.Errorf("unable to get credential state for provider %s: %w", config.ID, err)
}
if state == v1.CredentialStateSet {
provider, err := g.Build(ctx, config)
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions internal/providers/github/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ func TestProviderService_CreateGitHubOAuthProvider(t *testing.T) {
dbToken, err := mocks.fakeStore.GetAccessTokenByProvider(context.Background(), dbProv.Name)
require.NoError(t, err)
require.Len(t, dbToken, 1)
require.True(t, dbToken[0].EncryptedAccessToken.Valid)
deserialized, err := crypto.DeserializeEncryptedData(dbToken[0].EncryptedAccessToken.RawMessage)
require.NoError(t, err)
require.Equal(t, deserialized, encryptedToken)
require.Equal(t, dbToken[0].EncryptedToken, encryptedToken.EncodedData)
require.Equal(t, dbToken[0].OwnerFilter, sql.NullString{String: "testorg", Valid: true})
require.Equal(t, dbToken[0].EnrollmentNonce, sql.NullString{String: stateNonce, Valid: true})
Expand Down Expand Up @@ -229,6 +233,10 @@ func TestProviderService_CreateGitHubOAuthProvider(t *testing.T) {
dbTokenUpdate, err := mocks.fakeStore.GetAccessTokenByProvider(context.Background(), dbProv.Name)
require.NoError(t, err)
require.Len(t, dbTokenUpdate, 1)
require.True(t, dbToken[0].EncryptedAccessToken.Valid)
deserialized, err = crypto.DeserializeEncryptedData(dbToken[0].EncryptedAccessToken.RawMessage)
require.NoError(t, err)
require.Equal(t, deserialized, encryptedToken)
require.Equal(t, dbTokenUpdate[0].EncryptedToken, encryptedToken.EncodedData)
require.Equal(t, dbTokenUpdate[0].OwnerFilter, sql.NullString{String: "testorg", Valid: true})
require.Equal(t, dbTokenUpdate[0].EnrollmentNonce, sql.NullString{String: stateNonceUpdate, Valid: true})
Expand Down
2 changes: 1 addition & 1 deletion internal/providers/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (p *providerManager) deleteByRecord(ctx context.Context, config *db.Provide
}

// carry out provider-specific cleanup
if err := manager.Delete(ctx, config); err != nil {
if err = manager.Delete(ctx, config); err != nil {
return fmt.Errorf("error while cleaning up provider: %w", err)
}

Expand Down
35 changes: 16 additions & 19 deletions internal/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,27 @@ func GetCredentialStateForProvider(
s db.Store,
cryptoEngine crypto.Engine,
provCfg *serverconfig.ProviderConfig,
) string {
var credState string
) (string, error) {
// if the provider doesn't support any auth flow
// credentials state is not applicable
if slices.Equal(prov.AuthFlows, []db.AuthorizationFlow{db.AuthorizationFlowNone}) {
credState = provinfv1.CredentialStateNotApplicable
} else {
credState = provinfv1.CredentialStateUnset
cred, err := getCredentialForProvider(ctx, prov, cryptoEngine, s, provCfg)
if err != nil {
// This is non-fatal
zerolog.Ctx(ctx).Error().Err(err).Str("provider", prov.Name).Msg("error getting credential")
} else {
// check if the credential is EmptyCredential
// if it is, then the state is not applicable
if _, ok := cred.(*credentials.EmptyCredential); ok {
credState = provinfv1.CredentialStateUnset
} else {
credState = provinfv1.CredentialStateSet
}
}
return provinfv1.CredentialStateNotApplicable, nil
}

cred, err := getCredentialForProvider(ctx, prov, cryptoEngine, s, provCfg)
if err != nil {
// One of the callers of this function treats this error as non-fatal
// and uses the credState value.
return provinfv1.CredentialStateUnset, err
}

// check if the credential is EmptyCredential
// if it is, then the state is not applicable
if _, ok := cred.(*credentials.EmptyCredential); ok {
return provinfv1.CredentialStateUnset, nil
}

return credState
return provinfv1.CredentialStateSet, nil
}

func getCredentialForProvider(
Expand Down

0 comments on commit c886c5d

Please sign in to comment.