Skip to content

Commit

Permalink
feat(logic): implement virtual FS white/black list
Browse files Browse the repository at this point in the history
  • Loading branch information
ccamel committed Apr 13, 2023
1 parent d35673d commit 81c5c3e
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 126 deletions.
28 changes: 13 additions & 15 deletions x/logic/keeper/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
sdkmath "cosmossdk.io/math"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/ichiban/prolog"
"github.com/okp4/okp4d/x/logic/fs"
"github.com/okp4/okp4d/x/logic/interpreter"
"github.com/okp4/okp4d/x/logic/interpreter/bootstrap"
"github.com/okp4/okp4d/x/logic/meter"
Expand Down Expand Up @@ -113,16 +114,15 @@ func (k Keeper) newInterpreter(ctx goctx.Context) (*prolog.Interpreter, *util.Bo
interpreterParams := params.GetInterpreter()
gasPolicy := params.GetGasPolicy()
limits := params.GetLimits()

whitelist := util.NonZeroOrDefault(interpreterParams.PredicatesFilter.Whitelist, interpreter.RegistryNames)
blacklist := interpreterParams.PredicatesFilter.Blacklist
gasMeter := meter.WithWeightedMeter(sdkctx.GasMeter(), nonNilNorZeroOrDefaultUint64(gasPolicy.WeightingFactor, defaultWeightFactor))

whitelistPredicates := util.NonZeroOrDefault(interpreterParams.PredicatesFilter.Whitelist, interpreter.RegistryNames)
blacklistPredicates := interpreterParams.PredicatesFilter.Blacklist
predicates := lo.Reduce(
lo.Map(
lo.Filter(
interpreter.RegistryNames,
filterPredicates(whitelist, blacklist)),
util.Indexed(util.WhitelistBlacklistMatches(whitelistPredicates, blacklistPredicates, util.PredicateMatches))),
toPredicate(
nonNilNorZeroOrDefaultUint64(gasPolicy.DefaultPredicateCost, defaultPredicateCost),
gasPolicy.GetPredicateCosts())),
Expand All @@ -132,10 +132,17 @@ func (k Keeper) newInterpreter(ctx goctx.Context) (*prolog.Interpreter, *util.Bo
},
interpreter.Predicates{})

whitelistUrls := lo.Map(
util.NonZeroOrDefault(interpreterParams.VirtualFilesFilter.Whitelist, []string{}),
util.Indexed(util.ParseUrlMust))
blacklistUrls := lo.Map(
util.NonZeroOrDefault(interpreterParams.VirtualFilesFilter.Whitelist, []string{}),
util.Indexed(util.ParseUrlMust))

options := []interpreter.Option{
interpreter.WithPredicates(ctx, predicates, gasMeter),
interpreter.WithBootstrap(ctx, util.NonZeroOrDefault(interpreterParams.GetBootstrap(), bootstrap.Bootstrap())),
interpreter.WithFS(k.fsProvider(ctx)),
interpreter.WithFS(fs.NewFilteredFS(whitelistUrls, blacklistUrls, k.fsProvider(ctx))),
}

var userOutputBuffer *util.BoundedBuffer
Expand All @@ -149,21 +156,12 @@ func (k Keeper) newInterpreter(ctx goctx.Context) (*prolog.Interpreter, *util.Bo
return i, userOutputBuffer, err
}

// filterPredicates filters the given predicate (with arity) according to the given whitelist and blacklist.
// The whitelist and blacklist are applied to the registry to determine the final predicate list.
// The whitelist and blacklist can contain predicates with or without arity, e.g. "foo/0", "foo", "bar/1".
func filterPredicates(whitelist []string, blacklist []string) func(string, int) bool {
return func(predicate string, _ int) bool {
return lo.ContainsBy(whitelist, util.PredicateEq(predicate)) && !lo.ContainsBy(blacklist, util.PredicateEq(predicate))
}
}

// toPredicate converts the given predicate costs to a function that returns the cost for the given predicate as
// a pair of predicate name and cost.
func toPredicate(defaultCost uint64, predicateCosts []types.PredicateCost) func(string, int) lo.Tuple2[string, uint64] {
return func(predicate string, _ int) lo.Tuple2[string, uint64] {
for _, c := range predicateCosts {
if util.PredicateEq(predicate)(c.Predicate) {
if util.PredicateMatches(predicate)(c.Predicate) {
return lo.T2(predicate, nonNilNorZeroOrDefaultUint64(c.Cost, defaultCost))
}
}
Expand Down
95 changes: 0 additions & 95 deletions x/logic/keeper/interpreter_test.go

This file was deleted.

15 changes: 13 additions & 2 deletions x/logic/types/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package types

import (
"fmt"
"net/url"

"cosmossdk.io/math"
)
Expand Down Expand Up @@ -96,12 +97,22 @@ func WithBootstrap(bootstrap string) InterpreterOption {
}

func validateInterpreter(i interface{}) error {
_, ok := i.(Interpreter)
interpreter, ok := i.(Interpreter)
if !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}

// TODO: Validate interpreter params.
for _, file := range interpreter.VirtualFilesFilter.Whitelist {
if _, err := url.Parse(file); err != nil {
return fmt.Errorf("invalid virtual file in whitelist: %s", file)
}
}
for _, file := range interpreter.VirtualFilesFilter.Blacklist {
if _, err := url.Parse(file); err != nil {
return fmt.Errorf("invalid virtual file in blacklist: %s", file)
}
}

return nil
}

Expand Down
4 changes: 2 additions & 2 deletions x/logic/util/pointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package util

import "reflect"

// DerefOrDefault returns the value of the pointer if it is not nil, otherwise returns the default value.
// DerefOrDefault returns the values of the pointer if it is not nil, otherwise returns the default values.
func DerefOrDefault[T any](ptr *T, defaultValue T) T {
if ptr != nil {
return *ptr
}
return defaultValue
}

// NonZeroOrDefault returns the value of the argument if it is not nil and not zero, otherwise returns the default value.
// NonZeroOrDefault returns the values of the argument if it is not nil and not zero, otherwise returns the default values.
func NonZeroOrDefault[T any](v, defaultValue T) T {
v1 := reflect.ValueOf(v)
if v1.IsValid() && !v1.IsZero() {
Expand Down
12 changes: 6 additions & 6 deletions x/logic/util/pointer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,31 @@ import (
)

func TestDerefOrDefault(t *testing.T) {
Convey("Given a pointer to an int and a default int value", t, func() {
Convey("Given a pointer to an int and a default int values", t, func() {
x := 5
ptr := &x
defaultValue := 10

Convey("When the pointer is not nil", func() {
result := DerefOrDefault(ptr, defaultValue)

Convey("The result should be the value pointed to by the pointer", func() {
Convey("The result should be the values pointed to by the pointer", func() {
So(result, ShouldEqual, x)
})
})

Convey("When the pointer is nil", func() {
result := DerefOrDefault(nil, defaultValue)

Convey("The result should be the default value", func() {
Convey("The result should be the default values", func() {
So(result, ShouldEqual, defaultValue)
})
})
})
}

func TestNonZeroOrDefault(t *testing.T) {
Convey("Given a value", t, func() {
Convey("Given a values", t, func() {
cases := []struct {
v any
defaultValue any
Expand All @@ -45,8 +45,8 @@ func TestNonZeroOrDefault(t *testing.T) {
{"hello", "default", "hello"},
}
for _, tc := range cases {
Convey(fmt.Sprintf("When the value is %v", tc.v), func() {
Convey(fmt.Sprintf("Then the default value %v is returned", tc.defaultValue), func() {
Convey(fmt.Sprintf("When the values is %v", tc.v), func() {
Convey(fmt.Sprintf("Then the default values %v is returned", tc.defaultValue), func() {
So(NonZeroOrDefault(tc.v, tc.defaultValue), ShouldEqual, tc.expected)
})
})
Expand Down
14 changes: 8 additions & 6 deletions x/logic/util/prolog.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ func Resolve(env *engine.Env, t engine.Term) (engine.Atom, bool) {
}
}

// PredicateEq returns a function that matches the given predicate against the given other predicate.
// PredicateMatches returns a function that matches the given predicate against the given other predicate.
// If the other predicate contains a slash, it is matched as is. Otherwise, the other predicate is matched against the
// first part of the given predicate.
// For example:
// - matchPredicate("foo/0")("foo/0") -> true
// - matchPredicate("foo/0")("foo/1") -> false
// - matchPredicate("foo/0")("foo") -> true
// - matchPredicate("foo/0")("bar") -> false
func PredicateEq(predicate string) func(b string) bool {
return func(other string) bool {
if strings.Contains(other, "/") {
return predicate == other
//
// The function is curried, and is a binary relation that is reflexive, associative (but not commutative).
func PredicateMatches(this string) func(string) bool {
return func(that string) bool {
if strings.Contains(that, "/") {
return this == that
}
return strings.Split(predicate, "/")[0] == other
return strings.Split(this, "/")[0] == that
}
}

0 comments on commit 81c5c3e

Please sign in to comment.