Skip to content

Commit

Permalink
refactor: reuse the lua vm and pr-compile the lua code (#753)
Browse files Browse the repository at this point in the history
* refactor: reuse the lua vm and pr-compile the lua code

* fix the benchmark code

* update the code
  • Loading branch information
Zheaoli authored Jan 2, 2024
1 parent 0f3a45f commit 10068d3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 44 deletions.
10 changes: 5 additions & 5 deletions cmd/redis-shake/main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package main

import (
"context"
_ "net/http/pprof"
"os"
"os/signal"
"syscall"
"context"
_ "net/http/pprof"

"RedisShake/internal/config"
"RedisShake/internal/entry"
Expand All @@ -26,7 +26,7 @@ func main() {
utils.ChdirAndAcquireFileLock()
utils.SetNcpu()
utils.SetPprofPort()
function.Init()
luaRuntime := function.New(config.Opt.Function)

// create reader
var theReader reader.Reader
Expand Down Expand Up @@ -125,7 +125,7 @@ func main() {

// filter
log.Debugf("function before: %v", e)
entries := function.RunFunction(e)
entries := luaRuntime.RunFunction(e)
log.Debugf("function after: %v", entries)

for _, entry := range entries {
Expand All @@ -146,4 +146,4 @@ func waitShutdown(cancel context.CancelFunc) {
sig := <-quitCh
log.Infof("Got signal: %s to exit.", sig)
cancel()
}
}
81 changes: 49 additions & 32 deletions internal/function/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,40 @@ package function

import (
"strings"
"sync"

"RedisShake/internal/config"
"RedisShake/internal/entry"
"RedisShake/internal/log"

lua "github.com/yuin/gopher-lua"
"github.com/yuin/gopher-lua/parse"
)

var luaString string
type Runtime struct {
luaVMPool *sync.Pool
compiledFunction *lua.FunctionProto
}

func Init() {
luaString = config.Opt.Function
luaString = strings.TrimSpace(luaString)
if len(luaString) == 0 {
log.Infof("no function script")
return
func New(luaCode string) *Runtime {
if len(luaCode) == 0 {
return nil
}
luaCode = strings.TrimSpace(luaCode)
chunk, err := parse.Parse(strings.NewReader(luaCode), "<string>")
if err != nil {
log.Panicf("parse lua code failed: %v", err)
}
codeObject, err := lua.Compile(chunk, "<string>")
if err != nil {
log.Panicf("compile lua code failed: %v", err)
}
return &Runtime{
luaVMPool: &sync.Pool{
New: func() interface{} {
return lua.NewState()
},
},
compiledFunction: codeObject,
}
}

Expand All @@ -32,41 +50,40 @@ func Init() {
// shake.call(DB, ARGV)
// shake.log()

func RunFunction(e *entry.Entry) []*entry.Entry {
entries := make([]*entry.Entry, 0)
if len(luaString) == 0 {
entries = append(entries, e)
return entries
func (runtime *Runtime) RunFunction(e *entry.Entry) []*entry.Entry {
if runtime == nil {
return []*entry.Entry{e}
}

L := lua.NewState()
L.SetGlobal("DB", lua.LNumber(e.DbId))
L.SetGlobal("GROUP", lua.LString(e.Group))
L.SetGlobal("CMD", lua.LString(e.CmdName))
keys := L.NewTable()
entries := make([]*entry.Entry, 0)
luaState := runtime.luaVMPool.Get().(*lua.LState)
defer runtime.luaVMPool.Put(luaState)
luaState.SetGlobal("DB", lua.LNumber(e.DbId))
luaState.SetGlobal("GROUP", lua.LString(e.Group))
luaState.SetGlobal("CMD", lua.LString(e.CmdName))
keys := luaState.NewTable()
for _, key := range e.Keys {
keys.Append(lua.LString(key))
}
L.SetGlobal("KEYS", keys)
slots := L.NewTable()
luaState.SetGlobal("KEYS", keys)
slots := luaState.NewTable()
for _, slot := range e.Slots {
slots.Append(lua.LNumber(slot))
}
keyIndexes := L.NewTable()
keyIndexes := luaState.NewTable()
for _, keyIndex := range e.KeyIndexes {
keyIndexes.Append(lua.LNumber(keyIndex))
}
L.SetGlobal("KEY_INDEXES", keyIndexes)
L.SetGlobal("SLOTS", slots)
argv := L.NewTable()
luaState.SetGlobal("KEY_INDEXES", keyIndexes)
luaState.SetGlobal("SLOTS", slots)
argv := luaState.NewTable()
for _, arg := range e.Argv {
argv.Append(lua.LString(arg))
}
L.SetGlobal("ARGV", argv)
shake := L.NewTypeMetatable("shake")
L.SetGlobal("shake", shake)
luaState.SetGlobal("ARGV", argv)
shake := luaState.NewTypeMetatable("shake")
luaState.SetGlobal("shake", shake)

L.SetField(shake, "call", L.NewFunction(func(ls *lua.LState) int {
luaState.SetField(shake, "call", luaState.NewFunction(func(ls *lua.LState) int {
db := ls.ToInt(1)
argv := ls.ToTable(2)
var argvStrings []string
Expand All @@ -79,12 +96,12 @@ func RunFunction(e *entry.Entry) []*entry.Entry {
})
return 0
}))
L.SetField(shake, "log", L.NewFunction(func(ls *lua.LState) int {
luaState.SetField(shake, "log", luaState.NewFunction(func(ls *lua.LState) int {
log.Infof("lua log: %v", ls.ToString(1))
return 0
}))
err := L.DoString(luaString)
if err != nil {
luaState.Push(luaState.NewFunctionFromProto(runtime.compiledFunction))
if err := luaState.PCall(0, lua.MultRet, nil); err != nil {
log.Panicf("load function script failed: %v", err)
}

Expand Down
15 changes: 8 additions & 7 deletions internal/function/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
// Command is `go test -benchmem -bench="RunFunction$" -count=5 RedisShake/internal/function`
// Output is:
//
// BenchmarkRunFunction-16 6741 182470 ns/op 234715 B/op 1079 allocs/op
// BenchmarkRunFunction-16 7443 174567 ns/op 234710 B/op 1079 allocs/op
// BenchmarkRunFunction-16 7101 178651 ns/op 234711 B/op 1079 allocs/op
// BenchmarkRunFunction-16 6856 164739 ns/op 234722 B/op 1079 allocs/op
// BenchmarkRunFunction-16 6804 174768 ns/op 234713 B/op 1079 allocs/op
// cpu: Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
// BenchmarkRunFunction-16 152046 8494 ns/op 15283 B/op 42 allocs/op
// BenchmarkRunFunction-16 150916 7630 ns/op 15274 B/op 42 allocs/op
// BenchmarkRunFunction-16 149980 8467 ns/op 15292 B/op 42 allocs/op
// BenchmarkRunFunction-16 158834 7722 ns/op 15278 B/op 42 allocs/op
// BenchmarkRunFunction-16 118228 8482 ns/op 15292 B/op 42 allocs/op
func BenchmarkRunFunction(b *testing.B) {
config.Opt = config.ShakeOptions{
Function: `
Expand All @@ -33,7 +34,7 @@ end
shake.call(DB, ARGV)
`,
}
Init()
luaRuntime := New(config.Opt.Function)
e := &entry.Entry{
DbId: 0,
Argv: []string{"set", "mlpSummary:1", "1"},
Expand All @@ -47,6 +48,6 @@ shake.call(DB, ARGV)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
RunFunction(e)
luaRuntime.RunFunction(e)
}
}

0 comments on commit 10068d3

Please sign in to comment.