Skip to content

Commit

Permalink
Fix nil pointer for Query
Browse files Browse the repository at this point in the history
  • Loading branch information
otiai10 committed Jun 24, 2023
1 parent 823aebc commit 8d16417
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
2 changes: 2 additions & 0 deletions all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,6 @@ func TestTestee_Query(t *testing.T) {
mint.Expect(t, "foo").Query("foo").ToBe("foo")
mint.Expect(t, "foo").Query("bar").Not().ToBe("bar")
mint.Expect(t, v).Query("foo.name").ToBe("otiai10")
mint.Expect(t, v).Query("foo.age").ToBe(30)
mint.Expect(t, v).Query("foo.baa").ToBe(nil)
}
18 changes: 14 additions & 4 deletions mquery/mquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,27 @@ func queryMap(m any, t reflect.Type, qs []string) any {
if len(qs) == 0 {
return m
}
val := reflect.ValueOf(m)
if val.IsZero() {
return nil
}
switch t.Key().Kind() {
case reflect.String:
next := reflect.ValueOf(m).MapIndex(reflect.ValueOf(qs[0])).Interface()
return query(next, qs[1:])
val := reflect.ValueOf(m).MapIndex(reflect.ValueOf(qs[0]))
if !val.IsValid() {
return nil
}
return query(val.Interface(), qs[1:])
case reflect.Int:
i, err := strconv.Atoi(qs[0])
if err != nil {
return fmt.Errorf("cannot access map with keyword: %s: %v", qs[0], err)
}
next := reflect.ValueOf(m).MapIndex(reflect.ValueOf(i)).Interface()
return query(next, qs[1:])
val := reflect.ValueOf(m).MapIndex(reflect.ValueOf(i))
if !val.IsValid() {
return nil
}
return query(val.Interface(), qs[1:])
}
return nil
}
Expand Down
7 changes: 5 additions & 2 deletions testee.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ type Testee struct {
required bool
verbose bool

queriedFrom string // Only used when querying
// origin string // Only used when querying
}

// Query queries the actual value with given query string.
func (testee *Testee) Query(query string) *Testee {
testee.queriedFrom = fmt.Sprintf("queried from %T", testee.actual)
// testee.origin = fmt.Sprintf("%T", testee.actual)
testee.actual = mquery.Query(testee.actual, query)
return testee
}
Expand Down Expand Up @@ -122,6 +122,9 @@ func (testee *Testee) toText(fail int) string {
not = "NOT "
}
_, file, line, _ := runtime.Caller(3)
// if testee.origin != "" {
// testee.origin = fmt.Sprintf("(queried from %s)", testee.origin)
// }
return fmt.Sprintf(
scolds[fail],
filepath.Base(file), line,
Expand Down

0 comments on commit 8d16417

Please sign in to comment.