diff --git a/internal/wasi/wasi_test.go b/internal/wasi/wasi_test.go index b8d021aff2..41df2b8a90 100644 --- a/internal/wasi/wasi_test.go +++ b/internal/wasi/wasi_test.go @@ -1494,7 +1494,7 @@ func TestSnapshotPreview1_FdWrite_Errors(t *testing.T) { } func createFile(t *testing.T, path string, contents []byte) (*memFile, *MemFS) { - memFS := &MemFS{} + memFS := &MemFS{Files: map[string][]byte{}} f, err := memFS.OpenWASI(0, path, wasi.O_CREATE|wasi.O_TRUNC, wasi.R_FD_WRITE, 0, 0) require.NoError(t, err) diff --git a/internal/wasm/module_context.go b/internal/wasm/module_context.go index 4a22edfe55..ae247bd482 100644 --- a/internal/wasm/module_context.go +++ b/internal/wasm/module_context.go @@ -23,15 +23,14 @@ type ModuleContext struct { memory publicwasm.Memory store *Store - // Sys is not exposed publicly. This is currently only used by internalwasi. + // sys is not exposed publicly. This is currently only used by internalwasi. // Note: This is a part of ModuleContext so that scope is correct and Close is coherent. sys *SysContext } // WithMemory allows overriding memory without re-allocation when the result would be the same. func (m *ModuleContext) WithMemory(memory *MemoryInstance) *ModuleContext { - // only re-allocate if it will change the effective memory - if m.memory == nil || (memory != nil && memory.Max != nil && *memory.Max > 0 && memory != m.memory) { + if memory != nil && memory != m.memory { // only re-allocate if it will change the effective memory return &ModuleContext{module: m.module, memory: memory, ctx: m.ctx, sys: m.sys} } return m @@ -54,8 +53,7 @@ func (m *ModuleContext) Sys() *SysContext { // WithContext implements wasm.Module WithContext func (m *ModuleContext) WithContext(ctx context.Context) publicwasm.Module { - // only re-allocate if it will change the effective context - if ctx != nil && ctx != m.ctx { + if ctx != nil && ctx != m.ctx { // only re-allocate if it will change the effective context return &ModuleContext{module: m.module, memory: m.memory, ctx: ctx, sys: m.sys} } return m diff --git a/internal/wasm/module_context_test.go b/internal/wasm/module_context_test.go new file mode 100644 index 0000000000..eda6725c18 --- /dev/null +++ b/internal/wasm/module_context_test.go @@ -0,0 +1,199 @@ +package internalwasm + +import ( + "context" + "path" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestModuleContext_WithContext(t *testing.T) { + type key string + tests := []struct { + name string + mod *ModuleContext + ctx context.Context + expectSame bool + }{ + { + name: "nil->nil: same", + mod: &ModuleContext{}, + ctx: nil, + expectSame: true, + }, + { + name: "nil->ctx: not same", + mod: &ModuleContext{}, + ctx: context.WithValue(context.Background(), key("a"), "b"), + expectSame: false, + }, + { + name: "ctx->nil: same", + mod: &ModuleContext{ctx: context.Background()}, + ctx: nil, + expectSame: true, + }, + { + name: "ctx1->ctx2: not same", + mod: &ModuleContext{ctx: context.Background()}, + ctx: context.WithValue(context.Background(), key("a"), "b"), + expectSame: false, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + mod2 := tc.mod.WithContext(tc.ctx) + if tc.expectSame { + require.Same(t, tc.mod, mod2) + } else { + require.NotSame(t, tc.mod, mod2) + require.Equal(t, tc.ctx, mod2.Context()) + } + }) + } +} + +func TestModuleContext_WithMemory(t *testing.T) { + tests := []struct { + name string + mod *ModuleContext + mem *MemoryInstance + expectSame bool + }{ + { + name: "nil->nil: same", + mod: &ModuleContext{}, + mem: nil, + expectSame: true, + }, + { + name: "nil->mem: not same", + mod: &ModuleContext{}, + mem: &MemoryInstance{}, + expectSame: false, + }, + { + name: "mem->nil: same", + mod: &ModuleContext{memory: &MemoryInstance{}}, + mem: nil, + expectSame: true, + }, + { + name: "mem1->mem2: not same", + mod: &ModuleContext{memory: &MemoryInstance{}}, + mem: &MemoryInstance{}, + expectSame: false, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + mod2 := tc.mod.WithMemory(tc.mem) + if tc.expectSame { + require.Same(t, tc.mod, mod2) + } else { + require.NotSame(t, tc.mod, mod2) + require.Equal(t, tc.mem, mod2.memory) + } + }) + } +} + +func TestModuleContext_String(t *testing.T) { + s := newStore() + + tests := []struct { + name, moduleName, expected string + }{ + { + name: "empty", + moduleName: "", + expected: "Module[]", + }, + { + name: "not empty", + moduleName: "math", + expected: "Module[math]", + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + // Ensure paths that can create the host module can see the name. + m, err := s.Instantiate(context.Background(), &Module{}, tc.moduleName, nil) + defer m.Close() //nolint + + require.NoError(t, err) + require.Equal(t, tc.expected, m.String()) + require.Equal(t, tc.expected, s.Module(m.module.Name).String()) + }) + } +} + +func TestModuleContext_Close(t *testing.T) { + s := newStore() + + t.Run("calls store.CloseModule(module.name)", func(t *testing.T) { + moduleName := t.Name() + m, err := s.Instantiate(context.Background(), &Module{}, moduleName, nil) + require.NoError(t, err) + + // We use side effects to determine if Close in fact called store.CloseModule (without repeating store_test.go). + // One side effect of store.CloseModule is that the moduleName can no longer be looked up. Verify our base case. + require.Equal(t, s.Module(moduleName), m) + + // Closing should not err. + require.NoError(t, m.Close()) + + // Verify our intended side-effect + require.Nil(t, s.Module(moduleName)) + + // Verify no error closing again. + require.NoError(t, m.Close()) + }) + + t.Run("calls SysContext.Close()", func(t *testing.T) { + tempDir := t.TempDir() + pathName := "test" + file, _ := createWriteableFile(t, tempDir, pathName, make([]byte, 0)) + + sys, err := NewSysContext( + 0, // max + nil, // args + nil, // environ + nil, // stdin + nil, // stdout + nil, // stderr + map[uint32]*FileEntry{ // openedFiles + 3: {Path: "."}, + 4: {Path: path.Join(".", pathName), File: file}, + }, + ) + require.NoError(t, err) + + moduleName := t.Name() + m, err := s.Instantiate(context.Background(), &Module{}, moduleName, sys) + require.NoError(t, err) + + // We use side effects to determine if Close in fact called SysContext.Close (without repeating sys_test.go). + // One side effect of SysContext.Close is that it clears the openedFiles map. Verify our base case. + require.NotEmpty(t, sys.openedFiles) + + // Closing should not err. + require.NoError(t, m.Close()) + + // Verify our intended side-effect + require.Empty(t, sys.openedFiles) + + // Verify no error closing again. + require.NoError(t, m.Close()) + }) +} diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 4297e7b265..d3d7f6eaca 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -87,16 +87,6 @@ func TestModuleInstance_Memory(t *testing.T) { } } -func TestModuleContext_String(t *testing.T) { - s := newStore() - - // Ensure paths that can create the host module can see the name. - m, err := s.Instantiate(context.Background(), &Module{}, "module", nil) - require.NoError(t, err) - require.Equal(t, "Module[module]", m.String()) - require.Equal(t, "Module[module]", s.Module(m.module.Name).String()) -} - func TestStore_Instantiate(t *testing.T) { s := newStore() m, err := NewHostModule("", map[string]interface{}{"fn": func(wasm.Module) {}}) diff --git a/internal/wasm/sys.go b/internal/wasm/sys.go index edd59966b3..d0185c54ee 100644 --- a/internal/wasm/sys.go +++ b/internal/wasm/sys.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "math" - "os" "sync/atomic" "github.com/tetratelabs/wazero/wasi" @@ -197,11 +196,15 @@ func nullTerminatedByteCount(max uint32, elements []string) (uint32, error) { // Close implements io.Closer func (c *SysContext) Close() (err error) { - // stdin, stdout and stderr are only closed if we opened them. The only case we open is when stdin -> /dev/null - if f, ok := c.stdin.(*os.File); ok && f.Name() == os.DevNull { - _ = f.Close() // ignore error closing reader of /dev/null + // Close any files opened in this context + for fd, entry := range c.openedFiles { + delete(c.openedFiles, fd) + if entry.File != nil { // File is nil for a mount like "." or "/" + if e := entry.File.Close(); e != nil { + err = e // This means the err returned == the last non-nil error. + } + } } - // TODO: close openedFiles in #394 return } diff --git a/internal/wasm/sys_test.go b/internal/wasm/sys_test.go index 1f9710e4f9..5bd958adfb 100644 --- a/internal/wasm/sys_test.go +++ b/internal/wasm/sys_test.go @@ -3,6 +3,9 @@ package internalwasm import ( "bytes" "io" + "io/fs" + "os" + "path" "testing" "github.com/stretchr/testify/require" @@ -147,3 +150,57 @@ func TestNewSysContext_Environ(t *testing.T) { }) } } + +func TestSysContext_Close(t *testing.T) { + t.Run("no files", func(t *testing.T) { + sys := DefaultSysContext() + require.NoError(t, sys.Close()) + }) + + t.Run("open files", func(t *testing.T) { + tempDir := t.TempDir() + pathName := "test" + file, _ := createWriteableFile(t, tempDir, pathName, make([]byte, 0)) + + sys, err := NewSysContext( + 0, // max + nil, // args + nil, // environ + nil, // stdin + nil, // stdout + nil, // stderr + map[uint32]*FileEntry{ // openedFiles + 3: {Path: "."}, + 4: {Path: path.Join(".", pathName), File: file}, + }, + ) + require.NoError(t, err) + + // Closing should delete the file descriptors after closing the files. + require.NoError(t, sys.Close()) + require.Empty(t, sys.openedFiles) + + // Verify it was actually closed, by trying to close it again. + err = file.Close() + require.Contains(t, err.Error(), "file already closed") + + // No problem closing config again because the descriptors were removed, so they won't be called again. + require.NoError(t, sys.Close()) + }) + + // TODO: fs but never used (ex file == nil) + // TODO: externally closed +} + +// createWriteableFile uses real files when io.Writer tests are needed. +// TODO: temporarily *os.File until #394 +func createWriteableFile(t *testing.T, tmpDir string, pathName string, data []byte) (*os.File, fs.FS) { + require.NotNil(t, data) + absolutePath := path.Join(tmpDir, pathName) + require.NoError(t, os.WriteFile(absolutePath, data, 0o600)) + + // open the file for writing in a custom way until #390 + f, err := os.OpenFile(absolutePath, os.O_RDWR, 0o600) + require.NoError(t, err) + return f, os.DirFS(tmpDir) +}