Skip to content
This repository has been archived by the owner on Oct 14, 2024. It is now read-only.

Commit

Permalink
Refactor revision management to reduce duplication
Browse files Browse the repository at this point in the history
Adds new helper functions which handle checking the revision against
IfMatch and also bumps the revision or sets the initial value if its not
set.
  • Loading branch information
Tehsmash committed May 25, 2023
1 parent fa6e4ea commit 2063c38
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 92 deletions.
19 changes: 19 additions & 0 deletions backend/pkg/database/gorm/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
jsonpatch "github.com/evanphx/json-patch"

"github.com/openclarity/vmclarity/backend/pkg/database/types"
"github.com/openclarity/vmclarity/shared/pkg/utils"
)

func getExistingObjByID(db *gorm.DB, schema, objID string, obj interface{}) error {
Expand Down Expand Up @@ -66,3 +67,21 @@ func patchObject(original []byte, newobject interface{}) ([]byte, error) {

return updated, nil
}

func checkRevisionEtag(ifMatch *int, revision *int) error {
if (ifMatch != nil && revision != nil && *ifMatch != *revision) || (ifMatch != nil && revision == nil) {
return &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*revision, *ifMatch),
}
}
return nil
}

func bumpRevision(oldrevision *int) *int {
if oldrevision != nil {
return utils.PointerTo(*oldrevision + 1)
}
return utils.PointerTo(1)
}
6 changes: 5 additions & 1 deletion backend/pkg/database/gorm/odata.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ var schemaMetas = map[string]odatasql.SchemaMeta{
targetScanResultsSchemaName: {
Table: "scan_results",
Fields: odatasql.Schema{
"id": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"id": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"revision": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"target": odatasql.FieldMeta{
FieldType: odatasql.RelationshipFieldType,
RelationshipSchema: targetSchemaName,
Expand Down Expand Up @@ -265,6 +266,7 @@ var schemaMetas = map[string]odatasql.SchemaMeta{
Table: "scans",
Fields: odatasql.Schema{
"id": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"revision": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"startTime": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"endTime": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"scanConfig": odatasql.FieldMeta{
Expand Down Expand Up @@ -317,6 +319,7 @@ var schemaMetas = map[string]odatasql.SchemaMeta{
Table: "targets",
Fields: odatasql.Schema{
"id": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"revision": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"scansCount": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"targetInfo": odatasql.FieldMeta{
FieldType: odatasql.ComplexFieldType,
Expand Down Expand Up @@ -387,6 +390,7 @@ var schemaMetas = map[string]odatasql.SchemaMeta{
Table: "scan_configs",
Fields: odatasql.Schema{
"id": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"revision": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"name": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"disabled": odatasql.FieldMeta{FieldType: odatasql.PrimitiveFieldType},
"scanFamiliesConfig": odatasql.FieldMeta{
Expand Down
28 changes: 6 additions & 22 deletions backend/pkg/database/gorm/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,8 @@ func (s *ScansTableHandler) SaveScan(scan models.Scan, params models.PutScansSca
return models.Scan{}, fmt.Errorf("failed to convert DB object to API model: %w", err)
}

if (params.IfMatch != nil && dbScan.Revision != nil && *params.IfMatch != *dbScan.Revision) || (params.IfMatch != nil && dbScan.Revision == nil) {
return models.Scan{}, &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*dbScan.Revision, *params.IfMatch),
}
if err := checkRevisionEtag(params.IfMatch, dbScan.Revision); err != nil {
return models.Scan{}, err
}

if err := validateScanConfigID(scan, dbScan); err != nil {
Expand All @@ -189,11 +185,7 @@ func (s *ScansTableHandler) SaveScan(scan models.Scan, params models.PutScansSca
}
}

if dbScan.Revision != nil {
scan.Revision = utils.PointerTo(*dbScan.Revision + 1)
} else {
scan.Revision = utils.PointerTo(1)
}
scan.Revision = bumpRevision(dbScan.Revision)

marshaled, err := json.Marshal(scan)
if err != nil {
Expand Down Expand Up @@ -232,12 +224,8 @@ func (s *ScansTableHandler) UpdateScan(scan models.Scan, params models.PatchScan
return models.Scan{}, fmt.Errorf("failed to convert DB object to API model: %w", err)
}

if (params.IfMatch != nil && dbScan.Revision != nil && *params.IfMatch != *dbScan.Revision) || (params.IfMatch != nil && dbScan.Revision == nil) {
return models.Scan{}, &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*dbScan.Revision, *params.IfMatch),
}
if err := checkRevisionEtag(params.IfMatch, dbScan.Revision); err != nil {
return models.Scan{}, err
}

if err := validateScanConfigID(scan, dbScan); err != nil {
Expand All @@ -248,11 +236,7 @@ func (s *ScansTableHandler) UpdateScan(scan models.Scan, params models.PatchScan
return models.Scan{}, fmt.Errorf("scan config id validation failed: %w", err)
}

if dbScan.Revision != nil {
scan.Revision = utils.PointerTo(*dbScan.Revision + 1)
} else {
scan.Revision = utils.PointerTo(1)
}
scan.Revision = bumpRevision(dbScan.Revision)

var err error
dbObj.Data, err = patchObject(dbObj.Data, scan)
Expand Down
30 changes: 7 additions & 23 deletions backend/pkg/database/gorm/scan_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,8 @@ func (s *ScanConfigsTableHandler) SaveScanConfig(scanConfig models.ScanConfig, p
return models.ScanConfig{}, fmt.Errorf("failed to convert DB model to API model: %w", err)
}

if (params.IfMatch != nil && dbScanConfig.Revision != nil && *params.IfMatch != *dbScanConfig.Revision) || (params.IfMatch != nil && dbScanConfig.Revision == nil) {
return models.ScanConfig{}, &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*dbScanConfig.Revision, *params.IfMatch),
}
if err := checkRevisionEtag(params.IfMatch, dbScanConfig.Revision); err != nil {
return models.ScanConfig{}, err
}

// Check the existing DB entries to ensure that the name field is unique
Expand All @@ -250,11 +246,7 @@ func (s *ScanConfigsTableHandler) SaveScanConfig(scanConfig models.ScanConfig, p
return models.ScanConfig{}, fmt.Errorf("failed to check existing scan config: %w", err)
}

if dbScanConfig.Revision != nil {
scanConfig.Revision = utils.PointerTo(*dbScanConfig.Revision + 1)
} else {
scanConfig.Revision = utils.PointerTo(1)
}
scanConfig.Revision = bumpRevision(dbScanConfig.Revision)

marshaled, err := json.Marshal(scanConfig)
if err != nil {
Expand All @@ -263,7 +255,7 @@ func (s *ScanConfigsTableHandler) SaveScanConfig(scanConfig models.ScanConfig, p

dbObj.Data = marshaled

if err := s.DB.Save(&dbScanConfig).Error; err != nil {
if err := s.DB.Save(&dbObj).Error; err != nil {
return models.ScanConfig{}, fmt.Errorf("failed to save scan config in db: %w", err)
}

Expand Down Expand Up @@ -307,19 +299,11 @@ func (s *ScanConfigsTableHandler) UpdateScanConfig(scanConfig models.ScanConfig,
return models.ScanConfig{}, fmt.Errorf("failed to convert DB model to API model: %w", err)
}

if (params.IfMatch != nil && dbScanConfig.Revision != nil && *params.IfMatch != *dbScanConfig.Revision) || (params.IfMatch != nil && dbScanConfig.Revision == nil) {
return models.ScanConfig{}, &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*dbScanConfig.Revision, *params.IfMatch),
}
if err := checkRevisionEtag(params.IfMatch, dbScanConfig.Revision); err != nil {
return models.ScanConfig{}, err
}

if dbScanConfig.Revision != nil {
scanConfig.Revision = utils.PointerTo(*dbScanConfig.Revision + 1)
} else {
scanConfig.Revision = utils.PointerTo(1)
}
scanConfig.Revision = bumpRevision(dbScanConfig.Revision)

dbObj.Data, err = patchObject(dbObj.Data, scanConfig)
if err != nil {
Expand Down
30 changes: 7 additions & 23 deletions backend/pkg/database/gorm/scan_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,11 @@ func (s *ScanResultsTableHandler) SaveScanResult(scanResult models.TargetScanRes
return models.TargetScanResult{}, fmt.Errorf("failed to convert DB model to API model: %w", err)
}

if (params.IfMatch != nil && dbScanResult.Revision != nil && *params.IfMatch != *dbScanResult.Revision) || (params.IfMatch != nil && dbScanResult.Revision == nil) {
return models.TargetScanResult{}, &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*dbScanResult.Revision, *params.IfMatch),
}
if err := checkRevisionEtag(params.IfMatch, dbScanResult.Revision); err != nil {
return models.TargetScanResult{}, err
}

if dbScanResult.Revision != nil {
scanResult.Revision = utils.PointerTo(*dbScanResult.Revision + 1)
} else {
scanResult.Revision = utils.PointerTo(1)
}
scanResult.Revision = bumpRevision(dbScanResult.Revision)

marshaled, err := json.Marshal(scanResult)
if err != nil {
Expand All @@ -231,7 +223,7 @@ func (s *ScanResultsTableHandler) SaveScanResult(scanResult models.TargetScanRes

dbObj.Data = marshaled

if err := s.DB.Save(&dbScanResult).Error; err != nil {
if err := s.DB.Save(&dbObj).Error; err != nil {
return models.TargetScanResult{}, fmt.Errorf("failed to save scan result in db: %w", err)
}

Expand Down Expand Up @@ -267,19 +259,11 @@ func (s *ScanResultsTableHandler) UpdateScanResult(scanResult models.TargetScanR
return models.TargetScanResult{}, fmt.Errorf("failed to convert DB model to API model: %w", err)
}

if (params.IfMatch != nil && dbScanResult.Revision != nil && *params.IfMatch != *dbScanResult.Revision) || (params.IfMatch != nil && dbScanResult.Revision == nil) {
return models.TargetScanResult{}, &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*dbScanResult.Revision, *params.IfMatch),
}
if err := checkRevisionEtag(params.IfMatch, dbScanResult.Revision); err != nil {
return models.TargetScanResult{}, err
}

if dbScanResult.Revision != nil {
scanResult.Revision = utils.PointerTo(*dbScanResult.Revision + 1)
} else {
scanResult.Revision = utils.PointerTo(1)
}
scanResult.Revision = bumpRevision(dbScanResult.Revision)

dbObj.Data, err = patchObject(dbObj.Data, scanResult)
if err != nil {
Expand Down
28 changes: 6 additions & 22 deletions backend/pkg/database/gorm/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,11 @@ func (t *TargetsTableHandler) SaveTarget(target models.Target, params models.Put
return models.Target{}, fmt.Errorf("failed to convert DB model to API model: %w", err)
}

if (params.IfMatch != nil && dbTarget.Revision != nil && *params.IfMatch != *dbTarget.Revision) || (params.IfMatch != nil && dbTarget.Revision == nil) {
return models.Target{}, &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*dbTarget.Revision, *params.IfMatch),
}
if err := checkRevisionEtag(params.IfMatch, dbTarget.Revision); err != nil {
return models.Target{}, err
}

if dbTarget.Revision != nil {
target.Revision = utils.PointerTo(*dbTarget.Revision + 1)
} else {
target.Revision = utils.PointerTo(1)
}
target.Revision = bumpRevision(dbTarget.Revision)

existingTarget, err := t.checkUniqueness(target)
if err != nil {
Expand Down Expand Up @@ -252,19 +244,11 @@ func (t *TargetsTableHandler) UpdateTarget(target models.Target, params models.P
return models.Target{}, fmt.Errorf("failed to convert DB model to API model: %w", err)
}

if (params.IfMatch != nil && dbTarget.Revision != nil && *params.IfMatch != *dbTarget.Revision) || (params.IfMatch != nil && dbTarget.Revision == nil) {
return models.Target{}, &types.PreconditionFailedError{
Reason: fmt.Sprintf(
"Revision %d does not match %d. The object may have been modified since you started the request.",
*dbTarget.Revision, *params.IfMatch),
}
if err := checkRevisionEtag(params.IfMatch, dbTarget.Revision); err != nil {
return models.Target{}, err
}

if dbTarget.Revision != nil {
target.Revision = utils.PointerTo(*dbTarget.Revision + 1)
} else {
target.Revision = utils.PointerTo(1)
}
target.Revision = bumpRevision(dbTarget.Revision)

dbObj.Data, err = patchObject(dbObj.Data, target)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion backend/pkg/database/types/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type PreconditionFailedError struct {
}

func (e *PreconditionFailedError) Error() string {
return fmt.Sprintf("Precondition Failed: %s", e.Reason)
return fmt.Sprintf("Precondition failed: %s", e.Reason)
}

type DBConfig struct {
Expand Down

0 comments on commit 2063c38

Please sign in to comment.