Skip to content

Commit

Permalink
refactor: change host function API and hide GuestRuntime (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhmd-azeez authored Aug 29, 2023
1 parent a2a20c6 commit afbb83e
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 70 deletions.
8 changes: 4 additions & 4 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ type Plugin struct {
LastStatusCode int
log func(LogLevel, string)
logLevel LogLevel
guestRuntime GuestRuntime
guestRuntime guestRuntime
}

func logStd(level LogLevel, message string) {
Expand Down Expand Up @@ -363,7 +363,7 @@ func NewPlugin(
log: logStd,
logLevel: logLevel}

p.guestRuntime = guestRuntime(p)
p.guestRuntime = detectGuestRuntime(p)
return p, nil
}

Expand Down Expand Up @@ -470,8 +470,8 @@ func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) {
}

var isStart = name == "_start"
if plugin.guestRuntime.Init != nil && !isStart && !plugin.guestRuntime.initialized {
err := plugin.guestRuntime.Init()
if plugin.guestRuntime.init != nil && !isStart && !plugin.guestRuntime.initialized {
err := plugin.guestRuntime.init()
if err != nil {
return 1, []byte{}, errors.New(fmt.Sprintf("failed to initialize runtime: %v", err))
}
Expand Down
56 changes: 28 additions & 28 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,18 @@ func TestExit(t *testing.T) {
func TestHost_simple(t *testing.T) {
manifest := manifest("host.wasm")

mult := HostFunction{
Name: "mult",
Namespace: "env",
Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
mult := NewHostFunctionWithStack(
"mult",
"env",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
a := api.DecodeI32(stack[0])
b := api.DecodeI32(stack[1])

stack[0] = api.EncodeI32(a * b)
},
Params: []api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
Results: []api.ValueType{api.ValueTypeI64},
}
[]api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
api.ValueTypeI64,
)

if plugin, ok := plugin(t, manifest, mult); ok {
defer plugin.Close()
Expand All @@ -254,10 +254,10 @@ func TestHost_simple(t *testing.T) {
func TestHost_memory(t *testing.T) {
manifest := manifest("host_memory.wasm")

mult := HostFunction{
Name: "to_upper",
Namespace: "host",
Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
mult := NewHostFunctionWithStack(
"to_upper",
"host",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
offset := stack[0]
buffer, err := plugin.ReadBytes(offset)
if err != nil {
Expand All @@ -276,9 +276,9 @@ func TestHost_memory(t *testing.T) {

stack[0] = offset
},
Params: []api.ValueType{api.ValueTypeI64},
Results: []api.ValueType{api.ValueTypeI64},
}
[]api.ValueType{api.ValueTypeI64},
api.ValueTypeI64,
)

if plugin, ok := plugin(t, manifest, mult); ok {
defer plugin.Close()
Expand All @@ -302,10 +302,10 @@ func TestHost_multiple(t *testing.T) {
EnableWasi: true,
}

green_message := HostFunction{
Name: "hostGreenMessage",
Namespace: "env",
Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
green_message := NewHostFunctionWithStack(
"hostGreenMessage",
"env",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
offset := stack[0]
input, err := plugin.ReadString(offset)

Expand All @@ -324,14 +324,14 @@ func TestHost_multiple(t *testing.T) {

stack[0] = offset
},
Params: []api.ValueType{api.ValueTypeI64},
Results: []api.ValueType{api.ValueTypeI64},
}
[]api.ValueType{api.ValueTypeI64},
api.ValueTypeI64,
)

purple_message := HostFunction{
Name: "hostPurpleMessage",
Namespace: "env",
Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
purple_message := NewHostFunctionWithStack(
"hostPurpleMessage",
"env",
func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
offset := stack[0]
input, err := plugin.ReadString(offset)

Expand All @@ -350,9 +350,9 @@ func TestHost_multiple(t *testing.T) {

stack[0] = offset
},
Params: []api.ValueType{api.ValueTypeI64},
Results: []api.ValueType{api.ValueTypeI64},
}
[]api.ValueType{api.ValueTypeI64},
api.ValueTypeI64,
)

hostFunctions := []HostFunction{
purple_message,
Expand Down
56 changes: 36 additions & 20 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type ValType = api.ValueType
const I32 = api.ValueTypeI32
const I64 = api.ValueTypeI64

// HostFunctionCallback is a Function implemented in Go instead of a wasm binary.
// HostFunctionStackCallback is a Function implemented in Go instead of a wasm binary.
// The plugin parameter is the calling plugin, used to access memory or
// exported functions and logging.
//
Expand All @@ -45,30 +45,47 @@ const I64 = api.ValueTypeI64
//
// To safely decode/encode values from/to the uint64 inputs/ouputs, users are encouraged to use
// Wazero's api.EncodeXXX or api.DecodeXXX functions.
type HostFunctionCallback func(ctx context.Context, p *CurrentPlugin, userData interface{}, stack []uint64)
type HostFunctionStackCallback func(ctx context.Context, p *CurrentPlugin, stack []uint64)

// HostFunction represents a custom function defined by the host.
type HostFunction struct {
stackCallback HostFunctionStackCallback
Name string
Namespace string
Params []api.ValueType
Returns []api.ValueType
}

// NewHostFunctionWithStack creates a new instance of a HostFunction, which is designed
// to provide custom functionality in a given host environment.
// Here's an example multiplication function that loads operands from memory:
//
// mult := HostFunction{
// Name: "mult",
// Namespace: "env",
// Callback: func(ctx context.Context, plugin *CurrentPlugin, userData interface{}, stack []uint64) {
// mult := NewHostFunctionWithStack(
// "mult",
// "env",
// func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
// a := api.DecodeI32(stack[0])
// b := api.DecodeI32(stack[1])
//
// stack[0] = api.EncodeI32(a * b)
// },
// Params: []api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
// Results: []api.ValueType{api.ValueTypeI64},
// }
type HostFunction struct {
Callback HostFunctionCallback
Name string
Namespace string
Params []api.ValueType
Results []api.ValueType
UserData interface{}
// []api.ValueType{api.ValueTypeI64, api.ValueTypeI64},
// api.ValueTypeI64
// )
func NewHostFunctionWithStack(
name string,
namespace string,
callback HostFunctionStackCallback,
params []api.ValueType,
returnType api.ValueType) HostFunction {

return HostFunction{
stackCallback: callback,
Name: name,
Namespace: namespace,
Params: params,
Returns: []api.ValueType{returnType},
}
}

type CurrentPlugin struct {
Expand Down Expand Up @@ -187,17 +204,16 @@ func defineCustomHostFunctions(builder wazero.HostModuleBuilder, funcs []HostFun
// a separate variable (closure) and assigning the value of f to it, you might run into unexpected behavior.
// All the closures created in the loop would end up referencing the same f, which could lead to incorrect or unintended results.
// See: https://github.com/extism/go-sdk/issues/5#issuecomment-1666774486
closure := f.Callback
userData := f.UserData
closure := f.stackCallback

builder.NewFunctionBuilder().WithGoFunction(api.GoFunc(func(ctx context.Context, stack []uint64) {
if plugin, ok := ctx.Value("plugin").(*Plugin); ok {
closure(ctx, &CurrentPlugin{plugin}, userData, stack)
closure(ctx, &CurrentPlugin{plugin}, stack)
return
}

panic("Invalid context, `plugin` key not found")
}), f.Params, f.Results).Export(f.Name)
}), f.Params, f.Returns).Export(f.Name)
}
}

Expand Down
36 changes: 18 additions & 18 deletions runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ import (

// TODO: test runtime initialization for WASI and Haskell

type RuntimeType uint8
type runtimeType uint8

const (
None RuntimeType = iota
None runtimeType = iota
Haskell
Wasi
)

type GuestRuntime struct {
Type RuntimeType
Init func() error
type guestRuntime struct {
runtimeType runtimeType
init func() error
initialized bool
}

func guestRuntime(p *Plugin) GuestRuntime {
func detectGuestRuntime(p *Plugin) guestRuntime {
m := p.Main

runtime, ok := haskellRuntime(p, m)
Expand All @@ -34,16 +34,16 @@ func guestRuntime(p *Plugin) GuestRuntime {
}

p.Log(Trace, "No runtime detected")
return GuestRuntime{Type: None, Init: func() error { return nil }, initialized: true}
return guestRuntime{runtimeType: None, init: func() error { return nil }, initialized: true}
}

// Check for Haskell runtime initialization functions
// Initialize Haskell runtime if `hs_init` and `hs_exit` are present,
// by calling the `hs_init` export
func haskellRuntime(p *Plugin, m api.Module) (GuestRuntime, bool) {
func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
initFunc := m.ExportedFunction("hs_init")
if initFunc == nil {
return GuestRuntime{}, false
return guestRuntime{}, false
}

params := initFunc.Definition().ParamTypes()
Expand All @@ -70,13 +70,13 @@ func haskellRuntime(p *Plugin, m api.Module) (GuestRuntime, bool) {
}

p.Log(Trace, "Haskell runtime detected")
return GuestRuntime{Type: Haskell, Init: init}, true
return guestRuntime{runtimeType: Haskell, init: init}, true
}

// Check for initialization functions defined by the WASI standard
func wasiRuntime(p *Plugin, m api.Module) (GuestRuntime, bool) {
func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
if !p.Runtime.hasWasi {
return GuestRuntime{}, false
return guestRuntime{}, false
}

// WASI supports two modules: Reactors and Commands
Expand All @@ -90,30 +90,30 @@ func wasiRuntime(p *Plugin, m api.Module) (GuestRuntime, bool) {
}

// Check for `_initialize` this is used by WASI to initialize certain interfaces.
func reactorModule(m api.Module, p *Plugin) (GuestRuntime, bool) {
func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(m, p, "_initialize")
if init == nil {
return GuestRuntime{}, false
return guestRuntime{}, false
}

p.Logf(Trace, "WASI runtime detected")
p.Logf(Trace, "Reactor module detected")

return GuestRuntime{Type: Wasi, Init: init}, true
return guestRuntime{runtimeType: Wasi, init: init}, true
}

// Check for `__wasm__call_ctors`, this is used by WASI to
// initialize certain interfaces.
func commandModule(m api.Module, p *Plugin) (GuestRuntime, bool) {
func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(m, p, "__wasm_call_ctors")
if init == nil {
return GuestRuntime{}, false
return guestRuntime{}, false
}

p.Logf(Trace, "WASI runtime detected")
p.Logf(Trace, "Command module detected")

return GuestRuntime{Type: Wasi, Init: init}, true
return guestRuntime{runtimeType: Wasi, init: init}, true
}

func findFunc(m api.Module, p *Plugin, name string) func() error {
Expand Down

0 comments on commit afbb83e

Please sign in to comment.