diff --git a/pkg/bdb/migrate/mig/mig.go b/pkg/bdb/migrate/mig/mig.go index 2aa5e85..0b4fc92 100644 --- a/pkg/bdb/migrate/mig/mig.go +++ b/pkg/bdb/migrate/mig/mig.go @@ -85,7 +85,7 @@ func CreateBucket(path bdb.Path) func(*zerolog.Logger, *bolt.DB, *bolt.DB) error func DeleteBucket(path bdb.Path) func(*zerolog.Logger, *bolt.DB, *bolt.DB) error { return func(log *zerolog.Logger, _ *bolt.DB, rwDB *bolt.DB) error { - log.Info().Str("path", strings.Join(path, "/")).Msg("CreateBucket") + log.Info().Str("path", strings.Join(path, "/")).Msg("DeleteBucket") if err := rwDB.Update(func(tx *bolt.Tx) error { if len(path) == 1 { diff --git a/pkg/directory/v2/importer.go b/pkg/directory/v2/importer.go index ea8d425..bbb3721 100644 --- a/pkg/directory/v2/importer.go +++ b/pkg/directory/v2/importer.go @@ -14,6 +14,7 @@ import ( v3 "github.com/aserto-dev/go-edge-ds/pkg/directory/v3" "github.com/aserto-dev/go-edge-ds/pkg/ds" "github.com/aserto-dev/go-edge-ds/pkg/session" + "github.com/bufbuild/protovalidate-go" "github.com/google/uuid" "github.com/rs/zerolog" @@ -24,13 +25,16 @@ type Importer struct { logger *zerolog.Logger store *bdb.BoltDB i3 dsi3.ImporterServer + v *protovalidate.Validator } func NewImporter(logger *zerolog.Logger, store *bdb.BoltDB, i3 *v3.Importer) *Importer { + v, _ := protovalidate.New() return &Importer{ logger: logger, store: store, i3: i3, + v: v, } } @@ -86,16 +90,18 @@ func (s *Importer) objectHandler(ctx context.Context, tx *bolt.Tx, req *dsc2.Obj s.logger.Debug().Interface("object", req).Msg("import_object") req3 := convert.ObjectToV3(req) - - if req3 == nil { - return derr.ErrInvalidObject.Msg("nil") + if err := s.v.Validate(req3); err != nil { + // invalid proto message + return derr.ErrProtoValidate.Msg(err.Error()) } - if ok, err := ds.Object(req3).Validate(s.store.MC()); !ok { + obj := ds.Object(req3) + if err := obj.Validate(s.store.MC()); err != nil { + // The object violates the model. return err } - if _, err := bdb.Set[dsc3.Object](ctx, tx, bdb.ObjectsPath, ds.Object(req3).Key(), req3); err != nil { + if _, err := bdb.Set[dsc3.Object](ctx, tx, bdb.ObjectsPath, obj.Key(), req3); err != nil { return derr.ErrInvalidObject.Msg("set") } @@ -106,16 +112,21 @@ func (s *Importer) relationHandler(ctx context.Context, tx *bolt.Tx, req *dsc2.R s.logger.Debug().Interface("relation", req).Msg("import_relation") req3 := convert.RelationToV3(req) + if err := s.v.Validate(req3); err != nil { + // invalid proto message + return derr.ErrProtoValidate.Msg(err.Error()) + } - if req3 == nil { - return derr.ErrInvalidRelation.Msg("nil") + rel := ds.Relation(req3) + if err := rel.Validate(s.store.MC()); err != nil { + return err } - if _, err := bdb.Set[dsc3.Relation](ctx, tx, bdb.RelationsObjPath, ds.Relation(req3).ObjKey(), req3); err != nil { + if _, err := bdb.Set[dsc3.Relation](ctx, tx, bdb.RelationsObjPath, rel.ObjKey(), req3); err != nil { return derr.ErrInvalidRelation.Msg("set") } - if _, err := bdb.Set[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, ds.Relation(req3).SubKey(), req3); err != nil { + if _, err := bdb.Set[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, rel.SubKey(), req3); err != nil { return derr.ErrInvalidRelation.Msg("set") } diff --git a/pkg/directory/v3/importer.go b/pkg/directory/v3/importer.go index db666ac..d6a0738 100644 --- a/pkg/directory/v3/importer.go +++ b/pkg/directory/v3/importer.go @@ -100,18 +100,25 @@ func (s *Importer) objectSetHandler(ctx context.Context, tx *bolt.Tx, req *dsc3. } if err := s.v.Validate(req); err != nil { + // invalid proto message return derr.ErrProtoValidate.Msg(err.Error()) } - etag := ds.Object(req).Hash() + obj := ds.Object(req) + if err := obj.Validate(s.store.MC()); err != nil { + // The object violates the model. + return err + } + + etag := obj.Hash() - updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, ds.Object(req).Key(), req) + updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, obj.Key(), req) if err != nil { return err } if etag == updReq.Etag { - s.logger.Trace().Str("key", ds.Object(req).Key()).Str("etag-equal", etag).Msg("ImportObject") + s.logger.Trace().Str("key", obj.Key()).Str("etag-equal", etag).Msg("ImportObject") return nil } @@ -135,7 +142,12 @@ func (s *Importer) objectDeleteHandler(ctx context.Context, tx *bolt.Tx, req *ds return derr.ErrProtoValidate.Msg(err.Error()) } - if err := bdb.Delete(ctx, tx, bdb.ObjectsPath, ds.Object(req).Key()); err != nil { + obj := ds.Object(req) + if err := obj.Validate(s.store.MC()); err != nil { + return err + } + + if err := bdb.Delete(ctx, tx, bdb.ObjectsPath, obj.Key()); err != nil { return derr.ErrInvalidObject.Msg("delete") } @@ -150,23 +162,24 @@ func (s *Importer) relationSetHandler(ctx context.Context, tx *bolt.Tx, req *dsc } if err := s.v.Validate(req); err != nil { + // invalid proto message return derr.ErrProtoValidate.Msg(err.Error()) } - if err := s.store.MC().ValidateRelation(req); err != nil { - // The relation violates the model. + rel := ds.Relation(req) + if err := rel.Validate(s.store.MC()); err != nil { return err } - etag := ds.Relation(req).Hash() + etag := rel.Hash() - updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, ds.Relation(req).ObjKey(), req) + updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, rel.ObjKey(), req) if err != nil { return err } if etag == updReq.Etag { - s.logger.Trace().Str("key", ds.Relation(req).ObjKey()).Str("etag-equal", etag).Msg("ImportRelation") + s.logger.Trace().Str("key", rel.ObjKey()).Str("etag-equal", etag).Msg("ImportRelation") return nil } @@ -194,11 +207,16 @@ func (s *Importer) relationDeleteHandler(ctx context.Context, tx *bolt.Tx, req * return derr.ErrProtoValidate.Msg(err.Error()) } - if err := bdb.Delete(ctx, tx, bdb.RelationsObjPath, ds.Relation(req).ObjKey()); err != nil { + rel := ds.Relation(req) + if err := rel.Validate(s.store.MC()); err != nil { + return err + } + + if err := bdb.Delete(ctx, tx, bdb.RelationsObjPath, rel.ObjKey()); err != nil { return derr.ErrInvalidRelation.Msg("delete") } - if err := bdb.Delete(ctx, tx, bdb.RelationsSubPath, ds.Relation(req).SubKey()); err != nil { + if err := bdb.Delete(ctx, tx, bdb.RelationsSubPath, rel.SubKey()); err != nil { return derr.ErrInvalidRelation.Msg("delete") } diff --git a/pkg/directory/v3/reader.go b/pkg/directory/v3/reader.go index d6b07b7..e777516 100644 --- a/pkg/directory/v3/reader.go +++ b/pkg/directory/v3/reader.go @@ -37,9 +37,13 @@ func (s *Reader) GetObject(ctx context.Context, req *dsr3.GetObjectRequest) (*ds return resp, derr.ErrProtoValidate.Msg(err.Error()) } + objIdent := ds.ObjectIdentifier(&dsc3.ObjectIdentifier{ObjectType: req.ObjectType, ObjectId: req.ObjectId}) + if err := objIdent.Validate(s.store.MC()); err != nil { + return resp, err + } + // TODO handle pagination request. err := s.store.DB().View(func(tx *bolt.Tx) error { - objIdent := ds.ObjectIdentifier(&dsc3.ObjectIdentifier{ObjectType: req.ObjectType, ObjectId: req.ObjectId}) obj, err := bdb.Get[dsc3.Object](ctx, tx, bdb.ObjectsPath, objIdent.Key()) if err != nil { return err @@ -84,7 +88,7 @@ func (s *Reader) GetObjectMany(ctx context.Context, req *dsr3.GetObjectManyReque // validate all object identifiers first. for _, i := range req.Param { - if ok, err := ds.ObjectIdentifier(i).Validate(); !ok { + if err := ds.ObjectIdentifier(i).Validate(s.store.MC()); err != nil { return resp, err } } @@ -152,7 +156,12 @@ func (s *Reader) GetRelation(ctx context.Context, req *dsr3.GetRelationRequest) return resp, derr.ErrProtoValidate.Msg(err.Error()) } - path, filter, err := ds.GetRelation(req).PathAndFilter() + getRelation := ds.GetRelation(req) + if err := getRelation.Validate(s.store.MC()); err != nil { + return resp, err + } + + path, filter, err := getRelation.PathAndFilter() if err != nil { return resp, err } @@ -214,7 +223,12 @@ func (s *Reader) GetRelations(ctx context.Context, req *dsr3.GetRelationsRequest req.Page = &dsc3.PaginationRequest{Size: 100} } - path, keyFilter, valueFilter := ds.GetRelations(req).Filter() + getRelations := ds.GetRelations(req) + if err := getRelations.Validate(s.store.MC()); err != nil { + return resp, err + } + + path, keyFilter, valueFilter := getRelations.Filter() opts := []bdb.ScanOption{ bdb.WithPageToken(req.Page.Token), @@ -277,9 +291,14 @@ func (s *Reader) Check(ctx context.Context, req *dsr3.CheckRequest) (*dsr3.Check return resp, derr.ErrProtoValidate.Msg(err.Error()) } + check := ds.Check(req) + if err := check.Validate(s.store.MC()); err != nil { + return resp, err + } + err := s.store.DB().View(func(tx *bolt.Tx) error { var err error - resp, err = ds.Check(req).Exec(ctx, tx, s.store.MC()) + resp, err = check.Exec(ctx, tx, s.store.MC()) return err }) @@ -294,9 +313,14 @@ func (s *Reader) CheckPermission(ctx context.Context, req *dsr3.CheckPermissionR return resp, derr.ErrProtoValidate.Msg(err.Error()) } + check := ds.CheckPermission(req) + if err := check.Validate(s.store.MC()); err != nil { + return resp, err + } + err := s.store.DB().View(func(tx *bolt.Tx) error { var err error - resp, err = ds.CheckPermission(req).Exec(ctx, tx, s.store.MC()) + resp, err = check.Exec(ctx, tx, s.store.MC()) return err }) @@ -311,9 +335,14 @@ func (s *Reader) CheckRelation(ctx context.Context, req *dsr3.CheckRelationReque return resp, derr.ErrProtoValidate.Msg(err.Error()) } + check := ds.CheckRelation(req) + if err := check.Validate(s.store.MC()); err != nil { + return resp, err + } + err := s.store.DB().View(func(tx *bolt.Tx) error { var err error - resp, err = ds.CheckRelation(req).Exec(ctx, tx, s.store.MC()) + resp, err = check.Exec(ctx, tx, s.store.MC()) return err }) @@ -328,9 +357,14 @@ func (s *Reader) GetGraph(ctx context.Context, req *dsr3.GetGraphRequest) (*dsr3 return &dsr3.GetGraphResponse{}, derr.ErrProtoValidate.Msg(err.Error()) } + getGraph := ds.GetGraph(req) + if err := getGraph.Validate(s.store.MC()); err != nil { + return resp, err + } + err := s.store.DB().View(func(tx *bolt.Tx) error { var err error - results, err := ds.GetGraph(req).Exec(ctx, tx) + results, err := getGraph.Exec(ctx, tx) if err != nil { return err } diff --git a/pkg/directory/v3/writer.go b/pkg/directory/v3/writer.go index bb61642..203d1a5 100644 --- a/pkg/directory/v3/writer.go +++ b/pkg/directory/v3/writer.go @@ -35,26 +35,33 @@ func (s *Writer) SetObject(ctx context.Context, req *dsw3.SetObjectRequest) (*ds resp := &dsw3.SetObjectResponse{} if err := s.v.Validate(req); err != nil { + // invalid proto message. return resp, derr.ErrProtoValidate.Msg(err.Error()) } - etag := ds.Object(req.Object).Hash() + obj := ds.Object(req.Object) + if err := obj.Validate(s.store.MC()); err != nil { + // The object violates the model. + return resp, err + } + + etag := obj.Hash() err := s.store.DB().Update(func(tx *bolt.Tx) error { - updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, ds.Object(req.Object).Key(), req.Object) + updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.ObjectsPath, obj.Key(), req.Object) if err != nil { return err } if etag == updReq.Etag { - s.logger.Trace().Str("key", ds.Object(req.Object).Key()).Str("etag-equal", etag).Msg("set_object") + s.logger.Trace().Str("key", obj.Key()).Str("etag-equal", etag).Msg("set_object") resp.Result = updReq return nil } updReq.Etag = etag - objType, err := bdb.Set(ctx, tx, bdb.ObjectsPath, ds.Object(req.Object).Key(), updReq) + objType, err := bdb.Set(ctx, tx, bdb.ObjectsPath, obj.Key(), updReq) if err != nil { return err } @@ -73,16 +80,21 @@ func (s *Writer) DeleteObject(ctx context.Context, req *dsw3.DeleteObjectRequest return resp, derr.ErrProtoValidate.Msg(err.Error()) } + objIdent := ds.ObjectIdentifier(&dsc3.ObjectIdentifier{ObjectType: req.GetObjectType(), ObjectId: req.GetObjectId()}) + + if err := objIdent.Validate(s.store.MC()); err != nil { + return resp, err + } + err := s.store.DB().Update(func(tx *bolt.Tx) error { - objIdent := &dsc3.ObjectIdentifier{ObjectType: req.GetObjectType(), ObjectId: req.GetObjectId()} - if err := bdb.Delete(ctx, tx, bdb.ObjectsPath, ds.ObjectIdentifier(objIdent).Key()); err != nil { + if err := bdb.Delete(ctx, tx, bdb.ObjectsPath, objIdent.Key()); err != nil { return err } if req.GetWithRelations() { { // incoming object relations of object instance (result.type == incoming.subject.type && result.key == incoming.subject.key) - iter, err := bdb.NewScanIterator[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, bdb.WithKeyFilter(ds.ObjectIdentifier(objIdent).Key()+ds.InstanceSeparator)) + iter, err := bdb.NewScanIterator[dsc3.Relation](ctx, tx, bdb.RelationsSubPath, bdb.WithKeyFilter(objIdent.Key()+ds.InstanceSeparator)) if err != nil { return err } @@ -100,7 +112,7 @@ func (s *Writer) DeleteObject(ctx context.Context, req *dsw3.DeleteObjectRequest } { // outgoing object relations of object instance (result.type == outgoing.object.type && result.key == outgoing.object.key) - iter, err := bdb.NewScanIterator[dsc3.Relation](ctx, tx, bdb.RelationsObjPath, bdb.WithKeyFilter(ds.ObjectIdentifier(objIdent).Key()+ds.InstanceSeparator)) + iter, err := bdb.NewScanIterator[dsc3.Relation](ctx, tx, bdb.RelationsObjPath, bdb.WithKeyFilter(objIdent.Key()+ds.InstanceSeparator)) if err != nil { return err } @@ -134,33 +146,33 @@ func (s *Writer) SetRelation(ctx context.Context, req *dsw3.SetRelationRequest) return resp, derr.ErrProtoValidate.Msg(err.Error()) } - if err := s.store.MC().ValidateRelation(req.Relation); err != nil { - // The relation violates the model. + relation := ds.Relation(req.Relation) + if err := relation.Validate(s.store.MC()); err != nil { return resp, err } - etag := ds.Relation(req.Relation).Hash() + etag := relation.Hash() err := s.store.DB().Update(func(tx *bolt.Tx) error { - updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, ds.Relation(req.Relation).ObjKey(), req.Relation) + updReq, err := bdb.UpdateMetadata(ctx, tx, bdb.RelationsObjPath, relation.ObjKey(), req.Relation) if err != nil { return err } if etag == updReq.Etag { - s.logger.Trace().Str("key", ds.Relation(req.Relation).ObjKey()).Str("etag-equal", etag).Msg("set_relation") + s.logger.Trace().Str("key", relation.ObjKey()).Str("etag-equal", etag).Msg("set_relation") resp.Result = updReq return nil } updReq.Etag = etag - objRel, err := bdb.Set(ctx, tx, bdb.RelationsObjPath, ds.Relation(req.Relation).ObjKey(), updReq) + objRel, err := bdb.Set(ctx, tx, bdb.RelationsObjPath, relation.ObjKey(), updReq) if err != nil { return err } - if _, err := bdb.Set(ctx, tx, bdb.RelationsSubPath, ds.Relation(req.Relation).SubKey(), updReq); err != nil { + if _, err := bdb.Set(ctx, tx, bdb.RelationsSubPath, relation.SubKey(), updReq); err != nil { return err } @@ -179,15 +191,19 @@ func (s *Writer) DeleteRelation(ctx context.Context, req *dsw3.DeleteRelationReq return resp, derr.ErrProtoValidate.Msg(err.Error()) } + rel := ds.Relation(&dsc3.Relation{ + ObjectType: req.ObjectType, + ObjectId: req.ObjectId, + Relation: req.Relation, + SubjectType: req.SubjectType, + SubjectId: req.SubjectId, + SubjectRelation: req.SubjectRelation, + }) + if err := rel.Validate(s.store.MC()); err != nil { + return resp, err + } + err := s.store.DB().Update(func(tx *bolt.Tx) error { - rel := ds.Relation(&dsc3.Relation{ - ObjectType: req.ObjectType, - ObjectId: req.ObjectId, - Relation: req.Relation, - SubjectType: req.SubjectType, - SubjectId: req.SubjectId, - SubjectRelation: req.SubjectRelation, - }) if err := bdb.Delete(ctx, tx, bdb.RelationsObjPath, rel.ObjKey()); err != nil { return err diff --git a/pkg/ds/check.go b/pkg/ds/check.go index 02e8fdd..9a2b4ed 100644 --- a/pkg/ds/check.go +++ b/pkg/ds/check.go @@ -36,24 +36,24 @@ func (i *check) Subject() *dsc3.ObjectIdentifier { } } -func (i *check) Validate(mc *cache.Cache) (bool, error) { +func (i *check) Validate(mc *cache.Cache) error { if i == nil || i.CheckRequest == nil { - return false, ErrInvalidRequest.Msg("check") + return ErrInvalidRequest.Msg("check") } if !mc.ObjectExists(model.ObjectName(i.ObjectType)) { - return false, ErrObjectNotFound.Msgf("object_type: %s", i.ObjectType) + return ErrObjectNotFound.Msgf("object_type: %s", i.ObjectType) } if !mc.ObjectExists(model.ObjectName(i.SubjectType)) { - return false, ErrObjectNotFound.Msgf("subject_type: %s", i.SubjectType) + return ErrObjectNotFound.Msgf("subject_type: %s", i.SubjectType) } if !mc.RelationExists(model.ObjectName(i.ObjectType), model.RelationName(i.Relation)) { - return false, ErrRelationNotFound.Msgf("relation: %s%s%s", i.ObjectType, RelationSeparator, i.Relation) + return ErrRelationNotFound.Msgf("relation: %s%s%s", i.ObjectType, RelationSeparator, i.Relation) } - return true, nil + return nil } func (i *check) Exec(ctx context.Context, tx *bolt.Tx, mc *cache.Cache) (*dsr3.CheckResponse, error) { diff --git a/pkg/ds/check_permission.go b/pkg/ds/check_permission.go index 604e4c3..fce4419 100644 --- a/pkg/ds/check_permission.go +++ b/pkg/ds/check_permission.go @@ -36,24 +36,24 @@ func (i *checkPermission) Subject() *dsc3.ObjectIdentifier { } } -func (i *checkPermission) Validate(mc *cache.Cache) (bool, error) { +func (i *checkPermission) Validate(mc *cache.Cache) error { if i == nil || i.CheckPermissionRequest == nil { - return false, ErrInvalidRequest.Msg("check_permission") + return ErrInvalidRequest.Msg("check_permission") } - if ok, err := ObjectIdentifier(i.Object()).Validate(); !ok { - return ok, err + if err := ObjectIdentifier(i.Object()).Validate(mc); err != nil { + return err } - if ok, err := ObjectIdentifier(i.Subject()).Validate(); !ok { - return ok, err + if err := ObjectIdentifier(i.Subject()).Validate(mc); err != nil { + return err } if !mc.PermissionExists(model.ObjectName(i.ObjectType), model.RelationName(i.Permission)) { - return false, ErrPermissionNotFound.Msgf("%s%s%s", i.ObjectType, RelationSeparator, i.Permission) + return ErrPermissionNotFound.Msgf("%s%s%s", i.ObjectType, RelationSeparator, i.Permission) } - return true, nil + return nil } func (i *checkPermission) Exec(ctx context.Context, tx *bolt.Tx, mc *cache.Cache) (*dsr3.CheckPermissionResponse, error) { diff --git a/pkg/ds/check_relation.go b/pkg/ds/check_relation.go index cd18854..7fd35df 100644 --- a/pkg/ds/check_relation.go +++ b/pkg/ds/check_relation.go @@ -42,24 +42,24 @@ func (i *checkRelation) Subject() *dsc3.ObjectIdentifier { } } -func (i *checkRelation) Validate(mc *cache.Cache) (bool, error) { +func (i *checkRelation) Validate(mc *cache.Cache) error { if i == nil || i.CheckRelationRequest == nil { - return false, ErrInvalidRequest.Msg("check_relation") + return ErrInvalidRequest.Msg("check_relation") } - if ok, err := ObjectIdentifier(i.Object()).Validate(); !ok { - return ok, err + if err := ObjectIdentifier(i.Object()).Validate(mc); err != nil { + return err } - if ok, err := ObjectIdentifier(i.Subject()).Validate(); !ok { - return ok, err + if err := ObjectIdentifier(i.Subject()).Validate(mc); err != nil { + return err } if !mc.RelationExists(model.ObjectName(i.ObjectType), model.RelationName(i.Relation)) { - return false, ErrRelationNotFound.Msgf("%s%s%s", i.ObjectType, RelationSeparator, i.Relation) + return ErrRelationNotFound.Msgf("%s%s%s", i.ObjectType, RelationSeparator, i.Relation) } - return true, nil + return nil } func (i *checkRelation) Exec(ctx context.Context, tx *bolt.Tx, mc *cache.Cache) (*dsr3.CheckRelationResponse, error) { diff --git a/pkg/ds/graph.go b/pkg/ds/graph.go index 3822479..3ccfc8e 100644 --- a/pkg/ds/graph.go +++ b/pkg/ds/graph.go @@ -44,31 +44,31 @@ func (i *getGraph) Subject() *dsc3.ObjectIdentifier { } } -func (i *getGraph) Validate(mc *cache.Cache) (bool, error) { +func (i *getGraph) Validate(mc *cache.Cache) error { if i == nil || i.GetGraphRequest == nil { - return false, ErrInvalidRequest.Msg("get_graph") + return ErrInvalidRequest.Msg("get_graph") } // anchor must be defined, hence use an ObjectIdentifier. - if ok, err := ObjectIdentifier(i.Anchor()).Validate(); !ok { - return ok, err + if err := ObjectIdentifier(i.Anchor()).Validate(mc); err != nil { + return err } // Object can be optional, hence the use of an ObjectSelector. - if ok, err := ObjectSelector(i.Object()).Validate(); !ok { - return ok, err + if err := ObjectSelector(i.Object()).Validate(mc); err != nil { + return err } // Relation can be optional, hence the use of a RelationTypeSelector. if i.GetRelation() != "" { if !mc.RelationExists(model.ObjectName(i.ObjectType), model.RelationName(i.Relation)) { - return false, ErrRelationNotFound.Msgf("%s%s%s", i.ObjectType, RelationSeparator, i.Relation) + return ErrRelationNotFound.Msgf("%s%s%s", i.ObjectType, RelationSeparator, i.Relation) } } // Subject can be option, hence the use of an ObjectSelector. - if ok, err := ObjectSelector(i.Subject()).Validate(); !ok { - return ok, err + if err := ObjectSelector(i.Subject()).Validate(mc); err != nil { + return err } // either Object or Subject must be equal to the Anchor to indicate the directionality of the graph walk. @@ -76,10 +76,10 @@ func (i *getGraph) Validate(mc *cache.Cache) (bool, error) { // Anchor == Object ==> object->subject if !ObjectIdentifier(i.Anchor()).Equal(i.Object()) && !ObjectIdentifier(i.Anchor()).Equal(i.Subject()) { - return false, ErrGraphDirectionality + return ErrGraphDirectionality } - return true, nil + return nil } func (i *getGraph) Exec(ctx context.Context, tx *bolt.Tx /*, resolver *cache.Cache*/) ([]*dsc3.ObjectDependency, error) { diff --git a/pkg/ds/object.go b/pkg/ds/object.go index 3936415..22d890b 100644 --- a/pkg/ds/object.go +++ b/pkg/ds/object.go @@ -31,34 +31,16 @@ func (i *object) Key() string { return i.GetType() + TypeIDSeparator + i.GetId() } -func (i *object) Validate(mc *cache.Cache) (bool, error) { - if i.Object == nil { - return false, ErrInvalidArgumentObject.Msg(objectIdentifierNil) - } - - // #1 check is type field is set. - if IsNotSet(i.GetType()) { - return false, ErrInvalidArgumentObject.Msg(objectIdentifierType) - } - - // #2 check if id field is set. - if IsNotSet(i.GetId()) { - return false, ErrInvalidArgumentObject.Msg(objectIdentifierID) - } - +func (i *object) Validate(mc *cache.Cache) error { if i.Properties == nil { i.Properties = pb.NewStruct() } - if mc == nil { - return true, nil + if mc != nil && !mc.ObjectExists(model.ObjectName(i.Object.Type)) { + return derr.ErrObjectTypeNotFound.Msg(i.Object.Type) } - if !mc.ObjectExists(model.ObjectName(i.Object.Type)) { - return false, derr.ErrObjectTypeNotFound.Msg(i.Object.Type) - } - - return true, nil + return nil } func (i *object) Hash() string { @@ -100,25 +82,27 @@ type objectIdentifier struct { func ObjectIdentifier(i *dsc3.ObjectIdentifier) *objectIdentifier { return &objectIdentifier{i} } -func (i *objectIdentifier) Validate() (bool, error) { +func (i *objectIdentifier) Validate(mc *cache.Cache) error { if i.ObjectIdentifier == nil { - return false, ErrInvalidArgumentObjectIdentifier.Msg(objectIdentifierNil) + return ErrInvalidArgumentObjectIdentifier.Msg(objectIdentifierNil) } // #1 check is type field is set. if IsNotSet(i.GetObjectType()) { - return false, ErrInvalidArgumentObjectIdentifier.Msg(objectIdentifierType) + return ErrInvalidArgumentObjectIdentifier.Msg(objectIdentifierType) } // #2 check if id field is set. if IsNotSet(i.GetObjectId()) { - return false, ErrInvalidArgumentObjectIdentifier.Msg(objectIdentifierID) + return ErrInvalidArgumentObjectIdentifier.Msg(objectIdentifierID) } - // #3 validate that type is defined in the type system. - // TODO: validate type existence against TypeSystem model. + // #3 check if type exists. + if mc != nil && !mc.ObjectExists(model.ObjectName(i.ObjectIdentifier.ObjectType)) { + return derr.ErrObjectTypeNotFound.Msg(i.ObjectIdentifier.ObjectType) + } - return true, nil + return nil } func (i *objectIdentifier) Key() string { @@ -144,28 +128,24 @@ func ObjectSelector(i *dsc3.ObjectIdentifier) *objectSelector { return &objectSe // - empty object // - type only // - type + key. -func (i *objectSelector) Validate() (bool, error) { +func (i *objectSelector) Validate(mc *cache.Cache) error { // nil not allowed if i.ObjectIdentifier == nil { - return false, ErrInvalidArgumentObjectTypeSelector.Msg(objectIdentifierNil) - } - - // empty object - if IsNotSet(i.GetObjectType()) && IsNotSet(i.GetObjectId()) { - return true, nil - } - - // type only - if IsSet(i.GetObjectType()) && IsNotSet(i.GetObjectId()) { - return true, nil + return ErrInvalidArgumentObjectTypeSelector.Msg(objectIdentifierNil) } - // type + key - if IsSet(i.GetObjectType()) && IsSet(i.GetObjectId()) { - return true, nil + switch { + case IsSet(i.GetObjectType()): + // check if type exists. + if mc != nil && !mc.ObjectExists(model.ObjectName(i.ObjectIdentifier.ObjectType)) { + return derr.ErrObjectTypeNotFound.Msg(i.ObjectIdentifier.ObjectType) + } + case IsSet(i.GetObjectId()): + // can't have id without type. + return ErrInvalidArgumentObjectTypeSelector.Msg(objectIdentifierType) } - return false, nil + return nil } func (i *objectSelector) IsComplete() bool { diff --git a/pkg/ds/relation.go b/pkg/ds/relation.go index 15bf33a..14c24d7 100644 --- a/pkg/ds/relation.go +++ b/pkg/ds/relation.go @@ -17,11 +17,16 @@ import ( "github.com/rs/zerolog/log" ) -// Relation. +// Relation identifier. type relation struct { *dsc3.Relation } +// Relation selector. +type relations struct { + relation +} + func Relation(i *dsc3.Relation) *relation { return &relation{i} } func GetRelation(i *dsr3.GetRelationRequest) *relation { @@ -35,15 +40,15 @@ func GetRelation(i *dsr3.GetRelationRequest) *relation { }} } -func GetRelations(i *dsr3.GetRelationsRequest) *relation { - return &relation{&dsc3.Relation{ +func GetRelations(i *dsr3.GetRelationsRequest) *relations { + return &relations{relation{&dsc3.Relation{ ObjectType: i.ObjectType, ObjectId: i.ObjectId, Relation: i.Relation, SubjectType: i.SubjectType, SubjectId: i.SubjectId, SubjectRelation: i.SubjectRelation, - }} + }}} } func (i *relation) Key() string { @@ -82,45 +87,64 @@ func (i *relation) SubKey() string { Iff(i.GetSubjectRelation() == "", "", InstanceSeparator+i.GetSubjectRelation()) } -func (i *relation) Validate(mc *cache.Cache) (bool, error) { +func (i *relation) Validate(mc *cache.Cache) error { + if i == nil || i.Relation == nil { + return ErrInvalidArgumentRelation.Msg("relation not set (nil)") + } - if i == nil { - return false, ErrInvalidArgumentRelation.Msg("relation not set (nil)") + if IsNotSet(i.GetRelation()) { + return ErrInvalidArgumentRelation.Msg("relation") } - if i.Relation == nil { - return false, ErrInvalidArgumentRelation.Msg("relation not set (nil)") + if err := ObjectIdentifier(i.Object()).Validate(mc); err != nil { + return err } - if IsNotSet(i.GetRelation()) { - return false, ErrInvalidArgumentRelation.Msg("relation") + if err := ObjectIdentifier(i.Subject()).Validate(mc); err != nil { + return err } - if ok, err := ObjectIdentifier(i.Object()).Validate(); !ok { - return ok, err + if mc == nil { + return nil } - if ok, err := ObjectIdentifier(i.Subject()).Validate(); !ok { - return ok, err + return mc.ValidateRelation(i.Relation) +} + +func (i *relations) Validate(mc *cache.Cache) error { + if i == nil || i.Relation == nil { + return ErrInvalidArgumentRelation.Msg("relation not set (nil)") } - if mc == nil { - return true, nil + if err := ObjectSelector(i.Object()).Validate(mc); err != nil { + return err } - if !mc.ObjectExists(model.ObjectName(i.GetObjectType())) { - return false, derr.ErrObjectNotFound.Msg(i.GetObjectType()) + if err := ObjectSelector(i.Subject()).Validate(mc); err != nil { + return err } - if !mc.ObjectExists(model.ObjectName(i.GetSubjectType())) { - return false, derr.ErrObjectNotFound.Msg(i.GetSubjectType()) + if IsSet(i.GetRelation()) { + if IsNotSet(i.GetObjectType()) { + return ErrInvalidArgumentRelation.Msg("object type not set") + } + + if mc != nil && !mc.RelationExists(model.ObjectName(i.GetObjectType()), model.RelationName(i.GetRelation())) { + return derr.ErrRelationNotFound.Msg(i.GetObjectType() + ":" + i.GetRelation()) + } } - if !mc.RelationExists(model.ObjectName(i.GetObjectType()), model.RelationName(i.GetRelation())) { - return false, derr.ErrRelationNotFound.Msg(i.GetObjectType() + ":" + i.GetRelation()) + if IsSet(i.GetSubjectRelation()) { + if IsNotSet(i.GetSubjectType()) { + return ErrInvalidArgumentRelation.Msg("subject type not set") + } + + if mc != nil && !mc.RelationExists(model.ObjectName(i.GetSubjectType()), model.RelationName(i.GetSubjectRelation())) { + return derr.ErrRelationNotFound.Msg(i.GetSubjectType() + ":" + i.GetSubjectRelation()) + } } - return true, nil + return nil } func (i *relation) Hash() string {