Skip to content

Commit

Permalink
Merge pull request #724 from KeyboardNerd/ref
Browse files Browse the repository at this point in the history
database: move db logic to dbutil
  • Loading branch information
KeyboardNerd authored Mar 6, 2019
2 parents 1f0bc1e + 1b9ed99 commit b3fe95e
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 146 deletions.
47 changes: 12 additions & 35 deletions api/v3/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package v3

import (
"fmt"

"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -25,7 +23,6 @@ import (
pb "github.com/coreos/clair/api/v3/clairpb"
"github.com/coreos/clair/database"
"github.com/coreos/clair/ext/imagefmt"
"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
)

Expand Down Expand Up @@ -128,20 +125,13 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq
return nil, status.Errorf(codes.InvalidArgument, "ancestry name should not be empty")
}

tx, err := s.Store.Begin()
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

defer tx.Rollback()

ancestry, ok, err := tx.FindAncestry(name)
ancestry, ok, err := database.FindAncestryAndRollback(s.Store, name)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}

if !ok {
return nil, status.Error(codes.NotFound, fmt.Sprintf("requested ancestry '%s' is not found", req.GetAncestryName()))
return nil, status.Errorf(codes.NotFound, "requested ancestry '%s' is not found", req.GetAncestryName())
}

pbAncestry := &pb.GetAncestryResponse_Ancestry{
Expand All @@ -150,7 +140,7 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq
}

for _, layer := range ancestry.Layers {
pbLayer, err := GetPbAncestryLayer(tx, layer)
pbLayer, err := s.GetPbAncestryLayer(layer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -180,25 +170,20 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
return nil, status.Error(codes.InvalidArgument, "notification page limit should not be empty or less than 1")
}

tx, err := s.Store.Begin()
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
defer tx.Rollback()

dbNotification, ok, err := tx.FindVulnerabilityNotification(
dbNotification, ok, err := database.FindVulnerabilityNotificationAndRollback(
s.Store,
req.GetName(),
int(req.GetLimit()),
pagination.Token(req.GetOldVulnerabilityPage()),
pagination.Token(req.GetNewVulnerabilityPage()),
)

if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}

if !ok {
return nil, status.Error(codes.NotFound, fmt.Sprintf("requested notification '%s' is not found", req.GetName()))
return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName())
}

notification, err := pb.NotificationFromDatabaseModel(dbNotification)
Expand All @@ -216,21 +201,13 @@ func (s *NotificationServer) MarkNotificationAsRead(ctx context.Context, req *pb
return nil, status.Error(codes.InvalidArgument, "notification name should not be empty")
}

tx, err := s.Store.Begin()
found, err := database.MarkNotificationAsReadAndCommit(s.Store, req.GetName())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

defer tx.Rollback()
err = tx.DeleteNotification(req.GetName())
if err == commonerr.ErrNotFound {
return nil, status.Error(codes.NotFound, "requested notification \""+req.GetName()+"\" is not found")
} else if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}

if err := tx.Commit(); err != nil {
return nil, status.Error(codes.Internal, err.Error())
if !found {
return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName())
}

return &pb.MarkNotificationAsReadResponse{}, nil
Expand Down
12 changes: 4 additions & 8 deletions api/v3/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,22 @@ func GetClairStatus(store database.Datastore) (*pb.ClairStatus, error) {

// GetPbAncestryLayer retrieves an ancestry layer with vulnerabilities and
// features in an ancestry based on the provided database layer.
func GetPbAncestryLayer(tx database.Session, layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) {
func (s *AncestryServer) GetPbAncestryLayer(layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) {
pbLayer := &pb.GetAncestryResponse_AncestryLayer{
Layer: &pb.Layer{
Hash: layer.Hash,
},
}

features := layer.GetFeatures()
affectedFeatures, err := tx.FindAffectedNamespacedFeatures(features)
affectedFeatures, err := database.FindAffectedNamespacedFeaturesAndRollback(s.Store, features)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}

// NOTE(sidac): It's quite inefficient, but the easiest way to implement
// this feature for now, we should refactor the implementation if there's
// any performance issue. It's expected that the number of features is less
// than 1000.
for _, feature := range affectedFeatures {
if !feature.Valid {
return nil, status.Error(codes.Internal, "ancestry feature is not found")
panic("feature is missing in the database, it indicates the database is corrupted.")
}

for _, detectedFeature := range layer.Features {
Expand Down
124 changes: 124 additions & 0 deletions database/dbutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (

log "github.com/sirupsen/logrus"

"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
"github.com/deckarep/golang-set"
)

Expand Down Expand Up @@ -400,3 +402,125 @@ func PersistDetectorsAndCommit(store Datastore, detectors []Detector) error {

return nil
}

// MarkNotificationAsReadAndCommit marks a notification as read.
func MarkNotificationAsReadAndCommit(store Datastore, name string) (bool, error) {
tx, err := store.Begin()
if err != nil {
return false, err
}

defer tx.Rollback()
err = tx.DeleteNotification(name)
if err == commonerr.ErrNotFound {
return false, nil
} else if err != nil {
return false, err
}

if err := tx.Commit(); err != nil {
return false, err
}

return true, nil
}

// FindAffectedNamespacedFeaturesAndRollback finds the vulnerabilities on each
// feature.
func FindAffectedNamespacedFeaturesAndRollback(store Datastore, features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) {
tx, err := store.Begin()
if err != nil {
return nil, err
}

defer tx.Rollback()
nullableFeatures, err := tx.FindAffectedNamespacedFeatures(features)
if err != nil {
return nil, err
}

return nullableFeatures, nil
}

// FindVulnerabilityNotificationAndRollback finds the vulnerability notification
// and rollback.
func FindVulnerabilityNotificationAndRollback(store Datastore, name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (VulnerabilityNotificationWithVulnerable, bool, error) {
tx, err := store.Begin()
if err != nil {
return VulnerabilityNotificationWithVulnerable{}, false, err
}

defer tx.Rollback()
return tx.FindVulnerabilityNotification(name, limit, oldVulnerabilityPage, newVulnerabilityPage)
}

// FindNewNotification finds notifications either never notified or notified
// before the given time.
func FindNewNotification(store Datastore, notifiedBefore time.Time) (NotificationHook, bool, error) {
tx, err := store.Begin()
if err != nil {
return NotificationHook{}, false, err
}

defer tx.Rollback()
return tx.FindNewNotification(notifiedBefore)
}

// UpdateKeyValueAndCommit stores the key value to storage.
func UpdateKeyValueAndCommit(store Datastore, key, value string) error {
tx, err := store.Begin()
if err != nil {
return err
}

defer tx.Rollback()
if err = tx.UpdateKeyValue(key, value); err != nil {
return err
}

return tx.Commit()
}

// InsertVulnerabilityNotificationsAndCommit inserts the notifications into db
// and commit.
func InsertVulnerabilityNotificationsAndCommit(store Datastore, notifications []VulnerabilityNotification) error {
tx, err := store.Begin()
if err != nil {
return err
}
defer tx.Rollback()

if err := tx.InsertVulnerabilityNotifications(notifications); err != nil {
return err
}

return tx.Commit()
}

// FindVulnerabilitiesAndRollback finds the vulnerabilities based on given ids.
func FindVulnerabilitiesAndRollback(store Datastore, ids []VulnerabilityID) ([]NullableVulnerability, error) {
tx, err := store.Begin()
if err != nil {
return nil, err
}

defer tx.Rollback()
return tx.FindVulnerabilities(ids)
}

func UpdateVulnerabilitiesAndCommit(store Datastore, toRemove []VulnerabilityID, toAdd []VulnerabilityWithAffected) error {
tx, err := store.Begin()
if err != nil {
return err
}

if err := tx.DeleteVulnerabilities(toRemove); err != nil {
return err
}

if err := tx.InsertVulnerabilities(toAdd); err != nil {
return err
}

return tx.Commit()
}
8 changes: 1 addition & 7 deletions ext/vulnsrc/suse/suse.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,10 @@ func init() {
func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) {
log.WithField("package", u.Name).Info("Start fetching vulnerabilities")

tx, err := datastore.Begin()
if err != nil {
return resp, err
}
defer tx.Rollback()

// openSUSE and SUSE have one single xml file for all the products, there are no incremental
// xml files. We store into the database the value of the generation timestamp
// of the latest file we parsed.
flagValue, ok, err := tx.FindKeyValue(u.UpdaterFlag)
flagValue, ok, err := database.FindKeyValueAndRollback(datastore, u.UpdaterFlag)
if err != nil {
return resp, err
}
Expand Down
7 changes: 0 additions & 7 deletions ext/vulnsrc/ubuntu/ubuntu.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,6 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er
return resp, err
}

// Open a database transaction.
tx, err := db.Begin()
if err != nil {
return resp, err
}
defer tx.Rollback()

// Ask the database for the latest commit we successfully applied.
dbCommit, ok, err := database.FindKeyValueAndRollback(db, updaterFlag)
if err != nil {
Expand Down
26 changes: 2 additions & 24 deletions notifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop
go func() {
success, interrupted := handleTask(*notification, stopper, config.Attempts)
if success {
err := markNotificationAsRead(datastore, notification.Name)
_, err := database.MarkNotificationAsReadAndCommit(datastore, notification.Name)
if err != nil {
log.WithError(err).Error("Failed to mark notification notified")
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop

func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.NotificationHook {
for {
notification, ok, err := findNewNotification(datastore, renotifyInterval)
notification, ok, err := database.FindNewNotification(datastore, time.Now().Add(-renotifyInterval))
if err != nil || !ok {
if !ok {
log.WithError(err).Warning("could not get notification to send")
Expand Down Expand Up @@ -186,25 +186,3 @@ func handleTask(n database.NotificationHook, st *stopper.Stopper, maxAttempts in
log.WithField(logNotiName, n.Name).Info("successfully sent notification")
return true, false
}

func findNewNotification(datastore database.Datastore, renotifyInterval time.Duration) (database.NotificationHook, bool, error) {
tx, err := datastore.Begin()
if err != nil {
return database.NotificationHook{}, false, err
}
defer tx.Rollback()
return tx.FindNewNotification(time.Now().Add(-renotifyInterval))
}

func markNotificationAsRead(datastore database.Datastore, name string) error {
tx, err := datastore.Begin()
if err != nil {
log.WithError(err).Error("an error happens when beginning database transaction")
}
defer tx.Rollback()

if err := tx.MarkNotificationAsRead(name); err != nil {
return err
}
return tx.Commit()
}
Loading

0 comments on commit b3fe95e

Please sign in to comment.