Skip to content

Commit

Permalink
fix(dispatch): Add default source resolver to Attributes
Browse files Browse the repository at this point in the history
Attributes has a field for the default source for the method. This lets us remove the WithSource calls throughout the cmd package, and makes tests more readable.
  • Loading branch information
dustmop committed Apr 16, 2021
1 parent 5bded08 commit e074bf4
Show file tree
Hide file tree
Showing 19 changed files with 211 additions and 81 deletions.
2 changes: 1 addition & 1 deletion cmd/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dataset version(s). By default pull fetches the latest version of a dataset.
}

cmd.Flags().StringVar(&o.LinkDir, "link", "", "path to directory to link dataset to")
cmd.Flags().StringVar(&o.Source, "source", "network", "location to pull from")
cmd.Flags().StringVar(&o.Source, "source", "", "location to pull from")
cmd.MarkFlagFilename("link")
cmd.Flags().BoolVar(&o.LogsOnly, "logs-only", false, "only fetch logs, skipping HEAD data")

Expand Down
2 changes: 1 addition & 1 deletion cmd/remove.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (o *RemoveOptions) Run() (err error) {
}

ctx := context.TODO()
res, err := o.inst.WithSource("local").Dataset().Remove(ctx, &params)
res, err := o.inst.Dataset().Remove(ctx, &params)
if err != nil {
// TODO(b5): move this error handling down into lib
if errors.Is(err, dsref.ErrRefNotFound) {
Expand Down
2 changes: 1 addition & 1 deletion cmd/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (o *ValidateOptions) Run() (err error) {
}

ctx := context.TODO()
res, err := o.inst.WithSource("local").Dataset().Validate(ctx, p)
res, err := o.inst.Dataset().Validate(ctx, p)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion lib/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (m AccessMethods) Name() string {
// Attributes defines attributes for each method
func (m AccessMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
"createauthtoken": {AECreateAuthToken, "GET"},
"createauthtoken": {AECreateAuthToken, "GET", "local"},
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/automation.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (m AutomationMethods) Name() string {
// Attributes defines attributes for each method
func (m AutomationMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
"apply": {AEApply, "POST"},
"apply": {AEApply, "POST", ""},
}
}

Expand Down
4 changes: 2 additions & 2 deletions lib/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func (m CollectionMethods) Name() string {
// Attributes defines attributes for each method
func (m CollectionMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
"list": {AEList, "POST"},
"listrawrefs": {denyRPC, ""},
"list": {AEList, "POST", ""},
"listrawrefs": {denyRPC, "", ""},
}
}

Expand Down
6 changes: 3 additions & 3 deletions lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ func (m ConfigMethods) Name() string {
func (m ConfigMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
// config methods are not allowed over HTTP nor RPC
"getconfig": {denyRPC, ""},
"getconfigkeys": {denyRPC, ""},
"setconfig": {denyRPC, ""},
"getconfig": {denyRPC, "", ""},
"getconfigkeys": {denyRPC, "", ""},
"setconfig": {denyRPC, "", ""},
}
}

Expand Down
25 changes: 11 additions & 14 deletions lib/datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ func (m DatasetMethods) Name() string {
// Attributes defines attributes for each method
func (m DatasetMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
"componentstatus": {AEComponentStatus, "POST"},
"get": {AEGet, "GET"},
"componentstatus": {AEComponentStatus, "POST", ""},
"get": {AEGet, "GET", ""},
// "log": {AELog, "POST"},
"rename": {AERename, "POST"},
"save": {AESave, "POST"},
"pull": {AEPull, "POST"},
"rename": {AERename, "POST", ""},
"save": {AESave, "POST", ""},
"pull": {AEPull, "POST", "network"},
// "push": {AEPush, "POST"},
"render": {AERender, "POST"},
"remove": {AERemove, "POST"},
"validate": {AEValidate, "POST"},
"render": {AERender, "POST", ""},
"remove": {AERemove, "POST", "local"},
"validate": {AEValidate, "POST", ""},
// "unpack": {AEUnpack, "POST"},
"manifest": {AEManifest, "POST"},
"manifestmissing": {AEManifestMissing, "POST"},
"daginfo": {AEDAGInfo, "POST"},
"manifest": {AEManifest, "POST", ""},
"manifestmissing": {AEManifestMissing, "POST", ""},
"daginfo": {AEDAGInfo, "POST", ""},
}
}

Expand Down Expand Up @@ -1311,9 +1311,6 @@ func (datasetImpl) Validate(scope scope, p *ValidateParams) (*ValidateResponse,
if p.Ref == "" && (p.BodyFilename == "" || schemaFlagType == "") {
return nil, qrierr.New(ErrBadArgs, "please provide a dataset name, or a supply the --body and --schema or --structure flags")
}
if scope.SourceName() != "local" {
return nil, fmt.Errorf("validate requires 'local' source")
}

fsiPath := ""
var err error
Expand Down
4 changes: 2 additions & 2 deletions lib/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func (m DiffMethods) Name() string {
// Attributes defines attributes for each method
func (m DiffMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
"changes": {AEChanges, "POST"},
"diff": {AEDiff, "POST"},
"changes": {AEChanges, "POST", ""},
"diff": {AEDiff, "POST", ""},
}
}

Expand Down
55 changes: 38 additions & 17 deletions lib/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ type MethodSet interface {
// Each method is required to have associated attributes in order to successfully register
type AttributeSet struct {
endpoint APIEndpoint
verb string
httpVerb string
// the default source used for resolving references
defaultSource string
}

// Dispatch is a system for handling calls to lib. Should only be called by top-level lib methods.
Expand Down Expand Up @@ -144,6 +146,11 @@ func (inst *Instance) dispatchMethodCall(ctx context.Context, method string, par

// Look up the method for the given signifier
if c, ok := inst.regMethods.lookup(method); ok {
// If this method has a default source and no override exists, use that
// default instead
if source == "" {
source = c.Source
}
// Construct the isolated scope for this call
// TODO(dustmop): Add user authentication, profile, identity, etc
// TODO(dustmop): Also determine if the method is read-only vs read-write,
Expand Down Expand Up @@ -237,12 +244,14 @@ type callable struct {
RetCursor bool
Endpoint APIEndpoint
Verb string
Source string
}

// RegisterMethods iterates the methods provided by the lib API, and makes them visible to dispatch
func (inst *Instance) RegisterMethods() {
reg := make(map[string]callable)
inst.registerOne("access", inst.Access(), accessImpl{}, reg)
inst.registerOne("automation", inst.Automation(), automationImpl{}, reg)
inst.registerOne("collection", inst.Collection(), collectionImpl{}, reg)
inst.registerOne("config", inst.Config(), configImpl{}, reg)
inst.registerOne("dataset", inst.Dataset(), datasetImpl{}, reg)
Expand All @@ -255,7 +264,6 @@ func (inst *Instance) RegisterMethods() {
inst.registerOne("remote", inst.Remote(), remoteImpl{}, reg)
inst.registerOne("search", inst.Search(), searchImpl{}, reg)
inst.registerOne("sql", inst.SQL(), sqlImpl{}, reg)
inst.registerOne("automation", inst.Automation(), automationImpl{}, reg)
inst.regMethods = &regMethodSet{reg: reg}
}

Expand Down Expand Up @@ -384,25 +392,13 @@ func (inst *Instance) registerOne(ourName string, methods MethodSet, impl interf
// Remove this method from the methodSetMap now that it has been processed
delete(methodMap, i.Name)

var endpoint APIEndpoint
var httpVerb string
// Additional attributes for the method are found in the Attributes
amap := methods.Attributes()
methodAttrs, ok := amap[lowerName]
if !ok {
regFail("not in Attributes: %s.%s", ourName, lowerName)
}
endpoint = methodAttrs.endpoint
httpVerb = methodAttrs.verb
// If both these are empty string, RPC is not allowed for this method
if endpoint != "" || httpVerb != "" {
if !strings.HasPrefix(string(endpoint), "/") {
regFail("%s: endpoint URL must start with /, got %q", lowerName, endpoint)
}
if httpVerb != http.MethodGet && httpVerb != http.MethodPost && httpVerb != http.MethodPut {
regFail("%s: unknown http verb, got %q", lowerName, httpVerb)
}
}
validateMethodAttrs(lowerName, methodAttrs)

// Save the method to the registration table
reg[funcName] = callable{
Expand All @@ -411,8 +407,9 @@ func (inst *Instance) registerOne(ourName string, methods MethodSet, impl interf
InType: inType,
OutType: outType,
RetCursor: returnsCursor,
Endpoint: endpoint,
Verb: httpVerb,
Endpoint: methodAttrs.endpoint,
Verb: methodAttrs.httpVerb,
Source: methodAttrs.defaultSource,
}
log.Debugf("%d: registered %s(*%s) %v", k, funcName, inType, outType)
}
Expand All @@ -428,6 +425,30 @@ func regFail(fstr string, vals ...interface{}) {
panic(fmt.Sprintf(fstr, vals...))
}

func validateMethodAttrs(methodName string, attrs AttributeSet) {
// If endpoint and verb are not set, then RPC is denied, nothing to validate
// TODO(dustmop): Technically this is denying all HTTP, not just RPC. Consider
// separating HTTP and RPC denial
if attrs.endpoint == "" && attrs.httpVerb == "" {
return
}
if !strings.HasPrefix(string(attrs.endpoint), "/") {
regFail("%s: endpoint URL must start with /, got %q", methodName, attrs.endpoint)
}
if !stringOneOf(attrs.httpVerb, []string{http.MethodGet, http.MethodPost, http.MethodPut}) {
regFail("%s: unknown http verb, got %q", methodName, attrs.httpVerb)
}
}

func stringOneOf(needle string, haystack []string) bool {
for _, each := range haystack {
if needle == each {
return true
}
}
return false
}

func (inst *Instance) buildMethodMap(impl interface{}) map[string]reflect.Method {
result := make(map[string]reflect.Method)
implType := reflect.TypeOf(impl)
Expand Down
126 changes: 119 additions & 7 deletions lib/dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,61 @@ func TestVariadicReturnsWorkOverHTTP(t *testing.T) {
}
}

func TestDefaultSource(t *testing.T) {
ctx := context.Background()

inst, cleanup := NewMemTestInstance(ctx, t)
defer cleanup()
m := &getSrcMethods{d: inst}

reg := make(map[string]callable)
inst.registerOne("getsrc", m, getSrcImpl{}, reg)
inst.regMethods = &regMethodSet{reg: reg}

// Construct another methodSet with a source override
withSource := getSrcMethods{d: &dispatchSourceWrap{source: "registry", inst: inst}}

// Call One with the default source
got, err := m.One(ctx, &getSrcParams{})
if err != nil {
t.Fatalf("m.One call failed, err=%s", err)
}
expect := `one source=""`
if got != expect {
t.Errorf("value mismatch, expect: %s, got: %s", expect, got)
}

// Call One with a source override
got, err = withSource.One(ctx, &getSrcParams{})
if err != nil {
t.Fatalf("m.One call failed, err=%s", err)
}
expect = `one source="registry"`
if got != expect {
t.Errorf("value mismatch, expect: %s, got: %s", expect, got)
}

// Call Two with the default source
got, err = m.Two(ctx, &getSrcParams{})
if err != nil {
t.Fatalf("m.Two call failed, err=%s", err)
}
expect = `two source="network"`
if got != expect {
t.Errorf("value mismatch, expect: %s, got: %s", expect, got)
}

// Call Two with a source override
got, err = withSource.Two(ctx, &getSrcParams{})
if err != nil {
t.Fatalf("m.Two call failed, err=%s", err)
}
expect = `two source="registry"`
if got != expect {
t.Errorf("value mismatch, expect: %s, got: %s", expect, got)
}
}

func serverConnectAndListen(t *testing.T, servInst *Instance, port int) (*HTTPClient, func()) {
address := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", port)
connection, err := NewHTTPClient(address)
Expand Down Expand Up @@ -232,6 +287,16 @@ func expectToPanic(t *testing.T, regFunc func(), expectMessage string) {
}
}

// A dispatcher that wraps the instance and sets a different source
type dispatchSourceWrap struct {
source string
inst *Instance
}

func (dsw *dispatchSourceWrap) Dispatch(ctx context.Context, method string, param interface{}) (res interface{}, cur Cursor, err error) {
return dsw.inst.dispatchMethodCall(ctx, method, param, dsw.source)
}

// Test data: methodSet and implementation
type animalMethods struct {
d dispatcher
Expand All @@ -243,8 +308,8 @@ func (m *animalMethods) Name() string {

func (m *animalMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
"cat": {denyRPC, ""},
"dog": {denyRPC, ""},
"cat": {denyRPC, "", ""},
"dog": {denyRPC, "", ""},
}
}

Expand Down Expand Up @@ -341,12 +406,12 @@ func (m *fruitMethods) Name() string {

func (m *fruitMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
"apple": {"/apple", "GET"},
"banana": {"/banana", "GET"},
"cherry": {"/cherry", "GET"},
"date": {"/date", "GET"},
"apple": {"/apple", "GET", ""},
"banana": {"/banana", "GET", ""},
"cherry": {"/cherry", "GET", ""},
"date": {"/date", "GET", ""},
// entawak cannot be called over RPC
"entawak": {denyRPC, ""},
"entawak": {denyRPC, "", ""},
}
}

Expand Down Expand Up @@ -412,3 +477,50 @@ func (fruitImpl) Date(scp scope, p *fruitParams) (string, Cursor, error) {
func (fruitImpl) Entawak(scp scope, p *fruitParams) (string, Cursor, error) {
return "mentawa", nil, nil
}

// MethodSet for methods that return the source being used for resolution
type getSrcMethods struct {
d dispatcher
}

func (m *getSrcMethods) Name() string {
return "getsrc"
}

func (m *getSrcMethods) Attributes() map[string]AttributeSet {
return map[string]AttributeSet{
"one": {"/one", "GET", ""},
"two": {"/two", "GET", "network"},
}
}

type getSrcParams struct {
Name string
}

func (m *getSrcMethods) One(ctx context.Context, p *getSrcParams) (string, error) {
got, _, err := m.d.Dispatch(ctx, dispatchMethodName(m, "one"), p)
if res, ok := got.(string); ok {
return res, err
}
return "", dispatchReturnError(got, err)
}

func (m *getSrcMethods) Two(ctx context.Context, p *getSrcParams) (string, error) {
got, _, err := m.d.Dispatch(ctx, dispatchMethodName(m, "two"), p)
if res, ok := got.(string); ok {
return res, err
}
return "", dispatchReturnError(got, err)
}

// Implementation for get source methods
type getSrcImpl struct{}

func (getSrcImpl) One(scp scope, p *getSrcParams) (string, error) {
return fmt.Sprintf("one source=%q", scp.SourceName()), nil
}

func (getSrcImpl) Two(scp scope, p *getSrcParams) (string, error) {
return fmt.Sprintf("two source=%q", scp.SourceName()), nil
}
Loading

0 comments on commit e074bf4

Please sign in to comment.