Skip to content

Commit

Permalink
Merge pull request #17 from Mutated1994/master
Browse files Browse the repository at this point in the history
fix: Fix panic when calling PatchGuard.Restore
  • Loading branch information
cch123 authored Aug 1, 2021
2 parents d792ef7 + d203a03 commit 16a9bd9
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 11 deletions.
8 changes: 8 additions & 0 deletions examples/patch_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ func patchFunc() {
fmt.Println("unpatch, then output:")
heyHey()

patchGuard.Restore()
fmt.Println("restore, then output:")
heyHey()

patchGuard.Unpatch()
fmt.Println("unpatch, then output:")
heyHey()

fmt.Println()
}

Expand Down
8 changes: 8 additions & 0 deletions examples/patch_func_symbol.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ func patchFuncSymbol() {
fmt.Println("unpatch, then output:")
heyHeyHey()

patchGuard.Restore()
fmt.Println("restore, then output:")
heyHeyHey()

patchGuard.Unpatch()
fmt.Println("unpatch, then output:")
heyHeyHey()

fmt.Println()
}

Expand Down
8 changes: 8 additions & 0 deletions examples/patch_instance_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,13 @@ func patchInstanceFunc() {
fmt.Println("unpatch, then output:")
p.speak()

patchGuard.Restore()
fmt.Println("restore, then output:")
p.speak()

patchGuard.Unpatch()
fmt.Println("unpatch, then output:")
p.speak()

fmt.Println()
}
8 changes: 8 additions & 0 deletions examples/patch_instance_func_symbol.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,13 @@ func patchInstanceFuncSymbol() {
fmt.Println("unpatch, then output:")
p.speak()

patchGuard.Restore()
fmt.Println("restore, then output:")
p.speak()

patchGuard.Unpatch()
fmt.Println("unpatch, then output:")
p.speak()

fmt.Println()
}
10 changes: 9 additions & 1 deletion examples/patch_struct_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,13 @@ func patchStructMethod() {
fmt.Println("unpatch, then output:")
fmt.Println(f.MyFunc(nil))

patchGuard.Restore()
fmt.Println("restore, then output:")
fmt.Println(f.MyFunc(nil))

patchGuard.Unpatch()
fmt.Println("unpatch, then output:")
fmt.Println(f.MyFunc(nil))

fmt.Println()
}
}
9 changes: 8 additions & 1 deletion examples/patch_struct_method_symbol.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ func patchStructMethodSymbol() {
fmt.Println("unpatch, then output:")
fmt.Println(f.MyFunc(nil))

patchGuard.Restore()
fmt.Println("restore, then output:")
fmt.Println(f.MyFunc(nil))

patchGuard.Unpatch()
fmt.Println("unpatch, then output:")
fmt.Println(f.MyFunc(nil))

fmt.Println()
}

19 changes: 10 additions & 9 deletions internal/bouk/monkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,32 @@ func getPtr(v reflect.Value) unsafe.Pointer {
type PatchGuard struct {
target reflect.Value
replacement reflect.Value
isSymbol bool
}

func (g *PatchGuard) Unpatch() {
unpatchValue(g.target)
}

func (g *PatchGuard) Restore() {
patchValue(g.target, g.replacement)
patchValue(g.target, g.replacement, g.isSymbol)
}

// Patch replaces a function with another
func Patch(target, replacement interface{}) *PatchGuard {
t := reflect.ValueOf(target)
r := reflect.ValueOf(replacement)
patchValue(t, r)
patchValue(t, r, false)

return &PatchGuard{t, r}
return &PatchGuard{t, r, false}
}

func PatchSymbol(target, replacement interface{}) *PatchGuard {
t := reflect.ValueOf(target)
r := reflect.ValueOf(replacement)
patchSymbolValue(t, r)

return &PatchGuard{t, r}
return &PatchGuard{t, r, true}
}

// PatchInstanceMethod replaces an instance method methodName for the type target with replacement
Expand All @@ -67,24 +68,24 @@ func PatchInstanceMethod(target reflect.Type, methodName string, replacement int
panic(fmt.Sprintf("unknown method %s", methodName))
}
r := reflect.ValueOf(replacement)
patchValue(m.Func, r)
patchValue(m.Func, r, false)

return &PatchGuard{m.Func, r}
return &PatchGuard{m.Func, r, false}
}

func patchValue(target, replacement reflect.Value) {
func patchValue(target, replacement reflect.Value, isSymbol bool) {
lock.Lock()
defer lock.Unlock()

if target.Kind() != reflect.Func {
if !isSymbol && target.Kind() != reflect.Func {
panic("target has to be a Func")
}

if replacement.Kind() != reflect.Func {
panic("replacement has to be a Func")
}

if target.Type() != replacement.Type() {
if !isSymbol && target.Type() != replacement.Type() {
panic(fmt.Sprintf("target and replacement have to have the same type %s != %s", target.Type(), replacement.Type()))
}

Expand Down

0 comments on commit 16a9bd9

Please sign in to comment.