Skip to content

Commit

Permalink
Merge pull request #71 from aserto-dev/validate-objects
Browse files Browse the repository at this point in the history
Validate objects
  • Loading branch information
ronenh authored Dec 22, 2023
2 parents e98c03a + 7c862bf commit ad69282
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 154 deletions.
2 changes: 1 addition & 1 deletion pkg/bdb/migrate/mig/mig.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func CreateBucket(path bdb.Path) func(*zerolog.Logger, *bolt.DB, *bolt.DB) error

func DeleteBucket(path bdb.Path) func(*zerolog.Logger, *bolt.DB, *bolt.DB) error {
return func(log *zerolog.Logger, _ *bolt.DB, rwDB *bolt.DB) error {
log.Info().Str("path", strings.Join(path, "/")).Msg("CreateBucket")
log.Info().Str("path", strings.Join(path, "/")).Msg("DeleteBucket")

if err := rwDB.Update(func(tx *bolt.Tx) error {
if len(path) == 1 {
Expand Down
29 changes: 20 additions & 9 deletions pkg/directory/v2/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
v3 "github.com/aserto-dev/go-edge-ds/pkg/directory/v3"
"github.com/aserto-dev/go-edge-ds/pkg/ds"
"github.com/aserto-dev/go-edge-ds/pkg/session"
"github.com/bufbuild/protovalidate-go"

"github.com/google/uuid"
"github.com/rs/zerolog"
Expand All @@ -24,13 +25,16 @@ type Importer struct {
logger *zerolog.Logger
store *bdb.BoltDB
i3 dsi3.ImporterServer
v *protovalidate.Validator
}

func NewImporter(logger *zerolog.Logger, store *bdb.BoltDB, i3 *v3.Importer) *Importer {
v, _ := protovalidate.New()
return &Importer{
logger: logger,
store: store,
i3: i3,
v: v,
}
}

Expand Down Expand Up @@ -86,16 +90,18 @@ func (s *Importer) objectHandler(ctx context.Context, tx *bolt.Tx, req *dsc2.Obj
s.logger.Debug().Interface("object", req).Msg("import_object")

req3 := convert.ObjectToV3(req)

if req3 == nil {
return derr.ErrInvalidObject.Msg("nil")
if err := s.v.Validate(req3); err != nil {
// invalid proto message
return derr.ErrProtoValidate.Msg(err.Error())
}

if ok, err := ds.Object(req3).Validate(s.store.MC()); !ok {
obj := ds.Object(req3)
if err := obj.Validate(s.store.MC()); err != nil {
// The object violates the model.
return err
}

if _, err := bdb.Set[dsc3.Object](ctx, tx, bdb.ObjectsPath, ds.Object(req3).Key(), req3); err != nil {
if _, err := bdb.Set[dsc3.Object](ctx, tx, bdb.ObjectsPath, obj.Key(), req3); err != nil {
return derr.ErrInvalidObject.Msg("set")
}

Expand All @@ -106,16 +112,21 @@ func (s *Importer) relationHandler(ctx context.Context, tx *bolt.Tx, req *dsc2.R
s.logger.Debug().Interface("relation", req).Msg("import_relation")

req3 := convert.RelationToV3(req)
if err := s.v.Validate(req3); err != nil {
// invalid proto message
return derr.ErrProtoValidate.Msg(err.Error())
}

if req3 == nil {
return derr.ErrInvalidRelation.Msg("nil")
rel := ds.Relation(req3)
if err := rel.Validate(s.store.MC()); err != nil {
return err
}

if _, err := bdb.Set[dsc3.Relation](ctx, tx, bdb.RelationsObjPath, ds.Relation(req3).ObjKey(), req3); err != nil {
if _, err := bdb.Set[dsc3.Relation](ctx, tx, bdb.RelationsObjPath, rel.ObjKey(), req3); err != nil {
return derr.ErrInvalidRelation.Msg("set")
}

if _, err := bdb.Set[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, ds.Relation(req3).SubKey(), req3); err != nil {
if _, err := bdb.Set[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, rel.SubKey(), req3); err != nil {
return derr.ErrInvalidRelation.Msg("set")
}

Expand Down
40 changes: 29 additions & 11 deletions pkg/directory/v3/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,25 @@ func (s *Importer) objectSetHandler(ctx context.Context, tx *bolt.Tx, req *dsc3.
}

if err := s.v.Validate(req); err != nil {
// invalid proto message
return derr.ErrProtoValidate.Msg(err.Error())
}

etag := ds.Object(req).Hash()
obj := ds.Object(req)
if err := obj.Validate(s.store.MC()); err != nil {
// The object violates the model.
return err
}

etag := obj.Hash()

updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, ds.Object(req).Key(), req)
updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, obj.Key(), req)
if err != nil {
return err
}

if etag == updReq.Etag {
s.logger.Trace().Str("key", ds.Object(req).Key()).Str("etag-equal", etag).Msg("ImportObject")
s.logger.Trace().Str("key", obj.Key()).Str("etag-equal", etag).Msg("ImportObject")
return nil
}

Expand All @@ -135,7 +142,12 @@ func (s *Importer) objectDeleteHandler(ctx context.Context, tx *bolt.Tx, req *ds
return derr.ErrProtoValidate.Msg(err.Error())
}

if err := bdb.Delete(ctx, tx, bdb.ObjectsPath, ds.Object(req).Key()); err != nil {
obj := ds.Object(req)
if err := obj.Validate(s.store.MC()); err != nil {
return err
}

if err := bdb.Delete(ctx, tx, bdb.ObjectsPath, obj.Key()); err != nil {
return derr.ErrInvalidObject.Msg("delete")
}

Expand All @@ -150,23 +162,24 @@ func (s *Importer) relationSetHandler(ctx context.Context, tx *bolt.Tx, req *dsc
}

if err := s.v.Validate(req); err != nil {
// invalid proto message
return derr.ErrProtoValidate.Msg(err.Error())
}

if err := s.store.MC().ValidateRelation(req); err != nil {
// The relation violates the model.
rel := ds.Relation(req)
if err := rel.Validate(s.store.MC()); err != nil {
return err
}

etag := ds.Relation(req).Hash()
etag := rel.Hash()

updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, ds.Relation(req).ObjKey(), req)
updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, rel.ObjKey(), req)
if err != nil {
return err
}

if etag == updReq.Etag {
s.logger.Trace().Str("key", ds.Relation(req).ObjKey()).Str("etag-equal", etag).Msg("ImportRelation")
s.logger.Trace().Str("key", rel.ObjKey()).Str("etag-equal", etag).Msg("ImportRelation")
return nil
}

Expand Down Expand Up @@ -194,11 +207,16 @@ func (s *Importer) relationDeleteHandler(ctx context.Context, tx *bolt.Tx, req *
return derr.ErrProtoValidate.Msg(err.Error())
}

if err := bdb.Delete(ctx, tx, bdb.RelationsObjPath, ds.Relation(req).ObjKey()); err != nil {
rel := ds.Relation(req)
if err := rel.Validate(s.store.MC()); err != nil {
return err
}

if err := bdb.Delete(ctx, tx, bdb.RelationsObjPath, rel.ObjKey()); err != nil {
return derr.ErrInvalidRelation.Msg("delete")
}

if err := bdb.Delete(ctx, tx, bdb.RelationsSubPath, ds.Relation(req).SubKey()); err != nil {
if err := bdb.Delete(ctx, tx, bdb.RelationsSubPath, rel.SubKey()); err != nil {
return derr.ErrInvalidRelation.Msg("delete")
}

Expand Down
50 changes: 42 additions & 8 deletions pkg/directory/v3/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ func (s *Reader) GetObject(ctx context.Context, req *dsr3.GetObjectRequest) (*ds
return resp, derr.ErrProtoValidate.Msg(err.Error())
}

objIdent := ds.ObjectIdentifier(&dsc3.ObjectIdentifier{ObjectType: req.ObjectType, ObjectId: req.ObjectId})
if err := objIdent.Validate(s.store.MC()); err != nil {
return resp, err
}

// TODO handle pagination request.
err := s.store.DB().View(func(tx *bolt.Tx) error {
objIdent := ds.ObjectIdentifier(&dsc3.ObjectIdentifier{ObjectType: req.ObjectType, ObjectId: req.ObjectId})
obj, err := bdb.Get[dsc3.Object](ctx, tx, bdb.ObjectsPath, objIdent.Key())
if err != nil {
return err
Expand Down Expand Up @@ -84,7 +88,7 @@ func (s *Reader) GetObjectMany(ctx context.Context, req *dsr3.GetObjectManyReque

// validate all object identifiers first.
for _, i := range req.Param {
if ok, err := ds.ObjectIdentifier(i).Validate(); !ok {
if err := ds.ObjectIdentifier(i).Validate(s.store.MC()); err != nil {
return resp, err
}
}
Expand Down Expand Up @@ -152,7 +156,12 @@ func (s *Reader) GetRelation(ctx context.Context, req *dsr3.GetRelationRequest)
return resp, derr.ErrProtoValidate.Msg(err.Error())
}

path, filter, err := ds.GetRelation(req).PathAndFilter()
getRelation := ds.GetRelation(req)
if err := getRelation.Validate(s.store.MC()); err != nil {
return resp, err
}

path, filter, err := getRelation.PathAndFilter()
if err != nil {
return resp, err
}
Expand Down Expand Up @@ -214,7 +223,12 @@ func (s *Reader) GetRelations(ctx context.Context, req *dsr3.GetRelationsRequest
req.Page = &dsc3.PaginationRequest{Size: 100}
}

path, keyFilter, valueFilter := ds.GetRelations(req).Filter()
getRelations := ds.GetRelations(req)
if err := getRelations.Validate(s.store.MC()); err != nil {
return resp, err
}

path, keyFilter, valueFilter := getRelations.Filter()

opts := []bdb.ScanOption{
bdb.WithPageToken(req.Page.Token),
Expand Down Expand Up @@ -277,9 +291,14 @@ func (s *Reader) Check(ctx context.Context, req *dsr3.CheckRequest) (*dsr3.Check
return resp, derr.ErrProtoValidate.Msg(err.Error())
}

check := ds.Check(req)
if err := check.Validate(s.store.MC()); err != nil {
return resp, err
}

err := s.store.DB().View(func(tx *bolt.Tx) error {
var err error
resp, err = ds.Check(req).Exec(ctx, tx, s.store.MC())
resp, err = check.Exec(ctx, tx, s.store.MC())
return err
})

Expand All @@ -294,9 +313,14 @@ func (s *Reader) CheckPermission(ctx context.Context, req *dsr3.CheckPermissionR
return resp, derr.ErrProtoValidate.Msg(err.Error())
}

check := ds.CheckPermission(req)
if err := check.Validate(s.store.MC()); err != nil {
return resp, err
}

err := s.store.DB().View(func(tx *bolt.Tx) error {
var err error
resp, err = ds.CheckPermission(req).Exec(ctx, tx, s.store.MC())
resp, err = check.Exec(ctx, tx, s.store.MC())
return err
})

Expand All @@ -311,9 +335,14 @@ func (s *Reader) CheckRelation(ctx context.Context, req *dsr3.CheckRelationReque
return resp, derr.ErrProtoValidate.Msg(err.Error())
}

check := ds.CheckRelation(req)
if err := check.Validate(s.store.MC()); err != nil {
return resp, err
}

err := s.store.DB().View(func(tx *bolt.Tx) error {
var err error
resp, err = ds.CheckRelation(req).Exec(ctx, tx, s.store.MC())
resp, err = check.Exec(ctx, tx, s.store.MC())
return err
})

Expand All @@ -328,9 +357,14 @@ func (s *Reader) GetGraph(ctx context.Context, req *dsr3.GetGraphRequest) (*dsr3
return &dsr3.GetGraphResponse{}, derr.ErrProtoValidate.Msg(err.Error())
}

getGraph := ds.GetGraph(req)
if err := getGraph.Validate(s.store.MC()); err != nil {
return resp, err
}

err := s.store.DB().View(func(tx *bolt.Tx) error {
var err error
results, err := ds.GetGraph(req).Exec(ctx, tx)
results, err := getGraph.Exec(ctx, tx)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit ad69282

Please sign in to comment.