diff --git a/wasi/types.go b/wasi/types.go index d0ec0e78a7..1a0c63ccf2 100644 --- a/wasi/types.go +++ b/wasi/types.go @@ -367,8 +367,19 @@ const ( FunctionProcExit = "proc_exit" FunctionProcRaise = "proc_raise" FunctionSchedYield = "sched_yield" - FunctionRandomGet = "random_get" - FunctionSockRecv = "sock_recv" - FunctionSockSend = "sock_send" - FunctionSockShutdown = "sock_shutdown" + + // FunctionRandomGet write random data in buffer + // + // See ImportRandomGet + // See API.RandomGet + // See: https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-random_getbuf-pointeru8-buf_len-size---errno + FunctionRandomGet = "random_get" + + // ImportRandomGet is the WebAssembly 1.0 (MVP) Text format import of FunctionRandomGet + ImportRandomGet = `(import "wasi_snapshot_preview1" "random_get" + (func $wasi.random_get (param $buf i32) (param $buf_len i32) (result (;errno;) i32)))` + + FunctionSockRecv = "sock_recv" + FunctionSockSend = "sock_send" + FunctionSockShutdown = "sock_shutdown" ) diff --git a/wasi/wasi.go b/wasi/wasi.go index 13aa2bd232..949c91ca73 100644 --- a/wasi/wasi.go +++ b/wasi/wasi.go @@ -1,11 +1,12 @@ package wasi import ( + crand "crypto/rand" "errors" "io" "io/fs" "math" - "math/rand" + mrand "math/rand" "os" "reflect" "time" @@ -144,7 +145,24 @@ type API interface { // TODO: ProcExit // TODO: ProcRaise // TODO: SchedYield - // TODO: RandomGet + + // RandomGet is a WASI function that write random data in buffer (rand.Read()). + // + // * buf - is a offset to write random values + // * bufLen - size of random data in bytes + // + // For example, if `HostFunctionCallContext.Randomizer` initialized + // with random seed `rand.NewSource(42)`, we expect `ctx.Memory.Buffer` to contain: + // + // bufLen (5) + // +--------------------------+ + // | | + // []byte{?, 0x53, 0x8c, 0x7f, 0x96, 0xb1, ?} + // buf --^ + // + // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-random_getbuf-pointeru8-bufLen-size---errno + RandomGet(ctx wasm.HostFunctionCallContext, buf, bufLen uint32) Errno + // TODO: SockRecv // TODO: SockSend // TODO: SockShutdown @@ -163,6 +181,7 @@ type wasiAPI struct { opened map[uint32]fileEntry // timeNowUnixNano is mutable for testing timeNowUnixNano func() uint64 + randSource func([]byte) error } func (a *wasiAPI) register(store *wasm.Store) (err error) { @@ -213,7 +232,7 @@ func (a *wasiAPI) register(store *wasm.Store) (err error) { {FunctionProcExit, proc_exit}, // TODO: FunctionProcRaise // TODO: FunctionSchedYield - // TODO: FunctionRandomGet + {FunctionRandomGet, a.RandomGet}, // TODO: FunctionSockRecv // TODO: FunctionSockSend // TODO: FunctionSockShutdown @@ -346,6 +365,10 @@ func newAPI(opts ...Option) *wasiAPI { timeNowUnixNano: func() uint64 { return uint64(time.Now().UnixNano()) }, + randSource: func(p []byte) error { + _, err := crand.Read(p) + return err + }, } // apply functional options @@ -356,7 +379,7 @@ func newAPI(opts ...Option) *wasiAPI { } func (a *wasiAPI) randUnusedFD() uint32 { - fd := uint32(rand.Int31()) + fd := uint32(mrand.Int31()) for { if _, ok := a.opened[fd]; !ok { return fd @@ -538,6 +561,22 @@ func (a *wasiAPI) fd_close(ctx wasm.HostFunctionCallContext, fd uint32) (err Err return ErrnoSuccess } +// RandomGet implements API.RandomGet +func (a *wasiAPI) RandomGet(ctx wasm.HostFunctionCallContext, buf uint32, bufLen uint32) (errno Errno) { + randomBytes := make([]byte, bufLen) + err := a.randSource(randomBytes) + if err != nil { + // TODO: handle different errors that syscal to entropy source can return + return ErrnoIo + } + + if !ctx.Memory().Write(buf, randomBytes) { + return ErrnoFault + } + + return ErrnoSuccess +} + func proc_exit(wasm.HostFunctionCallContext, uint32) { // TODO: implement } diff --git a/wasi/wasi_test.go b/wasi/wasi_test.go index fb7b353e02..5e4831cfcd 100644 --- a/wasi/wasi_test.go +++ b/wasi/wasi_test.go @@ -3,7 +3,9 @@ package wasi import ( "context" _ "embed" + "errors" "fmt" + "math/rand" "testing" "github.com/stretchr/testify/require" @@ -305,7 +307,85 @@ func TestAPI_ClockTimeGet_Errors(t *testing.T) { // TODO: TestAPI_ProcExit TestAPI_ProcExit_Errors // TODO: TestAPI_ProcRaise TestAPI_ProcRaise_Errors // TODO: TestAPI_SchedYield TestAPI_SchedYield_Errors -// TODO: TestAPI_RandomGet TestAPI_RandomGet_Errors + +func TestAPI_RandomGet(t *testing.T) { + store, api := instantiateWasmStore(t, FunctionRandomGet, ImportRandomGet, "test") + maskLength := 7 // number of bytes to write '?' to tell what we've written + expectedMemory := []byte{ + '?', // random bytes in `buf` is after this + 0x53, 0x8c, 0x7f, 0x96, 0xb1, // random data from seed value of 42 + '?', // stopped after encoding + } // tr + + var bufLen = uint32(5) // arbitrary buffer size, + var buf = uint32(1) // offset, + var seed = int64(42) // and seed value + + api.(*wasiAPI).randSource = func(p []byte) error { + s := rand.NewSource(seed) + rng := rand.New(s) + _, err := rng.Read(p) + + return err + } + + t.Run("API.RandomGet", func(t *testing.T) { + maskMemory(store, maskLength) + hContext := wasm.NewHostFunctionCallContext(context.Background(), store.Memories[0]) + + errno := api.RandomGet(hContext, buf, bufLen) + require.Equal(t, ErrnoSuccess, errno) + require.Equal(t, expectedMemory, store.Memories[0].Buffer[0:maskLength]) + }) +} + +func TestAPI_RandomGet_Errors(t *testing.T) { + store, api := instantiateWasmStore(t, FunctionRandomGet, ImportRandomGet, "test") + + memorySize := uint32(len(store.Memories[0].Buffer)) + validAddress := uint32(0) // arbitrary valid address + tests := []struct { + name string + buf uint32 + bufLen uint32 + }{ + { + name: "random buffer out-of-memory", + buf: memorySize, + bufLen: 1, + }, + + { + name: "random buffer size exceeds maximum valid address by 1", + buf: validAddress, + bufLen: memorySize + 1, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + ret, _, err := store.CallFunction(context.Background(), "test", FunctionRandomGet, uint64(tc.buf), uint64(tc.bufLen)) + require.NoError(t, err) + require.Equal(t, uint64(ErrnoFault), ret[0]) // ret[0] is returned errno + }) + } + + t.Run("API.RandomGet returns ErrnoIO on random source err", func(t *testing.T) { + hContext := wasm.NewHostFunctionCallContext(context.Background(), store.Memories[0]) + + api.(*wasiAPI).randSource = func(p []byte) error { + return errors.New("random source error") + } + var bufLen = uint32(5) // arbitrary buffer size, + var buf = uint32(1) // and offset + errno := api.RandomGet(hContext, buf, bufLen) + require.Equal(t, ErrnoIo, errno) + }) + +} + // TODO: TestAPI_SockRecv TestAPI_SockRecv_Errors // TODO: TestAPI_SockSend TestAPI_SockSend_Errors // TODO: TestAPI_SockShutdown TestAPI_SockShutdown_Errors