Skip to content

Commit

Permalink
Equal & Len assert better messages
Browse files Browse the repository at this point in the history
  • Loading branch information
lainio committed Sep 1, 2024
1 parent 4fc7ffd commit 4a427d4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
38 changes: 25 additions & 13 deletions assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ var (
)

const (
assertionMsg = "assertion violation"
assertionMsg = "assertion violation"
assertionEqualMsg = "assert equal"
assertionNotEqualMsg = "assert not equal"
assertionLenMsg = "assert len"

gotWantFmt = ": got '%v', want '%v'"
gotWantLongerFmt = ": got '%v', should be longer than '%v'"
gotWantShorterFmt = ": got '%v', should be shorter than '%v'"
Expand Down Expand Up @@ -438,7 +442,7 @@ func MNotNil[M ~map[T]U, T comparable, U any](m M, a ...any) {
// assert violation message.
func NotEqual[T comparable](val, want T, a ...any) {
if want == val {
doShouldNotBeEqual(val, want, a)
doShouldNotBeEqual(assertionNotEqualMsg, val, want, a)
}
}

Expand All @@ -450,17 +454,17 @@ func NotEqual[T comparable](val, want T, a ...any) {
// are used to override the auto-generated assert violation message.
func Equal[T comparable](val, want T, a ...any) {
if want != val {
doShouldBeEqual(val, want, a)
doShouldBeEqual(assertionEqualMsg, val, want, a)
}
}

func doShouldBeEqual[T comparable](val, want T, a []any) {
defMsg := fmt.Sprintf(assertionMsg+gotWantFmt, val, want)
func doShouldBeEqual[T comparable](aname string, val, want T, a []any) {
defMsg := fmt.Sprintf(aname+gotWantFmt, val, want)
current().reportAssertionFault(1, defMsg, a)
}

func doShouldNotBeEqual[T comparable](val, want T, a []any) {
defMsg := fmt.Sprintf(assertionMsg+": got '%v' want (!= '%v')", val, want)
func doShouldNotBeEqual[T comparable](aname string, val, want T, a []any) {
defMsg := fmt.Sprintf(aname+": got '%v' want (!= '%v')", val, want)
current().reportAssertionFault(1, defMsg, a)
}

Expand Down Expand Up @@ -492,7 +496,11 @@ func DeepEqual(val, want any, a ...any) {
// assert.DeepEqual(pubKey, ed25519.PublicKey(pubKeyBytes))
func NotDeepEqual(val, want any, a ...any) {
if reflect.DeepEqual(val, want) {
defMsg := fmt.Sprintf(assertionMsg+": got '%v', want (!= '%v')", val, want)
defMsg := fmt.Sprintf(
assertionMsg+": got '%v', want (!= '%v')",
val,
want,
)
current().reportAssertionFault(0, defMsg, a)
}
}
Expand All @@ -511,7 +519,7 @@ func Len(obj string, length int, a ...any) {
l := len(obj)

if l != length {
doShouldBeEqual(l, length, a)
doShouldBeEqual(assertionLenMsg, l, length, a)
}
}

Expand Down Expand Up @@ -575,7 +583,7 @@ func SLen[S ~[]T, T any](obj S, length int, a ...any) {
l := len(obj)

if l != length {
doShouldBeEqual(l, length, a)
doShouldBeEqual(assertionLenMsg, l, length, a)
}
}

Expand Down Expand Up @@ -629,7 +637,7 @@ func MLen[M ~map[T]U, T comparable, U any](obj M, length int, a ...any) {
l := len(obj)

if l != length {
doShouldBeEqual(l, length, a)
doShouldBeEqual(assertionLenMsg, l, length, a)
}
}

Expand Down Expand Up @@ -683,7 +691,7 @@ func CLen[C ~chan T, T any](obj C, length int, a ...any) {
l := len(obj)

if l != length {
doShouldBeEqual(l, length, a)
doShouldBeEqual(assertionLenMsg, l, length, a)
}
}

Expand Down Expand Up @@ -728,7 +736,11 @@ func CShorter[C ~chan T, T any](obj C, length int, a ...any) {
//
// Note that when [Plain] asserter is used ([SetDefault]), optional arguments
// are used to override the auto-generated assert violation message.
func MKeyExists[M ~map[T]U, T comparable, U any](obj M, key T, a ...any) (val U) {
func MKeyExists[M ~map[T]U, T comparable, U any](
obj M,
key T,
a ...any,
) (val U) {
var ok bool
val, ok = obj[key]

Expand Down
4 changes: 2 additions & 2 deletions assert/assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func ExampleEqual() {
}
err := sample([]byte{1, 2})
fmt.Printf("%v", err)
// Output: sample: assert_test.go:80: ExampleEqual.func1(): assertion violation: got '2', want '3'
// Output: sample: assert_test.go:80: ExampleEqual.func1(): assert equal: got '2', want '3'
}

func ExampleSLen() {
Expand All @@ -94,7 +94,7 @@ func ExampleSLen() {
}
err := sample([]byte{1, 2})
fmt.Printf("%v", err)
// Output: sample: assert_test.go:92: ExampleSLen.func1(): assertion violation: got '2', want '3'
// Output: sample: assert_test.go:92: ExampleSLen.func1(): assert len: got '2', want '3'
}

func ExampleSNotEmpty() {
Expand Down

0 comments on commit 4a427d4

Please sign in to comment.