diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..f999431de5 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,7 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..0499de9f68 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,70 @@ +# Contributing + +We welcome contributions from the community. Please read the following guidelines carefully to maximize the chances of your PR being merged. + +## Coding Style + +- To ensure your change passes format checks, use run `make check`. To format your files, you can run `make format`. +- We follow standard Go table-driven tests and use the [`testify/require`](https://github.com/stretchr/testify#require-package) library to assert correctness. To verify all tests pass, you can run `make test`. + +## DCO + +We require DCO signoff line in every commit to this repo. + +The sign-off is a simple line at the end of the explanation for the +patch, which certifies that you wrote it or otherwise have the right to +pass it on as an open-source patch. The rules are pretty simple: if you +can certify the below (from +[developercertificate.org](https://developercertificate.org/)): + +``` +Developer Certificate of Origin +Version 1.1 +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +660 York Street, Suite 102, +San Francisco, CA 94110 USA +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. +Developer's Certificate of Origin 1.1 +By making a contribution to this project, I certify that: +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +then you just add a line to every git commit message: + + Signed-off-by: Joe Smith + +using your real name (sorry, no pseudonyms or anonymous contributions.) + +You can add the sign off when creating the git commit via `git commit -s`. + +## Code Reviews + +* Indicate the priority of each comment, following this +[feedback ladder](https://www.netlify.com/blog/2020/03/05/feedback-ladders-how-we-encode-code-reviews-at-netlify/). +If none was indicated it will be treated as `[dust]`. +* A single approval is sufficient to merge, except when the change cuts +across several components; then it should be approved by at least one owner +of each component. If a reviewer asks for changes in a PR they should be +addressed before the PR is merged, even if another reviewer has already +approved the PR. +* During the review, address the comments and commit the changes _without_ squashing the commits. +This facilitates incremental reviews since the reviewer does not go through all the code again to +find out what has changed since the last review. diff --git a/README.md b/README.md index 5201fe28ba..469e169b16 100644 --- a/README.md +++ b/README.md @@ -5,23 +5,40 @@ portability features like cross compilation. Import wazero and extend your Go ap language! ## Example + Here's an example of using wazero to invoke a Fibonacci function included in a Wasm binary. -While our [source for this](examples/testdata/fibonacci.go) is [TinyGo](https://tinygo.org/), it could have been written in -another language that targets Wasm, such as Rust. +While our [source for this](examples/testdata/fibonacci.go) is [TinyGo](https://tinygo.org/), it could have been written in another language that targets Wasm, such as AssemblyScript/C/C++/Rust/Zig. ```golang +package main + +import ( + "context" + "fmt" + "os" + + "github.com/tetratelabs/wazero/wasi" + "github.com/tetratelabs/wazero/wasm" + "github.com/tetratelabs/wazero/wasm/binary" + "github.com/tetratelabs/wazero/wasm/interpreter" +) + func main() { + // Default context impl. by Go + ctx := context.Background() // Read WebAssembly binary. - source, _ := os.ReadFile("fibonacci.wasm") + source, _ := os.ReadFile("examples/testdata/fibonacci.wasm") // Decode the binary as WebAssembly module. mod, _ := binary.DecodeModule(source) // Initialize the execution environment called "store" with Interpreter-based engine. store := wasm.NewStore(interpreter.NewEngine()) + // To resolve WASI specific methods, such as `fd_write` + wasi.RegisterAPI(store) // Instantiate the decoded module. store.Instantiate(mod, "test") // Execute the exported "fibonacci" function from the instantiated module. - ret, _, err := store.CallFunction("test", "fibonacci", 20) + ret, _, _ := store.CallFunction(ctx, "test", "fibonacci", 20) // Give us the fibonacci number for 20, namely 6765! fmt.Println(ret[0]) } @@ -58,4 +75,4 @@ Currently any performance optimization hasn't been done to this runtime yet, and However _theoretically speaking_, this project have the potential to compete with these state-of-the-art JIT-style runtimes. The rationale for that is it is well-know that [CGO is slow](https://github.com/golang/go/issues/19574). More specifically, if you make large amount of CGO calls which cross the boundary between Go and C (stack) space, then the usage of CGO could be a bottleneck. -Since we can do JIT compilation purely in Go, this runtime could be the fastest one for some use cases where we have to make large amount of CGO calls (e.g. Proxy-Wasm host environment, or request-based plugin systems). +Since we can do JIT compilation purely in Go, this runtime could be the fastest one for some use cases where we have to make large amount of CGO calls (e.g. Proxy-Wasm host environment, or request-based plugin systems). \ No newline at end of file diff --git a/tests/README.md b/tests/README.md index a7b3756b8b..8208c8a523 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,6 +1,9 @@ This directory contains tests which use multiple packages. For example: -- `bench` contains benchmark tests. -- `codec` contains a test and benchmark on text and binary decoders. -- `engine` contains variety of e2e tests, mainly to ensure the consistency in the behavior between engines. -- `spectest` contains end-to-end tests with the [WebAssembly specification tests](https://github.com/WebAssembly/spec/tree/wg-1.0/test/core). +* `bench` contains benchmark tests. +* `codec` contains a test and benchmark on text and binary decoders. +* `engine` contains variety of e2e tests, mainly to ensure the consistency in the behavior between engines. +* `spectest` contains end-to-end tests with the [WebAssembly specification tests](https://github.com/WebAssembly/spec/tree/wg-1.0/test/core). + +*Note*: this doesn't contain WASI tests, as there's not yet an [official testsuite](https://github.com/WebAssembly/WASI/issues/9). +Meanwhile, WASI functions are unit tested including via Text Format imports [here](../wasi/wasi_test.go) diff --git a/wasi/testdata/random.wat b/wasi/testdata/random.wat new file mode 100644 index 0000000000..29c61230ac --- /dev/null +++ b/wasi/testdata/random.wat @@ -0,0 +1,24 @@ +;; This is a wat file to just export clock WASI API to the host environment for testing the APIs. +;; This is currently separated as a wat file and pre-compiled because our text parser doesn't +;; implement 'memory' yet. After it supports 'memory', we can remove this file and embed this +;; wat file in the Go test code. +;; +;; Note: Although this is a raw wat file which should be moved under /tests/wasi in principle, +;; this file is put here for now, because this is a temporary file until the parser supports +;; the enough syntax, and this file will be embedded in unit test codes after that. +(module + (import "wasi_snapshot_preview1" "random_get" + (func $wasi.random_get (param $buf i32) (param $buf_len i32) (result (;errno;) i32))) + (memory 1) ;; just an arbitrary size big enough for tests + (export "memory" (memory 0)) + ;; Define wrapper functions instead of just exporting the imported WASI APIS for now + ;; because wazero's interpreter has a bug that it crashes when an imported-and-exported host function + ;; is called from the host environment, which will be fixed soon. + ;; After it's fixed, these wrapper functions are no longer necessary. + (func $random_get (param i32 i32) (result i32) + local.get 0 + local.get 1 + call $wasi.random_get + ) + (export "random_get" (func $random_get)) + ) diff --git a/wasi/wasi.go b/wasi/wasi.go index bcf4227f17..42d0d36cd2 100644 --- a/wasi/wasi.go +++ b/wasi/wasi.go @@ -1,12 +1,12 @@ package wasi import ( + crand "crypto/rand" "encoding/binary" "errors" "io" "io/fs" "math" - "math/rand" "os" "reflect" "time" @@ -143,7 +143,24 @@ type API interface { // TODO: ProcExit // TODO: ProcRaise // TODO: SchedYield - // TODO: RandomGet + + // RandomGet is a WASI function that write random data in buffer (rand.Read()). + // + // * buf - is a offset to write random values + // * bufLen - size of random data in bytes + // + // For example, if `HostFunctionCallContext.Randomizer` initialized + // with random seed `rand.NewSource(42)`, we expect `ctx.Memory.Buffer` to contain: + // + // bufLen (5) + // +--------------------------+ + // | | + // []byte{?, 0x53, 0x8c, 0x7f, 0x96, 0xb1, ?} + // buf --^ + // + // See https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-random_getbuf-pointeru8-bufLen-size---errno + RandomGet(ctx *wasm.HostFunctionCallContext, buf, bufLen uint32) Errno + // TODO: SockRecv // TODO: SockSend // TODO: SockShutdown @@ -154,6 +171,24 @@ const ( wasiSnapshotPreview1Name = "wasi_snapshot_preview1" ) +type RandomSource interface { + Read([]byte) (int, error) + Int31() (int32, error) +} + +// Non-deterministic random source using crypto/rand +type CryptoRandomSource struct{} + +func (c *CryptoRandomSource) Read(p []byte) (n int, err error) { + return crand.Read(p) +} + +func (c *CryptoRandomSource) Int31() (v int32, err error) { + err = binary.Read(crand.Reader, binary.BigEndian, &v) + + return v, err +} + type api struct { args *nullTerminatedStrings stdin io.Reader @@ -162,6 +197,7 @@ type api struct { opened map[uint32]fileEntry // timeNowUnixNano is mutable for testing timeNowUnixNano func() uint64 + randSource RandomSource } func (a *api) register(store *wasm.Store) (err error) { @@ -212,7 +248,7 @@ func (a *api) register(store *wasm.Store) (err error) { {FunctionProcExit, proc_exit}, // TODO: FunctionProcRaise // TODO: FunctionSchedYield - // TODO: FunctionRandomGet + {FunctionRandomGet, a.RandomGet}, // TODO: FunctionSockRecv // TODO: FunctionSockSend // TODO: FunctionSockShutdown @@ -345,6 +381,7 @@ func newAPI(opts ...Option) *api { timeNowUnixNano: func() uint64 { return uint64(time.Now().UnixNano()) }, + randSource: &CryptoRandomSource{}, } // apply functional options @@ -354,11 +391,16 @@ func newAPI(opts ...Option) *api { return ret } -func (a *api) randUnusedFD() uint32 { - fd := uint32(rand.Int31()) +func (a *api) randUnusedFD() (uint32, error) { + v, err := a.randSource.Int31() + if err != nil { + return 0, err + } + + fd := uint32(v) for { if _, ok := a.opened[fd]; !ok { - return fd + return fd, nil } fd = (fd + 1) % (1 << 31) } @@ -412,7 +454,10 @@ func (a *api) path_open(ctx *wasm.HostFunctionCallContext, fd, dirFlags, pathPtr } } - newFD := a.randUnusedFD() + newFD, err := a.randUnusedFD() + if err != nil { + return ErrnoInval + } a.opened[newFD] = fileEntry{ file: f, @@ -503,6 +548,23 @@ func (a *api) fd_close(ctx *wasm.HostFunctionCallContext, fd uint32) (err Errno) return ErrnoSuccess } +// RandomGet implements API.RandomGet +func (a *api) RandomGet(ctx *wasm.HostFunctionCallContext, buf uint32, bufLen uint32) (errno Errno) { + if !ctx.Memory.ValidateAddrRange(buf, uint64(bufLen)) { + return ErrnoInval + } + + random_bytes := make([]byte, bufLen) + _, err := a.randSource.Read(random_bytes) + if err != nil { + return ErrnoInval + } + + copy(ctx.Memory.Buffer[buf:buf+bufLen], random_bytes) + + return ErrnoSuccess +} + func proc_exit(*wasm.HostFunctionCallContext, uint32) { // TODO: implement } diff --git a/wasi/wasi_test.go b/wasi/wasi_test.go index 1cbbef32a0..f0b9b9af6c 100644 --- a/wasi/wasi_test.go +++ b/wasi/wasi_test.go @@ -3,6 +3,7 @@ package wasi import ( "context" _ "embed" + mrand "math/rand" "testing" "github.com/stretchr/testify/require" @@ -312,7 +313,93 @@ func TestAPI_ClockTimeGet_Errors(t *testing.T) { // TODO: TestAPI_ProcExit TestAPI_ProcExit_Errors // TODO: TestAPI_ProcRaise TestAPI_ProcRaise_Errors // TODO: TestAPI_SchedYield TestAPI_SchedYield_Errors -// TODO: TestAPI_RandomGet TestAPI_RandomGet_Errors + +// randomWat is a wasm module to call random_get. +//go:embed testdata/random.wat +var randomWat []byte + +// Non-deterministic random rource using crypto/rand +type DummyRandomSource struct { + rng *mrand.Rand +} + +func (d *DummyRandomSource) Read(p []byte) (n int, err error) { + return d.rng.Read(p) +} + +func (d *DummyRandomSource) Int31() (v int32, err error) { + return d.rng.Int31(), nil + +} + +func NewDummyRandomSource(seed int64) RandomSource { + s := mrand.NewSource(seed) + + return &DummyRandomSource{ + rng: mrand.New(s), + } +} + +func TestAPI_RandomGet(t *testing.T) { + store, wasiAPI := instantiateWasmStore(t, randomWat, "test") + maskLength := 7 // number of bytes to write '?' to tell what we've written + expectedMemory := []byte{ + '?', // random bytes in `buf` is after this + 0x53, 0x8c, 0x7f, 0x96, 0xb1, // random data from seed value of 42 + '?', // stopped after encoding + } // tr + + var bufLen = uint32(5) // arbitrary buffer size, + var buf = uint32(1) // offset, + var seed = int64(42) // and seed value + + wasiAPI.(*api).randSource = NewDummyRandomSource(seed) + + t.Run("API.RandomGet", func(t *testing.T) { + maskMemory(store, maskLength) + // provide a host context with a seed value for random generator + hContext := wasm.NewHostFunctionCallContext(context.Background(), store.Memories[0]) + + errno := wasiAPI.RandomGet(hContext, buf, bufLen) + require.Equal(t, ErrnoSuccess, errno) + require.Equal(t, expectedMemory, store.Memories[0].Buffer[0:maskLength]) + }) +} + +func TestAPI_RandomGet_Errors(t *testing.T) { + store, _ := instantiateWasmStore(t, randomWat, "test") + + memorySize := uint32(len(store.Memories[0].Buffer)) + validAddress := uint32(0) // arbitrary valid address as arguments to args_sizes_get. We chose 0 here. + tests := []struct { + name string + buf uint32 + bufLen uint32 + }{ + { + name: "random buffer out-of-memory", + buf: memorySize, + bufLen: 1, + }, + + { + name: "random buffer size exceeds the maximum valid address by 1", + buf: validAddress, + bufLen: memorySize + 1, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + ret, _, err := store.CallFunction(context.Background(), "test", FunctionRandomGet, uint64(tc.buf), uint64(tc.bufLen)) + require.NoError(t, err) + require.Equal(t, uint64(ErrnoInval), ret[0]) // ret[0] is returned errno + }) + } +} + // TODO: TestAPI_SockRecv TestAPI_SockRecv_Errors // TODO: TestAPI_SockSend TestAPI_SockSend_Errors // TODO: TestAPI_SockShutdown TestAPI_SockShutdown_Errors diff --git a/wasm/binary/decoder.go b/wasm/binary/decoder.go index bc7aee5a08..d1af58a1c7 100644 --- a/wasm/binary/decoder.go +++ b/wasm/binary/decoder.go @@ -27,8 +27,10 @@ func DecodeModule(binary []byte) (*wasm.Module, error) { m := &wasm.Module{} for { - sectionID := make([]byte, 1) - if _, err := io.ReadFull(r, sectionID); err == io.EOF { + // TODO: except custom sections, all others are required to be in order, but we aren't checking yet. + // See https://www.w3.org/TR/wasm-core-1/#modules%E2%91%A0%E2%93%AA + sectionID, err := r.ReadByte() + if err == io.EOF { break } else if err != nil { return nil, fmt.Errorf("read section id: %w", err) @@ -36,11 +38,11 @@ func DecodeModule(binary []byte) (*wasm.Module, error) { sectionSize, _, err := leb128.DecodeUint32(r) if err != nil { - return nil, fmt.Errorf("get size of section for id=%d: %v", sectionID[0], err) + return nil, fmt.Errorf("get size of section %s: %v", wasm.SectionIDName(sectionID), err) } sectionContentStart := r.Len() - switch sectionID[0] { + switch sectionID { case wasm.SectionIDCustom: // First, validate the section and determine if the section for this name has already been set name, nameSize, decodeErr := decodeUTF8(r, "custom section name") @@ -53,25 +55,19 @@ func DecodeModule(binary []byte) (*wasm.Module, error) { } else if name == "name" && m.NameSection != nil { err = fmt.Errorf("redundant custom section %s", name) break - } else if _, ok := m.CustomSections[name]; ok { - err = fmt.Errorf("redundant custom section %s", name) - break } - // Now, either decode the NameSection or store an unsupported one - // TODO: Do we care to store something we don't use? We could also skip it! - data, dataErr := readCustomSectionData(r, sectionSize-nameSize) - if dataErr != nil { - err = dataErr - } else if name == "name" { - m.NameSection, err = decodeNameSection(data) + // Now, either decode the NameSection or skip an unsupported one + limit := sectionSize - nameSize + if name == "name" { + m.NameSection, err = decodeNameSection(r, uint64(limit)) } else { - if m.CustomSections == nil { - m.CustomSections = map[string][]byte{name: data} - } else { - m.CustomSections[name] = data + // Note: Not Seek because it doesn't err when given an offset past EOF. Rather, it leads to undefined state. + if _, err = io.CopyN(io.Discard, r, int64(limit)); err != nil { + return nil, fmt.Errorf("failed to skip name[%s]: %w", name, err) } } + case wasm.SectionIDType: m.TypeSection, err = decodeTypeSection(r) case wasm.SectionIDImport: @@ -104,11 +100,11 @@ func DecodeModule(binary []byte) (*wasm.Module, error) { } if err != nil { - return nil, fmt.Errorf("section ID %d: %v", sectionID[0], err) + return nil, fmt.Errorf("section %s: %v", wasm.SectionIDName(sectionID), err) } } - functionCount, codeCount := len(m.FunctionSection), len(m.CodeSection) + functionCount, codeCount := m.SectionElementCount(wasm.SectionIDFunction), m.SectionElementCount(wasm.SectionIDCode) if functionCount != codeCount { return nil, fmt.Errorf("function and code section have inconsistent lengths: %d != %d", functionCount, codeCount) } diff --git a/wasm/binary/decoder_test.go b/wasm/binary/decoder_test.go index c7ab44ec5a..9ab47ce875 100644 --- a/wasm/binary/decoder_test.go +++ b/wasm/binary/decoder_test.go @@ -26,21 +26,6 @@ func TestDecodeModule(t *testing.T) { name: "only name section", input: &wasm.Module{NameSection: &wasm.NameSection{ModuleName: "simple"}}, }, - { - name: "only custom section", - input: &wasm.Module{CustomSections: map[string][]byte{ - "meme": {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, - }}, - }, - { - name: "name section and a custom section", - input: &wasm.Module{ - NameSection: &wasm.NameSection{ModuleName: "simple"}, - CustomSections: map[string][]byte{ - "meme": {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, - }, - }, - }, { name: "type section", input: &wasm.Module{ @@ -107,6 +92,29 @@ func TestDecodeModule(t *testing.T) { require.Equal(t, tc.input, m) }) } + t.Run("skips custom section", func(t *testing.T) { + input := append(append(magic, version...), + wasm.SectionIDCustom, 0xf, // 15 bytes in this section + 0x04, 'm', 'e', 'm', 'e', + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + m, e := DecodeModule(input) + require.NoError(t, e) + require.Equal(t, &wasm.Module{}, m) + }) + t.Run("skips custom section, but not name", func(t *testing.T) { + input := append(append(magic, version...), + wasm.SectionIDCustom, 0xf, // 15 bytes in this section + 0x04, 'm', 'e', 'm', 'e', + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, + wasm.SectionIDCustom, 0x0e, // 14 bytes in this section + 0x04, 'n', 'a', 'm', 'e', + subsectionIDModuleName, 0x07, // 7 bytes in this subsection + 0x06, // the Module name simple is 6 bytes long + 's', 'i', 'm', 'p', 'l', 'e') + m, e := DecodeModule(input) + require.NoError(t, e) + require.Equal(t, &wasm.Module{NameSection: &wasm.NameSection{ModuleName: "simple"}}, m) + }) } func TestDecodeModule_Errors(t *testing.T) { @@ -125,27 +133,16 @@ func TestDecodeModule_Errors(t *testing.T) { input: []byte("\x00asm\x01\x00\x00\x01"), expectedErr: "invalid version header", }, - { - name: "redundant custom section", - input: append(append(magic, version...), - wasm.SectionIDCustom, 0x09, // 9 bytes in this section - 0x04, 'm', 'e', 'm', 'e', - subsectionIDModuleName, 0x03, 0x01, 'x', - wasm.SectionIDCustom, 0x09, // 9 bytes in this section - 0x04, 'm', 'e', 'm', 'e', - subsectionIDModuleName, 0x03, 0x01, 'y'), - expectedErr: "section ID 0: redundant custom section meme", - }, { name: "redundant name section", input: append(append(magic, version...), wasm.SectionIDCustom, 0x09, // 9 bytes in this section 0x04, 'n', 'a', 'm', 'e', - subsectionIDModuleName, 0x03, 0x01, 'x', + subsectionIDModuleName, 0x02, 0x01, 'x', wasm.SectionIDCustom, 0x09, // 9 bytes in this section 0x04, 'n', 'a', 'm', 'e', - subsectionIDModuleName, 0x03, 0x01, 'x'), - expectedErr: "section ID 0: redundant custom section name", + subsectionIDModuleName, 0x02, 0x01, 'x'), + expectedErr: "section custom: redundant custom section name", }, } diff --git a/wasm/binary/encoder.go b/wasm/binary/encoder.go index 7f755ba365..3db533cb67 100644 --- a/wasm/binary/encoder.go +++ b/wasm/binary/encoder.go @@ -11,47 +11,46 @@ var sizePrefixedName = []byte{4, 'n', 'a', 'm', 'e'} // See https://www.w3.org/TR/wasm-core-1/#binary-format%E2%91%A0 func EncodeModule(m *wasm.Module) (bytes []byte) { bytes = append(magic, version...) - for name, data := range m.CustomSections { - bytes = append(bytes, encodeCustomSection(name, data)...) - } - if len(m.TypeSection) > 0 { + if m.SectionElementCount(wasm.SectionIDType) > 0 { bytes = append(bytes, encodeTypeSection(m.TypeSection)...) } - if len(m.ImportSection) > 0 { + if m.SectionElementCount(wasm.SectionIDImport) > 0 { bytes = append(bytes, encodeImportSection(m.ImportSection)...) } - if len(m.FunctionSection) > 0 { + if m.SectionElementCount(wasm.SectionIDFunction) > 0 { bytes = append(bytes, encodeFunctionSection(m.FunctionSection)...) } - if len(m.TableSection) > 0 { + if m.SectionElementCount(wasm.SectionIDTable) > 0 { panic("TODO: TableSection") } - if len(m.MemorySection) > 0 { + if m.SectionElementCount(wasm.SectionIDMemory) > 0 { bytes = append(bytes, encodeMemorySection(m.MemorySection)...) } - if len(m.GlobalSection) > 0 { + if m.SectionElementCount(wasm.SectionIDGlobal) > 0 { panic("TODO: GlobalSection") } - if len(m.ExportSection) > 0 { + if m.SectionElementCount(wasm.SectionIDExport) > 0 { bytes = append(bytes, encodeExportSection(m.ExportSection)...) } - if m.StartSection != nil { + if m.SectionElementCount(wasm.SectionIDStart) > 0 { bytes = append(bytes, encodeStartSection(*m.StartSection)...) } - if len(m.ElementSection) > 0 { + if m.SectionElementCount(wasm.SectionIDElement) > 0 { panic("TODO: ElementSection") } - if len(m.CodeSection) > 0 { + if m.SectionElementCount(wasm.SectionIDCode) > 0 { bytes = append(bytes, encodeCodeSection(m.CodeSection)...) } - if len(m.DataSection) > 0 { + if m.SectionElementCount(wasm.SectionIDData) > 0 { panic("TODO: DataSection") } - // >> The name section should appear only once in a module, and only after the data section. - // See https://www.w3.org/TR/wasm-core-1/#binary-namesec - if m.NameSection != nil { - nameSection := append(sizePrefixedName, encodeNameSectionData(m.NameSection)...) - bytes = append(bytes, encodeSection(wasm.SectionIDCustom, nameSection)...) + if m.SectionElementCount(wasm.SectionIDCustom) > 0 { + // >> The name section should appear only once in a module, and only after the data section. + // See https://www.w3.org/TR/wasm-core-1/#binary-namesec + if m.NameSection != nil { + nameSection := append(sizePrefixedName, encodeNameSectionData(m.NameSection)...) + bytes = append(bytes, encodeSection(wasm.SectionIDCustom, nameSection)...) + } } return } diff --git a/wasm/binary/encoder_test.go b/wasm/binary/encoder_test.go index 903466b93b..6dc938b539 100644 --- a/wasm/binary/encoder_test.go +++ b/wasm/binary/encoder_test.go @@ -32,34 +32,6 @@ func TestModule_Encode(t *testing.T) { 0x06, // the Module name simple is 6 bytes long 's', 'i', 'm', 'p', 'l', 'e'), }, - { - name: "only custom section", - input: &wasm.Module{CustomSections: map[string][]byte{ - "meme": {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, - }}, - expected: append(append(magic, version...), - wasm.SectionIDCustom, 0xf, // 15 bytes in this section - 0x04, 'm', 'e', 'm', 'e', - 1, 2, 3, 4, 5, 6, 7, 8, 9, 0), - }, - { - name: "name section and a custom section", // name should encode last - input: &wasm.Module{ - NameSection: &wasm.NameSection{ModuleName: "simple"}, - CustomSections: map[string][]byte{ - "meme": {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, - }, - }, - expected: append(append(magic, version...), - wasm.SectionIDCustom, 0xf, // 15 bytes in this section - 0x04, 'm', 'e', 'm', 'e', - 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, - wasm.SectionIDCustom, 0x0e, // 14 bytes in this section - 0x04, 'n', 'a', 'm', 'e', - subsectionIDModuleName, 0x07, // 7 bytes in this subsection - 0x06, // the Module name simple is 6 bytes long - 's', 'i', 'm', 'p', 'l', 'e'), - }, { name: "type section", input: &wasm.Module{ diff --git a/wasm/binary/names.go b/wasm/binary/names.go index b9590522e6..dc64cdfd3e 100644 --- a/wasm/binary/names.go +++ b/wasm/binary/names.go @@ -27,17 +27,17 @@ const ( // * LocalNames decode from subsection 2 // // See https://www.w3.org/TR/wasm-core-1/#binary-namesec -func decodeNameSection(data []byte) (result *wasm.NameSection, err error) { +func decodeNameSection(r *bytes.Reader, limit uint64) (result *wasm.NameSection, err error) { // TODO: add leb128 functions that work on []byte and offset. While using a reader allows us to reuse reader-based // leb128 functions, it is less efficient, causes untestable code and in some cases more complex vs plain []byte. - r := bytes.NewReader(data) result = &wasm.NameSection{} // subsectionID is decoded if known, and skipped if not var subsectionID uint8 // subsectionSize is the length to skip when the subsectionID is unknown var subsectionSize uint32 - for { + var bytesRead uint64 + for limit > 0 { if subsectionID, err = r.ReadByte(); err != nil { if err == io.EOF { return result, nil @@ -45,11 +45,12 @@ func decodeNameSection(data []byte) (result *wasm.NameSection, err error) { // TODO: untestable as this can't fail for a reason beside EOF reading a byte from a buffer return nil, fmt.Errorf("failed to read a subsection ID: %w", err) } + limit-- - // TODO: unused except when skipping. This means we can pass on a corrupt length of a known subsection - if subsectionSize, _, err = leb128.DecodeUint32(r); err != nil { + if subsectionSize, bytesRead, err = leb128.DecodeUint32(r); err != nil { return nil, fmt.Errorf("failed to read the size of subsection[%d]: %w", subsectionID, err) } + limit -= bytesRead switch subsectionID { case subsectionIDModuleName: @@ -66,11 +67,13 @@ func decodeNameSection(data []byte) (result *wasm.NameSection, err error) { } default: // Skip other subsections. // Note: Not Seek because it doesn't err when given an offset past EOF. Rather, it leads to undefined state. - if _, err := io.CopyN(io.Discard, r, int64(subsectionSize)); err != nil { + if _, err = io.CopyN(io.Discard, r, int64(subsectionSize)); err != nil { return nil, fmt.Errorf("failed to skip subsection[%d]: %w", subsectionID, err) } } + limit -= uint64(subsectionSize) } + return } func decodeFunctionNames(r *bytes.Reader) (wasm.NameMap, error) { diff --git a/wasm/binary/names_test.go b/wasm/binary/names_test.go index d148b7cb47..0cf3450524 100644 --- a/wasm/binary/names_test.go +++ b/wasm/binary/names_test.go @@ -1,6 +1,7 @@ package binary import ( + "bytes" "testing" "github.com/stretchr/testify/require" @@ -201,7 +202,8 @@ func TestDecodeNameSection(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - ns, err := decodeNameSection(encodeNameSectionData(tc.input)) + data := encodeNameSectionData(tc.input) + ns, err := decodeNameSection(bytes.NewReader(data), uint64(len(data))) require.NoError(t, err) require.Equal(t, tc.input, ns) }) @@ -302,7 +304,7 @@ func TestDecodeNameSection_Errors(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - _, err := decodeNameSection(tc.input) + _, err := decodeNameSection(bytes.NewReader(tc.input), uint64(len(tc.input))) require.EqualError(t, err, tc.expectedErr) }) } diff --git a/wasm/binary/section.go b/wasm/binary/section.go index ddc00f3fae..7c587f977b 100644 --- a/wasm/binary/section.go +++ b/wasm/binary/section.go @@ -9,14 +9,6 @@ import ( "github.com/tetratelabs/wazero/wasm/internal/leb128" ) -func readCustomSectionData(r *bytes.Reader, dataSize uint32) ([]byte, error) { - data := make([]byte, dataSize) - if _, err := io.ReadFull(r, data); err != nil { - return nil, fmt.Errorf("cannot read custom section data: %w", err) - } - return data, nil -} - func decodeTypeSection(r io.Reader) ([]*wasm.FunctionType, error) { vs, _, err := leb128.DecodeUint32(r) if err != nil { @@ -165,7 +157,7 @@ func decodeExportSection(r *bytes.Reader) (map[string]*wasm.Export, error) { return exportSection, nil } -func decodeStartSection(r *bytes.Reader) (*uint32, error) { +func decodeStartSection(r *bytes.Reader) (*wasm.Index, error) { vs, _, err := leb128.DecodeUint32(r) if err != nil { return nil, fmt.Errorf("get size of vector: %w", err) @@ -218,14 +210,6 @@ func decodeDataSection(r *bytes.Reader) ([]*wasm.DataSegment, error) { return result, nil } -// encodeCustomSection encodes the opaque bytes for the given name as a SectionIDCustom -// See https://www.w3.org/TR/wasm-core-1/#binary-customsec -func encodeCustomSection(name string, data []byte) []byte { - // The contents of a custom section is the non-empty name followed by potentially empty opaque data - contents := append(encodeSizePrefixed([]byte(name)), data...) - return encodeSection(wasm.SectionIDCustom, contents) -} - // encodeSection encodes the sectionID, the size of its contents in bytes, followed by the contents. // See https://www.w3.org/TR/wasm-core-1/#sections%E2%91%A0 func encodeSection(sectionID wasm.SectionID, contents []byte) []byte { @@ -260,10 +244,10 @@ func encodeImportSection(imports []*wasm.Import) []byte { // WebAssembly 1.0 (MVP) Binary Format. // // See https://www.w3.org/TR/wasm-core-1/#function-section%E2%91%A0 -func encodeFunctionSection(functions []wasm.Index) []byte { - contents := leb128.EncodeUint32(uint32(len(functions))) - for _, typeIndex := range functions { - contents = append(contents, leb128.EncodeUint32(typeIndex)...) +func encodeFunctionSection(typeIndices []wasm.Index) []byte { + contents := leb128.EncodeUint32(uint32(len(typeIndices))) + for _, index := range typeIndices { + contents = append(contents, leb128.EncodeUint32(index)...) } return encodeSection(wasm.SectionIDFunction, contents) } @@ -307,6 +291,6 @@ func encodeExportSection(exports map[string]*wasm.Export) []byte { // encodeStartSection encodes a SectionIDStart for the given function index in WebAssembly 1.0 (MVP) Binary Format. // // See https://www.w3.org/TR/wasm-core-1/#start-section%E2%91%A0 -func encodeStartSection(funcidx uint32) []byte { +func encodeStartSection(funcidx wasm.Index) []byte { return encodeSection(wasm.SectionIDStart, leb128.EncodeUint32(funcidx)) } diff --git a/wasm/binary/section_test.go b/wasm/binary/section_test.go index ee77a53dce..78b480e439 100644 --- a/wasm/binary/section_test.go +++ b/wasm/binary/section_test.go @@ -111,6 +111,11 @@ func TestDecodeExportSection_Errors(t *testing.T) { } } +func TestEncodeFunctionSection(t *testing.T) { + require.Equal(t, []byte{wasm.SectionIDFunction, 0x2, 0x01, 0x05}, encodeFunctionSection([]wasm.Index{5})) +} + +// TestEncodeStartSection uses the same index as TestEncodeFunctionSection to highlight the encoding is different. func TestEncodeStartSection(t *testing.T) { require.Equal(t, []byte{wasm.SectionIDStart, 0x01, 0x05}, encodeStartSection(5)) } diff --git a/wasm/jit/compiler.go b/wasm/jit/compiler.go index 710bccbc4f..8503f43a72 100644 --- a/wasm/jit/compiler.go +++ b/wasm/jit/compiler.go @@ -10,9 +10,9 @@ import ( type compiler interface { // String is for debugging purpose. String() string - // emitPreamble is called before compiling any wazeroir operation. + // compilePreamble is called before compiling any wazeroir operation. // This is used, for example, to initilize the reserved registers, etc. - emitPreamble() error + compilePreamble() error // compile generates the byte slice of native code. // stackPointerCeil is the max stack pointer that the target function would reach. // staticData is compiledFunctionStaticData for the resutling native code. diff --git a/wasm/jit/engine.go b/wasm/jit/engine.go index 10475fde81..7d1197f63d 100644 --- a/wasm/jit/engine.go +++ b/wasm/jit/engine.go @@ -703,7 +703,7 @@ func compileWasmFunction(f *wasm.FunctionInstance) (*compiledFunction, error) { return nil, fmt.Errorf("failed to initialize assembly builder: %w", err) } - if err := compiler.emitPreamble(); err != nil { + if err := compiler.compilePreamble(); err != nil { return nil, fmt.Errorf("failed to emit preamble: %w", err) } diff --git a/wasm/jit/jit_amd64.go b/wasm/jit/jit_amd64.go index 308aa8bb32..e9884cdda4 100644 --- a/wasm/jit/jit_amd64.go +++ b/wasm/jit/jit_amd64.go @@ -4729,8 +4729,8 @@ func (c *amd64Compiler) callFunction(addr wasm.FunctionAddress, addrReg int16, f // 2) Set engine.valueStackContext.stackBasePointer for the next function. { - // At this point, tmpRegister holds the OLD stack base pointer. We could get the new frame's - // stack base pointer by "OLD stack base pointer + OLD stack pointer - # of function params" + // At this point, tmpRegister holds the old stack base pointer. We could get the new frame's + // stack base pointer by "old stack base pointer + old stack pointer - # of function params" // See the comments in engine.pushCallFrame which does exactly the same calculation in Go. calculateNextStackBasePointer := c.newProg() calculateNextStackBasePointer.As = x86.AADDQ @@ -4870,8 +4870,7 @@ func (c *amd64Compiler) returnFunction() error { // Obtain the temporary registers to be used in the followings. regs, found := c.locationStack.takeFreeRegisters(generalPurposeRegisterTypeInt, 3) if !found { - // This in theory never happen as all the registers must be free except addrReg. - return fmt.Errorf("could not find enough free registers") + return fmt.Errorf("BUG: all the registers should be free at this point") } c.locationStack.markRegisterUsed(regs...) @@ -5220,7 +5219,7 @@ func (c *amd64Compiler) exit(status jitCallStatusCode) { c.addInstruction(ret) } -func (c *amd64Compiler) emitPreamble() (err error) { +func (c *amd64Compiler) compilePreamble() (err error) { // We assume all function parameters are already pushed onto the stack by // the caller. c.pushFunctionParams() diff --git a/wasm/jit/jit_amd64_test.go b/wasm/jit/jit_amd64_test.go index 7542fd5cc9..60e7e99fa5 100644 --- a/wasm/jit/jit_amd64_test.go +++ b/wasm/jit/jit_amd64_test.go @@ -117,7 +117,7 @@ func TestAmd64Compiler_returnFunction(t *testing.T) { t.Run("last return", func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) const expectedValue uint32 = 100 @@ -159,7 +159,7 @@ func TestAmd64Compiler_returnFunction(t *testing.T) { for funcaddr := wasm.FunctionAddress(0); funcaddr < callFrameNums; funcaddr++ { // Each function pushes its funcaddr and soon returns. compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push its funcaddr. @@ -479,7 +479,7 @@ func TestAmd64Compiler_compileBrTable(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.ir = &wazeroir.CompilationResult{LabelCallers: map[string]uint32{}} - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) for _, r := range unreservedGeneralPurposeIntRegisters { @@ -530,7 +530,7 @@ func Test_setJITStatus(t *testing.T) { // Build codes. compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) compiler.exit(s) @@ -550,7 +550,7 @@ func Test_setJITStatus(t *testing.T) { func TestAmd64Compiler_initializeReservedRegisters(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) compiler.exit(jitCallStatusCodeReturned) @@ -576,7 +576,7 @@ func TestAmd64Compiler_allocateRegister(t *testing.T) { const stealTarget = x86.REG_AX env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Use up all the Int regs. for _, r := range unreservedGeneralPurposeIntRegisters { @@ -615,7 +615,7 @@ func TestAmd64Compiler_allocateRegister(t *testing.T) { func TestAmd64Compiler_compileLabel(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) label := &wazeroir.Label{FrameID: 100, Kind: wazeroir.LabelKindContinuation} labelKey := label.String() @@ -644,7 +644,7 @@ func TestAmd64Compiler_compilePick(t *testing.T) { t.Run("on reg", func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Set up the pick target original value. @@ -689,7 +689,7 @@ func TestAmd64Compiler_compilePick(t *testing.T) { t.Run("on stack", func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the original value. @@ -733,7 +733,7 @@ func TestAmd64Compiler_compileConstI32(t *testing.T) { t.Run(fmt.Sprintf("%d", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Now emit the const instruction. @@ -773,7 +773,7 @@ func TestAmd64Compiler_compileConstI64(t *testing.T) { t.Run(fmt.Sprintf("%d", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Now emit the const instruction. @@ -813,7 +813,7 @@ func TestAmd64Compiler_compileConstF32(t *testing.T) { t.Run(fmt.Sprintf("%f", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Now emit the const instruction. @@ -855,7 +855,7 @@ func TestAmd64Compiler_compileConstF64(t *testing.T) { t.Run(fmt.Sprintf("%f", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Now emit the const instruction. @@ -899,7 +899,7 @@ func TestAmd64Compiler_compileAdd(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: x1Value}) require.NoError(t, err) @@ -935,7 +935,7 @@ func TestAmd64Compiler_compileAdd(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstI64(&wazeroir.OperationConstI64{Value: x1Value}) require.NoError(t, err) @@ -980,7 +980,7 @@ func TestAmd64Compiler_compileAdd(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstF32(&wazeroir.OperationConstF32{Value: tc.v1}) require.NoError(t, err) @@ -1027,7 +1027,7 @@ func TestAmd64Compiler_compileAdd(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstF64(&wazeroir.OperationConstF64{Value: tc.v1}) @@ -1086,7 +1086,7 @@ func TestAmd64Compiler_emitEqOrNe(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push the cmp target values. @@ -1151,7 +1151,7 @@ func TestAmd64Compiler_emitEqOrNe(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push the cmp target values. @@ -1240,7 +1240,7 @@ func TestAmd64Compiler_emitEqOrNe(t *testing.T) { t.Run(fmt.Sprintf("x1=%f,x2=%f", tc.x1, tc.x2), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) useUpIntRegistersFunc(compiler) @@ -1318,7 +1318,7 @@ func TestAmd64Compiler_emitEqOrNe(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) useUpIntRegistersFunc(compiler) @@ -1382,7 +1382,7 @@ func TestAmd64Compiler_compileEqz(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push the cmp target value. @@ -1428,7 +1428,7 @@ func TestAmd64Compiler_compileEqz(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push the cmp target values. @@ -1493,7 +1493,7 @@ func TestAmd64Compiler_compileLe_or_Lt(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push the target values. @@ -1573,7 +1573,7 @@ func TestAmd64Compiler_compileLe_or_Lt(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push the target values. @@ -1661,7 +1661,7 @@ func TestAmd64Compiler_compileLe_or_Lt(t *testing.T) { t.Run(fmt.Sprintf("x1=%f,x2=%f", tc.x1, tc.x2), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() // Prepare operands. require.NoError(t, err) @@ -1741,7 +1741,7 @@ func TestAmd64Compiler_compileLe_or_Lt(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Prepare operands. @@ -1822,7 +1822,7 @@ func TestAmd64Compiler_compileGe_or_Gt(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push the target values. @@ -1904,7 +1904,7 @@ func TestAmd64Compiler_compileGe_or_Gt(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Push the target values. @@ -1993,7 +1993,7 @@ func TestAmd64Compiler_compileGe_or_Gt(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Prepare operands. @@ -2074,7 +2074,7 @@ func TestAmd64Compiler_compileGe_or_Gt(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Prepare operands. @@ -2137,7 +2137,7 @@ func TestAmd64Compiler_compileSub(t *testing.T) { const x2Value uint32 = 51 env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: x1Value}) require.NoError(t, err) @@ -2173,7 +2173,7 @@ func TestAmd64Compiler_compileSub(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstI64(&wazeroir.OperationConstI64{Value: x1Value}) @@ -2219,7 +2219,7 @@ func TestAmd64Compiler_compileSub(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstF32(&wazeroir.OperationConstF32{Value: tc.v1}) @@ -2267,7 +2267,7 @@ func TestAmd64Compiler_compileSub(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstF64(&wazeroir.OperationConstF64{Value: tc.v1}) @@ -2359,7 +2359,7 @@ func TestAmd64Compiler_compileMul(t *testing.T) { const dxValue uint64 = 111111 compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Pretend there was an existing value on the DX register. We expect compileMul to save this to the stack. @@ -2471,7 +2471,7 @@ func TestAmd64Compiler_compileMul(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Pretend there was an existing value on the DX register. We expect compileMul to save this to the stack. @@ -2547,7 +2547,7 @@ func TestAmd64Compiler_compileMul(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstF32(&wazeroir.OperationConstF32{Value: tc.x1}) @@ -2599,7 +2599,7 @@ func TestAmd64Compiler_compileMul(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstF64(&wazeroir.OperationConstF64{Value: tc.x1}) @@ -2646,7 +2646,7 @@ func TestAmd64Compiler_compilClz(t *testing.T) { t.Run(fmt.Sprintf("%032b", tc.input), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the target value. err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: tc.input}) @@ -2694,7 +2694,7 @@ func TestAmd64Compiler_compilClz(t *testing.T) { t.Run(fmt.Sprintf("%064b", tc.expectedLeadingZeros), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the target value. @@ -2744,7 +2744,7 @@ func TestAmd64Compiler_compilCtz(t *testing.T) { t.Run(fmt.Sprintf("%032b", tc.input), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the target value. @@ -2793,7 +2793,7 @@ func TestAmd64Compiler_compilCtz(t *testing.T) { t.Run(fmt.Sprintf("%064b", tc.input), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the target value. err = compiler.compileConstI64(&wazeroir.OperationConstI64{Value: tc.input}) @@ -2844,7 +2844,7 @@ func TestAmd64Compiler_compilPopcnt(t *testing.T) { t.Run(fmt.Sprintf("%032b", tc.input), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the target value. err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: tc.input}) @@ -2893,7 +2893,7 @@ func TestAmd64Compiler_compilPopcnt(t *testing.T) { t.Run(fmt.Sprintf("%064b", tc.in), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the target value. @@ -3017,7 +3017,7 @@ func TestAmd64Compiler_compile_and_or_xor_shl_shr_rotl_rotr(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) var is32Bit bool @@ -3241,7 +3241,7 @@ func TestAmd64Compiler_compileDiv(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Pretend there was an existing value on the DX register. We expect compileDivForInts to save this to the stack. @@ -3398,7 +3398,7 @@ func TestAmd64Compiler_compileDiv(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Pretend there was an existing value on the DX register. We expect compileDivForInts to save this to the stack. @@ -3509,7 +3509,7 @@ func TestAmd64Compiler_compileDiv(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstF32(&wazeroir.OperationConstF32{Value: tc.x1}) require.NoError(t, err) @@ -3583,7 +3583,7 @@ func TestAmd64Compiler_compileDiv(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstF64(&wazeroir.OperationConstF64{Value: tc.x1}) @@ -3702,7 +3702,7 @@ func TestAmd64Compiler_compileRem(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Pretend there was an existing value on the DX register. We expect compileDivForInts to save this to the stack. @@ -3858,7 +3858,7 @@ func TestAmd64Compiler_compileRem(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Pretend there was an existing value on the DX register. We expect compileDivForInts to save this to the stack. @@ -3951,7 +3951,7 @@ func TestAmd64Compiler_compileF32DemoteFromF64(t *testing.T) { t.Run(fmt.Sprintf("%f", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the demote target. @@ -3998,7 +3998,7 @@ func TestAmd64Compiler_compileF64PromoteFromF32(t *testing.T) { t.Run(fmt.Sprintf("%f", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the promote target. @@ -4052,7 +4052,7 @@ func TestAmd64Compiler_compileReinterpret(t *testing.T) { t.Run(fmt.Sprintf("%d", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) if originOnStack { @@ -4133,7 +4133,7 @@ func TestAmd64Compiler_compileExtend(t *testing.T) { t.Run(fmt.Sprintf("%v", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the promote target. @@ -4224,7 +4224,7 @@ func TestAmd64Compiler_compileITruncFromF(t *testing.T) { t.Run(fmt.Sprintf("%f", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the conversion target. @@ -4357,7 +4357,7 @@ func TestAmd64Compiler_compileFConvertFromI(t *testing.T) { t.Run(fmt.Sprintf("%d", v), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the conversion target. @@ -4472,7 +4472,7 @@ func TestAmd64Compiler_compile_abs_neg_ceil_floor(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) var is32Bit bool @@ -4681,7 +4681,7 @@ func TestAmd64Compiler_compile_min_max_copysign(t *testing.T) { t.Run(fmt.Sprintf("x1=%f_x2=%f", vs.x1, vs.x2), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) var is32Bit bool @@ -4792,7 +4792,7 @@ func TestAmd64Compiler_setupMemoryOffset(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: base}) @@ -4841,7 +4841,7 @@ func TestAmd64Compiler_compileLoad(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Before load operations, we must push the base offset value. @@ -4940,7 +4940,7 @@ func TestAmd64Compiler_compileLoad8(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Before load operations, we must push the base offset value. @@ -5002,7 +5002,7 @@ func TestAmd64Compiler_compileLoad16(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Before load operations, we must push the base offset value. @@ -5056,7 +5056,7 @@ func TestAmd64Compiler_compileLoad32(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Before load operations, we must push the base offset value. @@ -5116,7 +5116,7 @@ func TestAmd64Compiler_compileStore(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Before store operations, we must push the base offset, and the store target values. @@ -5176,7 +5176,7 @@ func TestAmd64Compiler_compileStore8(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Before store operations, we must push the base offset, and the store target values. @@ -5221,7 +5221,7 @@ func TestAmd64Compiler_compileStore16(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Before store operations, we must push the base offset, and the store target values. @@ -5266,7 +5266,7 @@ func TestAmd64Compiler_compileStore32(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Before store operations, we must push the base offset, and the store target values. @@ -5313,7 +5313,7 @@ func TestAmd64Compiler_compileMemoryGrow(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Emit memory.grow instructions. err = compiler.compileMemoryGrow() @@ -5321,7 +5321,7 @@ func TestAmd64Compiler_compileMemoryGrow(t *testing.T) { // Emit arbitrary code after memory.grow returned. const expValue uint32 = 100 - err = compiler.emitPreamble() + err = compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: expValue}) require.NoError(t, err) @@ -5355,7 +5355,7 @@ func TestAmd64Compiler_compileMemorySize(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f.ModuleInstance = env.moduleInstance - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Emit memory.size instructions. @@ -5488,7 +5488,7 @@ func TestAmd64Compiler_compileDrop(t *testing.T) { t.Run("real", func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) bottom := compiler.locationStack.pushValueLocationOnRegister(x86.REG_R10) @@ -5526,7 +5526,7 @@ func TestAmd64Compiler_compileDrop(t *testing.T) { func TestAmd64Compiler_releaseAllRegistersToStack(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) x1Reg := int16(x86.REG_AX) @@ -5624,7 +5624,7 @@ func TestAmd64Compiler_generate(t *testing.T) { func TestAmd64Compiler_compileUnreachable(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) x1Reg := int16(x86.REG_AX) @@ -5700,7 +5700,7 @@ func TestAmd64Compiler_compileSelect(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) var x1, x2, c *valueLocation @@ -5794,7 +5794,7 @@ func TestAmd64Compiler_compileSwap(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) if tc.x2OnRegister { @@ -5866,7 +5866,7 @@ func TestAmd64Compiler_compileGlobalGet(t *testing.T) { compiler.f = &wasm.FunctionInstance{ModuleInstance: &wasm.ModuleInstance{Globals: globals}} // Emit the code. - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) op := &wazeroir.OperationGlobalGet{Index: 1} err = compiler.compileGlobalGet(op) @@ -5920,7 +5920,7 @@ func TestAmd64Compiler_compileGlobalSet(t *testing.T) { env.stack()[loc.stackPointer] = valueToSet // Now emit the code. - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) op := &wazeroir.OperationGlobalSet{Index: 1} err = compiler.compileGlobalSet(op) @@ -5952,7 +5952,7 @@ func TestAmd64Compiler_callFunction(t *testing.T) { env.setCallFrameStackPointer(engine.globalContext.callFrameStackLen - 1) compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) require.Empty(t, compiler.locationStack.usedRegisters) @@ -6004,7 +6004,7 @@ func TestAmd64Compiler_callFunction(t *testing.T) { compiler := env.requireNewCompiler(t) compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, ModuleInstance: moduleInstance} - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) expectedValue += addTargetValue @@ -6047,7 +6047,7 @@ func TestAmd64Compiler_callFunction(t *testing.T) { // Now we start building the caller's code. compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) const initialValue = 100 @@ -6108,7 +6108,7 @@ func TestAmd64Compiler_compileCall(t *testing.T) { // Call target function takes three i32 arguments and does ADD 2 times. compiler := env.requireNewCompiler(t) compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}, ModuleInstance: &wasm.ModuleInstance{}} - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) for i := 0; i < 2; i++ { err = compiler.compileAdd(&wazeroir.OperationAdd{Type: wazeroir.UnsignedTypeI32}) @@ -6134,7 +6134,7 @@ func TestAmd64Compiler_compileCall(t *testing.T) { }, }} - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) var expectedValue uint32 @@ -6184,7 +6184,7 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { env.stack()[loc.stackPointer] = 10 // Now emit the code. - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) require.NoError(t, compiler.compileCallIndirect(targetOperation)) compiler.exit(jitCallStatusCodeReturned) @@ -6218,7 +6218,7 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { require.NoError(t, err) // Now emit the code. - err = compiler.emitPreamble() + err = compiler.compilePreamble() require.NoError(t, err) require.NoError(t, compiler.compileCallIndirect(targetOperation)) @@ -6246,7 +6246,7 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { // and the typeID doesn't match the table[targetOffset]'s type ID. table[0] = wasm.TableElement{FunctionTypeID: 50} - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Place the offfset value. @@ -6292,7 +6292,7 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { expectedReturnValue := uint32(i * 1000) { compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: expectedReturnValue}) require.NoError(t, err) @@ -6309,7 +6309,7 @@ func TestAmd64Compiler_compileCallIndirect(t *testing.T) { } compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Ensure that the module instance has the type information for targetOperation.TypeIndex, @@ -6354,7 +6354,7 @@ func TestAmd64Compiler_readInstructionAddress(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Set the acquisition target instruction to the one after JMP. @@ -6370,7 +6370,7 @@ func TestAmd64Compiler_readInstructionAddress(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) const destinationRegister = x86.REG_AX diff --git a/wasm/jit/jit_arm64.go b/wasm/jit/jit_arm64.go index 2985d3f629..0f4b0a9a2e 100644 --- a/wasm/jit/jit_arm64.go +++ b/wasm/jit/jit_arm64.go @@ -67,9 +67,9 @@ type arm64Compiler struct { builder *asm.Builder f *wasm.FunctionInstance ir *wazeroir.CompilationResult - // setBRTargetOnNextInstructions holds branch kind instructions (BR, conditional BR, etc) + // setBranchTargetOnNextInstructions holds branch kind instructions (BR, conditional BR, etc) // where we want to set the next coming instruction as the destination of these BR instructions. - setBRTargetOnNextInstructions []*obj.Prog + setBranchTargetOnNextInstructions []*obj.Prog // locationStack holds the state of wazeroir virtual stack. // and each item is either placed in register or the actual memory stack. locationStack *valueLocationStack @@ -77,6 +77,8 @@ type arm64Compiler struct { labels map[string]*labelInfo // stackPointerCeil is the greatest stack pointer value (from valueLocationStack) seen during compilation. stackPointerCeil uint64 + // afterAssembleCallback hold the callbacks which are called after assembling native code. + afterAssembleCallback []func(code []byte) error } // compile implements compiler.compile for the arm64 architecture. @@ -89,10 +91,19 @@ func (c *arm64Compiler) compile() (code []byte, staticData compiledFunctionStati stackPointerCeil = c.locationStack.stackPointerCeil } - code, err = mmapCodeSegment(c.builder.Assemble()) + original := c.builder.Assemble() + + for _, cb := range c.afterAssembleCallback { + if err = cb(original); err != nil { + return + } + } + + code, err = mmapCodeSegment(original) if err != nil { return } + return } @@ -117,10 +128,10 @@ func (c *arm64Compiler) label(labelKey string) *labelInfo { func (c *arm64Compiler) newProg() (inst *obj.Prog) { inst = c.builder.NewProg() - for _, origin := range c.setBRTargetOnNextInstructions { + for _, origin := range c.setBranchTargetOnNextInstructions { origin.To.SetTarget(inst) } - c.setBRTargetOnNextInstructions = nil + c.setBranchTargetOnNextInstructions = nil return } @@ -128,22 +139,24 @@ func (c *arm64Compiler) addInstruction(inst *obj.Prog) { c.builder.AddInstruction(inst) } -func (c *arm64Compiler) setBRTargetOnNext(progs ...*obj.Prog) { - c.setBRTargetOnNextInstructions = append(c.setBRTargetOnNextInstructions, progs...) +func (c *arm64Compiler) setBranchTargetOnNext(progs ...*obj.Prog) { + c.setBranchTargetOnNextInstructions = append(c.setBranchTargetOnNextInstructions, progs...) } func (c *arm64Compiler) markRegisterUsed(reg int16) { c.locationStack.markRegisterUsed(reg) } -func (c *arm64Compiler) markRegisterUnused(reg int16) { - if !isZeroRegister(reg) { - c.locationStack.markRegisterUnused(reg) +func (c *arm64Compiler) markRegisterUnused(regs ...int16) { + for _, reg := range regs { + if !isZeroRegister(reg) { + c.locationStack.markRegisterUnused(reg) + } } } -// applyConstToRegisterInstruction adds an instruction where source operand is a constant and destination is a register. -func (c *arm64Compiler) applyConstToRegisterInstruction(instruction obj.As, constValue int64, destinationRegister int16) { +// compileConstToRegisterInstruction adds an instruction where source operand is a constant and destination is a register. +func (c *arm64Compiler) compileConstToRegisterInstruction(instruction obj.As, constValue int64, destinationRegister int16) { applyConst := c.newProg() applyConst.As = instruction applyConst.From.Type = obj.TYPE_CONST @@ -156,77 +169,66 @@ func (c *arm64Compiler) applyConstToRegisterInstruction(instruction obj.As, cons c.addInstruction(applyConst) } -// applyMemoryToRegisterInstruction adds an instruction where source operand is a memory location and destination is a register. +// compileMemoryToRegisterInstruction adds an instruction where source operand is a memory location and destination is a register. // baseRegister is the base absolute address in the memory, and offset is the offset from the absolute address in baseRegister. -func (c *arm64Compiler) applyMemoryToRegisterInstruction(instruction obj.As, baseRegister int16, offset int64, destinationRegister int16) (err error) { - if offset > math.MaxInt16 { - // This is a bug in JIT copmiler: caller must check the offset at compilation time, and avoid such a large offset - // by loading the const to the register beforehand and then using applyRegisterOffsetMemoryToRegisterInstruction instead. - err = fmt.Errorf("memory offset must be smaller than or equal %d, but got %d", math.MaxInt16, offset) - return +func (c *arm64Compiler) compileMemoryToRegisterInstruction(instruction obj.As, sourceBaseRegister int16, sourceOffsetConst int64, destinationRegister int16) { + if sourceOffsetConst > math.MaxInt16 { + // The assembler can take care of offsets larger than 2^15-1 by emitting additional instructions to load such large offset, + // but it uses "its" temporary register which we cannot track. Therefore, we avoid directly emitting memory load with large offsets, + // but instead load the constant manually to "our" temporary register, then emit the load with it. + c.compileConstToRegisterInstruction(arm64.AMOVD, sourceOffsetConst, reservedRegisterForTemporary) + inst := c.newProg() + inst.As = instruction + inst.From.Type = obj.TYPE_MEM + inst.From.Reg = sourceBaseRegister + inst.From.Index = reservedRegisterForTemporary + inst.From.Scale = 1 + inst.To.Type = obj.TYPE_REG + inst.To.Reg = destinationRegister + c.addInstruction(inst) + } else { + inst := c.newProg() + inst.As = instruction + inst.From.Type = obj.TYPE_MEM + inst.From.Reg = sourceBaseRegister + inst.From.Offset = sourceOffsetConst + inst.To.Type = obj.TYPE_REG + inst.To.Reg = destinationRegister + c.addInstruction(inst) } - inst := c.newProg() - inst.As = instruction - inst.From.Type = obj.TYPE_MEM - inst.From.Reg = baseRegister - inst.From.Offset = offset - inst.To.Type = obj.TYPE_REG - inst.To.Reg = destinationRegister - c.addInstruction(inst) - return } -// applyRegisterOffsetMemoryToRegisterInstruction adds an instruction where source operand is a memory location and destination is a register. -// The difference from applyMemoryToRegisterInstruction is that here we specify the offset by a register rather than offset constant. -func (c *arm64Compiler) applyRegisterOffsetMemoryToRegisterInstruction(instruction obj.As, baseRegister, offsetRegister, destinationRegister int16) (err error) { - inst := c.newProg() - inst.As = instruction - inst.From.Type = obj.TYPE_MEM - inst.From.Reg = baseRegister - inst.From.Index = offsetRegister - inst.From.Scale = 1 - inst.To.Type = obj.TYPE_REG - inst.To.Reg = destinationRegister - c.addInstruction(inst) - return nil -} - -// applyRegisterToMemoryInstruction adds an instruction where destination operand is a memory location and source is a register. -// This is the opposite of applyMemoryToRegisterInstruction. -func (c *arm64Compiler) applyRegisterToMemoryInstruction(instruction obj.As, baseRegister int16, offset int64, source int16) (err error) { - if offset > math.MaxInt16 { - // This is a bug in JIT copmiler: caller must check the offset at compilation time, and avoid such a large offset - // by loading the const to the register beforehand and then using applyRegisterToRegisterOffsetMemoryInstruction instead. - err = fmt.Errorf("memory offset must be smaller than or equal %d, but got %d", math.MaxInt16, offset) - return +// compileRegisterToMemoryInstruction adds an instruction where destination operand is a memory location and source is a register. +// This is the opposite of compileMemoryToRegisterInstruction. +func (c *arm64Compiler) compileRegisterToMemoryInstruction(instruction obj.As, sourceRegister int16, destinationBaseRegister int16, destinationOffsetConst int64) { + if destinationOffsetConst > math.MaxInt16 { + // The assembler can take care of offsets larger than 2^15-1 by emitting additional instructions to load such large offset, + // but we cannot track its temporary register. Therefore, we avoid directly emitting memory load with large offsets: + // load the constant manually to "our" temporary register, then emit the load with it. + c.compileConstToRegisterInstruction(arm64.AMOVD, destinationOffsetConst, reservedRegisterForTemporary) + inst := c.newProg() + inst.As = instruction + inst.To.Type = obj.TYPE_MEM + inst.To.Reg = destinationBaseRegister + inst.To.Index = reservedRegisterForTemporary + inst.To.Scale = 1 + inst.From.Type = obj.TYPE_REG + inst.From.Reg = sourceRegister + c.addInstruction(inst) + } else { + inst := c.newProg() + inst.As = instruction + inst.To.Type = obj.TYPE_MEM + inst.To.Reg = destinationBaseRegister + inst.To.Offset = destinationOffsetConst + inst.From.Type = obj.TYPE_REG + inst.From.Reg = sourceRegister + c.addInstruction(inst) } - inst := c.newProg() - inst.As = instruction - inst.To.Type = obj.TYPE_MEM - inst.To.Reg = baseRegister - inst.To.Offset = offset - inst.From.Type = obj.TYPE_REG - inst.From.Reg = source - c.addInstruction(inst) - return } -// applyRegisterToRegisterOffsetMemoryInstruction adds an instruction where destination operand is a memory location and source is a register. -// The difference from applyRegisterToMemoryInstruction is that here we specify the offset by a register rather than offset constant. -func (c *arm64Compiler) applyRegisterToRegisterOffsetMemoryInstruction(instruction obj.As, baseRegister, offsetRegister, source int16) { - inst := c.newProg() - inst.As = instruction - inst.To.Type = obj.TYPE_MEM - inst.To.Reg = baseRegister - inst.To.Index = offsetRegister - inst.To.Scale = 1 - inst.From.Type = obj.TYPE_REG - inst.From.Reg = source - c.addInstruction(inst) -} - -// applyRegisterToRegisterOffsetMemoryInstruction adds an instruction where both destination and source operands are registers. -func (c *arm64Compiler) applyRegisterToRegisterInstruction(instruction obj.As, from, to int16) { +// compileRegisterToRegisterInstruction adds an instruction where both destination and source operands are registers. +func (c *arm64Compiler) compileRegisterToRegisterInstruction(instruction obj.As, from, to int16) { inst := c.newProg() inst.As = instruction inst.To.Type = obj.TYPE_REG @@ -236,8 +238,8 @@ func (c *arm64Compiler) applyRegisterToRegisterInstruction(instruction obj.As, f c.addInstruction(inst) } -// applyRegisterToRegisterOffsetMemoryInstruction adds an instruction which takes two source operands on registers and one destination register operand. -func (c *arm64Compiler) applyTwoRegistersToRegisterInstruction(instruction obj.As, src1, src2, destination int16) { +// compileTwoRegistersToRegisterInstruction adds an instruction which takes two source operands on registers and one destination register operand. +func (c *arm64Compiler) compileTwoRegistersToRegisterInstruction(instruction obj.As, src1, src2, destination int16) { inst := c.newProg() inst.As = instruction inst.To.Type = obj.TYPE_REG @@ -248,8 +250,8 @@ func (c *arm64Compiler) applyTwoRegistersToRegisterInstruction(instruction obj.A c.addInstruction(inst) } -// applyTwoRegistersToNoneInstruction adds an instruction which takes two source operands on registers. -func (c *arm64Compiler) applyTwoRegistersToNoneInstruction(instruction obj.As, src1, src2 int16) { +// compileTwoRegistersToNoneInstruction adds an instruction which takes two source operands on registers. +func (c *arm64Compiler) compileTwoRegistersToNoneInstruction(instruction obj.As, src1, src2 int16) { inst := c.newProg() inst.As = instruction // TYPE_NONE indicates that this instruction doesn't have a destination. @@ -261,14 +263,22 @@ func (c *arm64Compiler) applyTwoRegistersToNoneInstruction(instruction obj.As, s c.addInstruction(inst) } -func (c *arm64Compiler) emitUnconditionalBRInstruction(targetType obj.AddrType) (jmp *obj.Prog) { - jmp = c.newProg() - jmp.As = obj.AJMP - jmp.To.Type = targetType - c.addInstruction(jmp) +func (c *arm64Compiler) compileUnconditionalBranchInstruction() (br *obj.Prog) { + br = c.newProg() + br.As = obj.AJMP + br.To.Type = obj.TYPE_BRANCH + c.addInstruction(br) return } +func (c *arm64Compiler) compileUnconditionalBranchToAddressOnRegister(addressRegister int16) { + br := c.newProg() + br.As = obj.AJMP + br.To.Type = obj.TYPE_MEM + br.To.Reg = addressRegister + c.addInstruction(br) +} + func (c *arm64Compiler) String() (ret string) { return } // pushFunctionParams pushes any function parameters onto the stack, setting appropriate register types. @@ -287,8 +297,8 @@ func (c *arm64Compiler) pushFunctionParams() { } } -// emitPreamble implements compiler.emitPreamble for the arm64 architecture. -func (c *arm64Compiler) emitPreamble() error { +// compilePreamble implements compiler.compilePreamble for the arm64 architecture. +func (c *arm64Compiler) compilePreamble() error { // The assembler skips the first instruction so we intentionally add NOP here. nop := c.newProg() nop.As = obj.ANOP @@ -298,61 +308,119 @@ func (c *arm64Compiler) emitPreamble() error { // Before excuting function body, we must initialize the stack base pointer register // so that we can manipulate the memory stack properly. - return c.initializeReservedStackBasePointerRegister() + return c.compileInitializeReservedStackBasePointerRegister() } // returnFunction emits instructions to return from the current function frame. // If the current frame is the bottom, the code goes back to the Go code with jitCallStatusCodeReturned status. -// Otherwise, we branch into the caller's return address (TODO). -func (c *arm64Compiler) returnFunction() error { - // TODO: we don't support function calls yet. - // For now the following code just returns to Go code. - - // Since we return from the function, we need to decrement the callframe stack pointer, and write it back. - callFramePointerReg, _ := c.locationStack.takeFreeRegister(generalPurposeRegisterTypeInt) - if err := c.applyMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, - engineGlobalContextCallFrameStackPointerOffset, callFramePointerReg); err != nil { - return err - } - c.applyConstToRegisterInstruction(arm64.ASUBS, 1, callFramePointerReg) - if err := c.applyRegisterToMemoryInstruction(arm64.AMOVD, reservedRegisterForEngine, - engineGlobalContextCallFrameStackPointerOffset, callFramePointerReg); err != nil { +// Otherwise, we branch into the caller's return address. +func (c *arm64Compiler) compileReturnFunction() error { + // Release all the registers as our calling convention requires the caller-save. + if err := c.compileReleaseAllRegistersToStack(); err != nil { return err } - return c.exit(jitCallStatusCodeReturned) + // Since we return from the function, we need to decrement the callframe stack pointer, and write it back. + tmpRegs, found := c.locationStack.takeFreeRegisters(generalPurposeRegisterTypeInt, 3) + if !found { + return fmt.Errorf("BUG: all the registers should be free at this point") + } + + // Alias for readability. + callFramePointerReg, callFrameStackTopAddressRegister, tmpReg := tmpRegs[0], tmpRegs[1], tmpRegs[2] + + // First we decrement the callframe stack pointer. + c.compileMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset, callFramePointerReg) + c.compileConstToRegisterInstruction(arm64.ASUBS, 1, callFramePointerReg) + c.compileRegisterToMemoryInstruction(arm64.AMOVD, callFramePointerReg, reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset) + + // Next we compare the decremented call frame stack pointer with the engine.precviousCallFrameStackPointer. + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextPreviouscallFrameStackPointer, + tmpReg, + ) + c.compileTwoRegistersToNoneInstruction(arm64.ACMP, callFramePointerReg, tmpReg) + + // If the values are identical, we return back to the Go code with returned status. + brIfNotEqual := c.newProg() + brIfNotEqual.As = arm64.ABNE + brIfNotEqual.To.Type = obj.TYPE_BRANCH + c.addInstruction(brIfNotEqual) + c.exit(jitCallStatusCodeReturned) + + // Otherwise, we have to jump to the caller's return address. + c.setBranchTargetOnNext(brIfNotEqual) + + // First, we have to calculate the caller callFrame's absolute address to aquire the return address. + // + // "tmpReg = &engine.callFrameStack[0]" + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCallFrameStackElement0AddressOffset, + tmpReg, + ) + // "callFrameStackTopAddressRegister = tmpReg + callFramePointerReg << ${callFrameDataSizeMostSignificantSetBit}" + c.compileAddInstructionWithLeftShiftedRegister( + callFramePointerReg, callFrameDataSizeMostSignificantSetBit, + tmpReg, + callFrameStackTopAddressRegister, + ) + + // At this point, we have + // + // [......., ra.caller, rb.caller, rc.caller, _, ra.current, rb.current, rc.current, _, ...] <- call frame stack's data region (somewhere in the memory) + // ^ + // callFrameStackTopAddressRegister + // (absolute address of &callFrameStack[engine.callFrameStackPointer]) + // + // where: + // ra.* = callFrame.returnAddress + // rb.* = callFrame.returnStackBasePointer + // rc.* = callFrame.compiledFunction + // _ = callFrame's padding (see comment on callFrame._ field.) + // + // What we have to do in the following is that + // 1) Set engine.valueStackContext.stackBasePointer to the value on "rb.caller". + // 2) Jump into the address of "ra.caller". + + // 1) Set engine.valueStackContext.stackBasePointer to the value on "rb.caller". + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + // "rb.caller" is below the top address. + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameReturnStackBasePointerOffset), + tmpReg) + c.compileRegisterToMemoryInstruction(arm64.AMOVD, + tmpReg, + reservedRegisterForEngine, engineValueStackContextStackBasePointerOffset) + + // 2) Branch into the address of "ra.caller". + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + // "rb.caller" is below the top address. + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameReturnAddressOffset), + tmpReg) + c.compileUnconditionalBranchToAddressOnRegister(tmpReg) + + c.locationStack.markRegisterUnused(tmpRegs...) + return nil } // exit adds instructions to give the control back to engine.exec with the given status code. func (c *arm64Compiler) exit(status jitCallStatusCode) error { // Write the current stack pointer to the engine.stackPointer. - c.applyConstToRegisterInstruction(arm64.AMOVW, int64(c.locationStack.sp), reservedRegisterForTemporary) - if err := c.applyRegisterToMemoryInstruction(arm64.AMOVW, reservedRegisterForEngine, - engineValueStackContextStackPointerOffset, reservedRegisterForTemporary); err != nil { - return err - } + c.compileConstToRegisterInstruction(arm64.AMOVW, int64(c.locationStack.sp), reservedRegisterForTemporary) + c.compileRegisterToMemoryInstruction(arm64.AMOVW, reservedRegisterForTemporary, reservedRegisterForEngine, + engineValueStackContextStackPointerOffset) if status != 0 { - c.applyConstToRegisterInstruction(arm64.AMOVW, int64(status), reservedRegisterForTemporary) - if err := c.applyRegisterToMemoryInstruction(arm64.AMOVWU, reservedRegisterForEngine, - engineExitContextJITCallStatusCodeOffset, reservedRegisterForTemporary); err != nil { - return err - } + c.compileConstToRegisterInstruction(arm64.AMOVW, int64(status), reservedRegisterForTemporary) + c.compileRegisterToMemoryInstruction(arm64.AMOVWU, reservedRegisterForTemporary, reservedRegisterForEngine, engineExitContextJITCallStatusCodeOffset) } else { // If the status == 0, we use zero register to store zero. - if err := c.applyRegisterToMemoryInstruction(arm64.AMOVWU, reservedRegisterForEngine, - engineExitContextJITCallStatusCodeOffset, zeroRegister); err != nil { - return err - } + c.compileRegisterToMemoryInstruction(arm64.AMOVWU, zeroRegister, reservedRegisterForEngine, engineExitContextJITCallStatusCodeOffset) } // The return address to the Go code is stored in archContext.jitReturnAddress which // is embedded in engine. We load the value to the tmpRegister, and then // invoke RET with that register. - if err := c.applyMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, - engineArchContextJITCallReturnAddressOffset, reservedRegisterForTemporary); err != nil { - return err - } + c.compileMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, engineArchContextJITCallReturnAddressOffset, reservedRegisterForTemporary) ret := c.newProg() ret.As = obj.ARET @@ -426,8 +494,8 @@ func (c *arm64Compiler) compileGlobalSet(o *wazeroir.OperationGlobalSet) error { // compileBr implements compiler.compileBr for the arm64 architecture. func (c *arm64Compiler) compileBr(o *wazeroir.OperationBr) error { - c.maybeMoveTopConditionalToFreeGeneralPurposeRegister() - return c.branchInto(o.Target) + c.maybeCompileMoveTopConditionalToFreeGeneralPurposeRegister() + return c.compileBranchInto(o.Target) } // compileBrIf implements compiler.compileBrIf for the arm64 architecture. @@ -474,12 +542,12 @@ func (c *arm64Compiler) compileBrIf(o *wazeroir.OperationBrIf) error { } else { // If the value is not on the conditional register, we compare the value with the zero register, // and then do the conditional BR if the value does't equal zero. - if err := c.ensureOnGeneralPurposeRegister(cond); err != nil { + if err := c.maybeCompileEnsureOnGeneralPurposeRegister(cond); err != nil { return err } // Compare the value with zero register. Note that the value is ensured to be i32 by function validation phase, // so we use CMPW (32-bit compare) here. - c.applyTwoRegistersToNoneInstruction(arm64.ACMPW, cond.register, zeroRegister) + c.compileTwoRegistersToNoneInstruction(arm64.ACMPW, cond.register, zeroRegister) conditionalBR.As = arm64.ABNE c.markRegisterUnused(cond.register) @@ -492,33 +560,33 @@ func (c *arm64Compiler) compileBrIf(o *wazeroir.OperationBrIf) error { // and we have to avoid affecting the code generation for Then branch afterwards. saved := c.locationStack c.setLocationStack(saved.clone()) - if err := c.emitDropRange(o.Else.ToDrop); err != nil { + if err := c.compileDropRange(o.Else.ToDrop); err != nil { return err } - c.branchInto(o.Else.Target) + c.compileBranchInto(o.Else.Target) // Now ready to emit the code for branching into then branch. // Retrieve the original value location stack so that the code below wont'be affected by the Else branch ^^. c.setLocationStack(saved) // We branch into here from the original conditional BR (conditionalBR). - c.setBRTargetOnNext(conditionalBR) - if err := c.emitDropRange(o.Then.ToDrop); err != nil { + c.setBranchTargetOnNext(conditionalBR) + if err := c.compileDropRange(o.Then.ToDrop); err != nil { return err } - c.branchInto(o.Then.Target) + c.compileBranchInto(o.Then.Target) return nil } -func (c *arm64Compiler) branchInto(target *wazeroir.BranchTarget) error { +func (c *arm64Compiler) compileBranchInto(target *wazeroir.BranchTarget) error { if target.IsReturnTarget() { - return c.returnFunction() + return c.compileReturnFunction() } else { labelKey := target.String() if c.ir.LabelCallers[labelKey] > 1 { // We can only re-use register state if when there's a single call-site. // Release existing values on registers to the stack if there's multiple ones to have // the consistent value location state at the beginning of label. - if err := c.releaseAllRegistersToStack(); err != nil { + if err := c.compileReleaseAllRegistersToStack(); err != nil { return err } } @@ -530,8 +598,8 @@ func (c *arm64Compiler) branchInto(target *wazeroir.BranchTarget) error { targetLabel.initialStack = c.locationStack.clone() } - jmp := c.emitUnconditionalBRInstruction(obj.TYPE_BRANCH) - c.assignBranchTarget(labelKey, jmp) + br := c.compileUnconditionalBranchInstruction() + c.assignBranchTarget(labelKey, br) return nil } } @@ -554,8 +622,229 @@ func (c *arm64Compiler) compileBrTable(o *wazeroir.OperationBrTable) error { return fmt.Errorf("TODO: unsupported on arm64") } +// compileCall implements compiler.compileCall for the arm64 architecture. func (c *arm64Compiler) compileCall(o *wazeroir.OperationCall) error { - return fmt.Errorf("TODO: unsupported on arm64") + target := c.f.ModuleInstance.Functions[o.FunctionIndex] + + if err := c.compileCallFunction(target.Address, target.FunctionType.Type); err != nil { + return err + } + return nil +} + +// compileCall implements compiler.compileCall and compiler.compileCallIndirect (TODO) for the arm64 architecture. +func (c *arm64Compiler) compileCallFunction(addr wasm.FunctionAddress, functype *wasm.FunctionType) error { + // TODO: the following code can be generalized for CallIndirect. + + // Release all the registers as our calling convention requires the caller-save. + if err := c.compileReleaseAllRegistersToStack(); err != nil { + return err + } + + // Obtain the free registers to be used in the followings. + freeRegisters, found := c.locationStack.takeFreeRegisters(generalPurposeRegisterTypeInt, 5) + if !found { + return fmt.Errorf("BUG: all registers except addrReg should be free at this point") + } + c.locationStack.markRegisterUsed(freeRegisters...) + + // Alias for readability. + callFrameStackTopAddressRegister, compiledFunctionAddressRegister, oldStackBasePointer, + tmp := freeRegisters[0], freeRegisters[1], freeRegisters[2], freeRegisters[3] + + // TODO: Check the callframe stack length, and if necessary, grow the call frame stack before jump into the target. + + // "tmp = engine.callFrameStackPointer" + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset, + tmp) + // "callFrameStackTopAddressRegister = &engine.callFrameStack[0]" + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCallFrameStackElement0AddressOffset, + callFrameStackTopAddressRegister) + // "callFrameStackTopAddressRegister += tmp << $callFrameDataSizeMostSignificantSetBit" + c.compileAddInstructionWithLeftShiftedRegister( + tmp, callFrameDataSizeMostSignificantSetBit, + callFrameStackTopAddressRegister, + callFrameStackTopAddressRegister, + ) + + // At this point, we have: + // + // [..., ra.current, rb.current, rc.current, _, ra.next, rb.next, rc.next, ...] <- call frame stack's data region (somewhere in the memory) + // ^ + // callFrameStackTopAddressRegister + // (absolute address of &callFrame[engine.callFrameStackPointer]]) + // + // where: + // ra.* = callFrame.returnAddress + // rb.* = callFrame.returnStackBasePointer + // rc.* = callFrame.compiledFunction + // _ = callFrame's padding (see comment on callFrame._ field.) + // + // In the following comment, we use the notations in the above example. + // + // What we have to do in the following is that + // 1) Set rb.current so that we can return back to this function properly. + // 2) Set engine.valueStackContext.stackBasePointer for the next function. + // 3) Set rc.next to specify which function is executed on the current call frame (needs to make Go function calls). + // 4) Set ra.current so that we can return back to this function properly. + + // 1) Set rb.current so that we can return back to this function properly. + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineValueStackContextStackBasePointerOffset, + oldStackBasePointer) + c.compileRegisterToMemoryInstruction(arm64.AMOVD, + oldStackBasePointer, + // "rb.current" is BELOW the top address. See the above example for detail. + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameReturnStackBasePointerOffset)) + + // 2) Set engine.valueStackContext.stackBasePointer for the next function. + // + // At this point, oldStackBasePointer holds the old stack base pointer. We could get the new frame's + // stack base pointer by "old stack base pointer + old stack pointer - # of function params" + // See the comments in engine.pushCallFrame which does exactly the same calculation in Go. + c.compileConstToRegisterInstruction(arm64.AADD, + int64(c.locationStack.sp)-int64(len(functype.Params)), + oldStackBasePointer) + c.compileRegisterToMemoryInstruction(arm64.AMOVD, + oldStackBasePointer, + reservedRegisterForEngine, engineValueStackContextStackBasePointerOffset) + + // 3) Set rc.next to specify which function is executed on the current call frame. + // + // First, we read the address of the first item of engine.compiledFunctions slice (= &engine.compiledFunctions[0]) + // into tmp. + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCompiledFunctionsElement0AddressOffset, + tmp) + + // Next, read the address of the target function (= &engine.compiledFunctions[offset]) + // into compiledFunctionAddressRegister. + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + tmp, int64(addr)*8, // * 8 because the size of *compiledFunction equals 8 bytes. + compiledFunctionAddressRegister) + + // Finally, we are ready to write the address of the target function's *compiledFunction into the new callframe. + c.compileRegisterToMemoryInstruction(arm64.AMOVD, + compiledFunctionAddressRegister, + callFrameStackTopAddressRegister, callFrameCompiledFunctionOffset) + + // 4) Set ra.current so that we can return back to this function properly. + // + // First, Get the return address into the tmp. + c.readInstructionAddress(obj.AJMP, tmp) + // Then write the address into the callframe. + c.compileRegisterToMemoryInstruction(arm64.AMOVD, + tmp, + // "ra.current" is BELOW the top address. See the above example for detail. + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameReturnAddressOffset), + ) + + // Everthing is done to make function call now. + // We increment the callframe stack pointer. + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset, + tmp) + c.compileConstToRegisterInstruction(arm64.AADD, 1, tmp) + c.compileRegisterToMemoryInstruction(arm64.AMOVD, + tmp, + reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset) + + // Then, br into the target function's initial address. + c.compileMemoryToRegisterInstruction(arm64.AMOVD, + compiledFunctionAddressRegister, compiledFunctionCodeInitialAddressOffset, + tmp) + c.compileUnconditionalBranchToAddressOnRegister(tmp) + + // All the registers used are temporary so we mark them unused. + c.markRegisterUnused(freeRegisters...) + + // We consumed the function parameters from the stack after call. + for i := 0; i < len(functype.Params); i++ { + c.locationStack.pop() + } + + // Also, the function results were pushed by the call. + for _, t := range functype.Results { + loc := c.locationStack.pushValueLocationOnStack() + switch t { + case wasm.ValueTypeI32, wasm.ValueTypeI64: + loc.setRegisterType(generalPurposeRegisterTypeInt) + case wasm.ValueTypeF32, wasm.ValueTypeF64: + loc.setRegisterType(generalPurposeRegisterTypeFloat) + } + } + + // On the function return, we initialize the state for this function. + c.compileInitializeReservedStackBasePointerRegister() + + // TODO: initialize module context, and memory pointer. + return nil +} + +// readInstructionAddress adds an ADR instruction to set the absolute address of "target instruction" +// into destinationRegister. "target instruction" is specified by beforeTargetInst argument and +// the target is determined by "the instruction right after beforeTargetInst type". +// +// For example, if beforeTargetInst == RET and we have the instruction sequence like +// ADR -> X -> Y -> ... -> RET -> MOV, then the ADR instruction emitted by this function set the absolute +// address of MOV instruction into the destination register. +func (c *arm64Compiler) readInstructionAddress(beforeTargetInst obj.As, destinationRegister int16) { + // Emit ADR instruction to read the specified instruction's absolute address. + // Note: we cannot emit the "ADR REG, $(target's offset from here)" due to the + // incapability of the assembler. Instead, we emit "ADR REG, ." meaning that + // "reading the current program counter" = "reading the absolute address of this ADR instruction". + // And then, after compilation phase, we directly edit the native code slice so that + // it can properly read the target instruction's absolute address. + readAddress := c.newProg() + readAddress.As = arm64.AADR + readAddress.From.Type = obj.TYPE_BRANCH + readAddress.To.Type = obj.TYPE_REG + readAddress.To.Reg = destinationRegister + c.addInstruction(readAddress) + + // Setup the callback to modify the instruction bytes after compilation. + // Note: this is the closure over readAddress (*obj.Prog). + c.afterAssembleCallback = append(c.afterAssembleCallback, func(code []byte) error { + // Find the target instruction. + target := readAddress + for target != nil { + if target.As == beforeTargetInst { + // At this point, target is the instruction right before the target instruction. + // Thus, advance one more time to make target the target instruction. + target = target.Link + break + } + target = target.Link + } + + if target == nil { + return fmt.Errorf("BUG: target instruction not found for read instruction address") + } + + offset := target.Pc - readAddress.Pc + if offset > math.MaxUint8 { + // We could support up to 20-bit integer, but byte should be enough for our impl. + // If the necessity comes up, we could fix the below to support larger offsets. + return fmt.Errorf("BUG: too large offset for read") + } + + // Now ready to write an offset byte. + v := byte(offset) + // arm64 has 4-bytes = 32-bit fixed-length instruction. + adrInstructionBytes := code[readAddress.Pc : readAddress.Pc+4] + // According to the binary format of ADR instruction in arm64: + // https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/ADR--Form-PC-relative-address-?lang=en + // + // The 0 to 1 bits live on 29 to 30 bits of the instruction. + adrInstructionBytes[3] |= (v & 0b00000011) << 5 + // The 2 to 4 bits live on 5 to 7 bits of the instruction. + adrInstructionBytes[0] |= (v & 0b00011100) << 3 + // The 5 to 7 bits live on 8 to 10 bits of the instruction. + adrInstructionBytes[1] |= (v & 0b11100000) >> 5 + return nil + }) } func (c *arm64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) error { @@ -564,11 +853,11 @@ func (c *arm64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) e // compileDrop implements compiler.compileDrop for the arm64 architecture. func (c *arm64Compiler) compileDrop(o *wazeroir.OperationDrop) error { - return c.emitDropRange(o.Range) + return c.compileDropRange(o.Range) } -// emitDropRange is the implementation of compileDrop. See compiler.compileDrop. -func (c *arm64Compiler) emitDropRange(r *wazeroir.InclusiveRange) error { +// compileDropRange is the implementation of compileDrop. See compiler.compileDrop. +func (c *arm64Compiler) compileDropRange(r *wazeroir.InclusiveRange) error { if r == nil { return nil } else if r.Start == 0 { @@ -583,7 +872,7 @@ func (c *arm64Compiler) emitDropRange(r *wazeroir.InclusiveRange) error { // Below, we might end up moving a non-top value first which // might result in changing the flag value. - c.maybeMoveTopConditionalToFreeGeneralPurposeRegister() + c.maybeCompileMoveTopConditionalToFreeGeneralPurposeRegister() // Save the live values because we pop and release values in drop range below. liveValues := c.locationStack.stack[c.locationStack.sp-uint64(r.Start):] @@ -603,7 +892,7 @@ func (c *arm64Compiler) emitDropRange(r *wazeroir.InclusiveRange) error { // If the value is on a memory, we have to move it to a register, // otherwise the memory location is overriden by other values // after this drop instructin. - if err := c.ensureOnGeneralPurposeRegister(live); err != nil { + if err := c.maybeCompileEnsureOnGeneralPurposeRegister(live); err != nil { return err } // Update the runtime memory stack location by pushing onto the location stack. @@ -618,7 +907,7 @@ func (c *arm64Compiler) compileSelect() error { // compilePick implements compiler.compilePick for the arm64 architecture. func (c *arm64Compiler) compilePick(o *wazeroir.OperationPick) error { - c.maybeMoveTopConditionalToFreeGeneralPurposeRegister() + c.maybeCompileMoveTopConditionalToFreeGeneralPurposeRegister() pickTarget := c.locationStack.stack[c.locationStack.sp-1-uint64(o.Depth)] pickedRegister, err := c.allocateRegister(pickTarget.registerType()) @@ -634,11 +923,11 @@ func (c *arm64Compiler) compilePick(o *wazeroir.OperationPick) error { case generalPurposeRegisterTypeFloat: inst = arm64.AFMOVD } - c.applyRegisterToRegisterInstruction(inst, pickTarget.register, pickedRegister) + c.compileRegisterToRegisterInstruction(inst, pickTarget.register, pickedRegister) } else if pickTarget.onStack() { // Temporarily assign a register to the pick target, and then load the value. pickTarget.setRegister(pickedRegister) - c.loadValueOnStackToRegister(pickTarget) + c.compileLoadValueOnStackToRegister(pickTarget) // After the load, we revert the register assignment to the pick target. pickTarget.setRegister(nilRegister) } @@ -677,7 +966,7 @@ func (c *arm64Compiler) compileAdd(o *wazeroir.OperationAdd) error { inst = arm64.AFADDD } - c.applyRegisterToRegisterInstruction(inst, x2.register, x1.register) + c.compileRegisterToRegisterInstruction(inst, x2.register, x1.register) // The result is placed on a register for x1, so record it. c.locationStack.pushValueLocationOnRegister(x1.register) return nil @@ -715,7 +1004,7 @@ func (c *arm64Compiler) compileSub(o *wazeroir.OperationSub) error { inst = arm64.AFSUBD } - c.applyTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, destinationReg) + c.compileTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, destinationReg) c.locationStack.pushValueLocationOnRegister(destinationReg) return nil } @@ -745,7 +1034,7 @@ func (c *arm64Compiler) compileMul(o *wazeroir.OperationMul) error { inst = arm64.AFMULD } - c.applyRegisterToRegisterInstruction(inst, x2.register, x1.register) + c.compileRegisterToRegisterInstruction(inst, x2.register, x1.register) // The result is placed on a register for x1, so record it. c.locationStack.pushValueLocationOnRegister(x1.register) return nil @@ -771,32 +1060,208 @@ func (c *arm64Compiler) compileRem(o *wazeroir.OperationRem) error { return fmt.Errorf("TODO: unsupported on arm64") } +// compileAnd implements compiler.compileAnd for the arm64 architecture. func (c *arm64Compiler) compileAnd(o *wazeroir.OperationAnd) error { - return fmt.Errorf("TODO: unsupported on arm64") + x1, x2, err := c.popTwoValuesOnRegisters() + if err != nil { + return err + } + + // If either of the registers x1 or x2 is zero, + // the result will always be zero. + if isZeroRegister(x1.register) || isZeroRegister(x2.register) { + c.locationStack.pushValueLocationOnRegister(zeroRegister) + return nil + } + + // At this point, at least one of x1 or x2 registers is non zero. + // Choose the non-zero register as destination. + var destinationReg int16 = x1.register + if isZeroRegister(x1.register) { + destinationReg = x2.register + } + + var inst obj.As + switch o.Type { + case wazeroir.UnsignedInt32: + inst = arm64.AANDW + case wazeroir.UnsignedInt64: + inst = arm64.AAND + } + + c.compileTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, destinationReg) + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil } +// compileOr implements compiler.compileOr for the arm64 architecture. func (c *arm64Compiler) compileOr(o *wazeroir.OperationOr) error { - return fmt.Errorf("TODO: unsupported on arm64") + x1, x2, err := c.popTwoValuesOnRegisters() + if err != nil { + return err + } + + if isZeroRegister(x1.register) { + c.locationStack.pushValueLocationOnRegister(x2.register) + return nil + } + if isZeroRegister(x2.register) { + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil + } + + var inst obj.As + switch o.Type { + case wazeroir.UnsignedInt32: + inst = arm64.AORRW + case wazeroir.UnsignedInt64: + inst = arm64.AORR + } + + c.compileTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, x1.register) + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil } +// compileXor implements compiler.compileXor for the arm64 architecture. func (c *arm64Compiler) compileXor(o *wazeroir.OperationXor) error { - return fmt.Errorf("TODO: unsupported on arm64") + x1, x2, err := c.popTwoValuesOnRegisters() + if err != nil { + return err + } + + // At this point, at least one of x1 or x2 registers is non zero. + // Choose the non-zero register as destination. + var destinationReg int16 = x1.register + if isZeroRegister(x1.register) { + destinationReg = x2.register + } + + var inst obj.As + switch o.Type { + case wazeroir.UnsignedInt32: + inst = arm64.AEORW + case wazeroir.UnsignedInt64: + inst = arm64.AEOR + } + + c.compileTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, destinationReg) + c.locationStack.pushValueLocationOnRegister(destinationReg) + return nil } +// compileShl implements compiler.compileShl for the arm64 architecture. func (c *arm64Compiler) compileShl(o *wazeroir.OperationShl) error { - return fmt.Errorf("TODO: unsupported on arm64") + x1, x2, err := c.popTwoValuesOnRegisters() + if err != nil { + return err + } + + if isZeroRegister(x1.register) || isZeroRegister(x2.register) { + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil + } + + var inst obj.As + switch o.Type { + case wazeroir.UnsignedInt32: + inst = arm64.ALSLW + case wazeroir.UnsignedInt64: + inst = arm64.ALSL + } + + c.compileTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, x1.register) + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil } +// compileShr implements compiler.compileShr for the arm64 architecture. func (c *arm64Compiler) compileShr(o *wazeroir.OperationShr) error { - return fmt.Errorf("TODO: unsupported on arm64") + x1, x2, err := c.popTwoValuesOnRegisters() + if err != nil { + return err + } + + if isZeroRegister(x1.register) || isZeroRegister(x2.register) { + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil + } + + var inst obj.As + switch o.Type { + case wazeroir.SignedInt32: + inst = arm64.AASRW + case wazeroir.SignedInt64: + inst = arm64.AASR + case wazeroir.SignedUint32: + inst = arm64.ALSRW + case wazeroir.SignedUint64: + inst = arm64.ALSR + } + + c.compileTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, x1.register) + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil } +// compileRotl implements compiler.compileRotl for the arm64 architecture. func (c *arm64Compiler) compileRotl(o *wazeroir.OperationRotl) error { - return fmt.Errorf("TODO: unsupported on arm64") + x1, x2, err := c.popTwoValuesOnRegisters() + if err != nil { + return err + } + + if isZeroRegister(x1.register) || isZeroRegister(x2.register) { + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil + } + + var ( + inst obj.As + neginst obj.As + ) + + switch o.Type { + case wazeroir.UnsignedInt32: + inst = arm64.ARORW + neginst = arm64.ANEGW + case wazeroir.UnsignedInt64: + inst = arm64.AROR + neginst = arm64.ANEG + } + + // Arm64 doesn't have rotate left instruction. + // The shift amount needs to be converted to a negative number, similar to assembly output of bits.RotateLeft. + c.compileRegisterToRegisterInstruction(neginst, x2.register, x2.register) + + c.compileTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, x1.register) + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil } +// compileRotr implements compiler.compileRotr for the arm64 architecture. func (c *arm64Compiler) compileRotr(o *wazeroir.OperationRotr) error { - return fmt.Errorf("TODO: unsupported on arm64") + x1, x2, err := c.popTwoValuesOnRegisters() + if err != nil { + return err + } + + if isZeroRegister(x1.register) || isZeroRegister(x2.register) { + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil + } + + var inst obj.As + switch o.Type { + case wazeroir.UnsignedInt32: + inst = arm64.ARORW + case wazeroir.UnsignedInt64: + inst = arm64.AROR + } + + c.compileTwoRegistersToRegisterInstruction(inst, x2.register, x1.register, x1.register) + c.locationStack.pushValueLocationOnRegister(x1.register) + return nil } func (c *arm64Compiler) compileAbs(o *wazeroir.OperationAbs) error { @@ -908,7 +1373,7 @@ func (c *arm64Compiler) emitEqOrNe(isEq bool, unsignedType wazeroir.UnsignedType inst = arm64.AFCMPD } - c.applyTwoRegistersToNoneInstruction(inst, x2.register, x1.register) + c.compileTwoRegistersToNoneInstruction(inst, x2.register, x1.register) // Push the comparison result as a conditional register value. cond := conditionalRegisterState(arm64.COND_NE) @@ -934,7 +1399,7 @@ func (c *arm64Compiler) compileEqz(o *wazeroir.OperationEqz) error { inst = arm64.ACMP } - c.applyTwoRegistersToNoneInstruction(inst, zeroRegister, x1.register) + c.compileTwoRegistersToNoneInstruction(inst, zeroRegister, x1.register) // Push the comparison result as a conditional register value. cond := conditionalRegisterState(arm64.COND_EQ) @@ -972,7 +1437,7 @@ func (c *arm64Compiler) compileLt(o *wazeroir.OperationLt) error { conditionalRegister = arm64.COND_MI } - c.applyTwoRegistersToNoneInstruction(inst, x2.register, x1.register) + c.compileTwoRegistersToNoneInstruction(inst, x2.register, x1.register) // Push the comparison result as a conditional register value. c.locationStack.pushValueLocationOnConditionalRegister(conditionalRegister) @@ -1009,7 +1474,7 @@ func (c *arm64Compiler) compileGt(o *wazeroir.OperationGt) error { conditionalRegister = arm64.COND_GT } - c.applyTwoRegistersToNoneInstruction(inst, x2.register, x1.register) + c.compileTwoRegistersToNoneInstruction(inst, x2.register, x1.register) // Push the comparison result as a conditional register value. c.locationStack.pushValueLocationOnConditionalRegister(conditionalRegister) @@ -1046,7 +1511,7 @@ func (c *arm64Compiler) compileLe(o *wazeroir.OperationLe) error { conditionalRegister = arm64.COND_LS } - c.applyTwoRegistersToNoneInstruction(inst, x2.register, x1.register) + c.compileTwoRegistersToNoneInstruction(inst, x2.register, x1.register) // Push the comparison result as a conditional register value. c.locationStack.pushValueLocationOnConditionalRegister(conditionalRegister) @@ -1083,7 +1548,7 @@ func (c *arm64Compiler) compileGe(o *wazeroir.OperationGe) error { conditionalRegister = arm64.COND_GE } - c.applyTwoRegistersToNoneInstruction(inst, x2.register, x1.register) + c.compileTwoRegistersToNoneInstruction(inst, x2.register, x1.register) // Push the comparison result as a conditional register value. c.locationStack.pushValueLocationOnConditionalRegister(conditionalRegister) @@ -1132,19 +1597,19 @@ func (c *arm64Compiler) compileMemorySize() error { // compileConstI32 implements compiler.compileConstI32 for the arm64 architecture. func (c *arm64Compiler) compileConstI32(o *wazeroir.OperationConstI32) error { - return c.emitIntConstant(true, uint64(o.Value)) + return c.compileIntConstant(true, uint64(o.Value)) } // compileConstI64 implements compiler.compileConstI64 for the arm64 architecture. func (c *arm64Compiler) compileConstI64(o *wazeroir.OperationConstI64) error { - return c.emitIntConstant(false, o.Value) + return c.compileIntConstant(false, o.Value) } -// emitIntConstant adds instructions to load an integer constant. +// compileIntConstant adds instructions to load an integer constant. // is32bit is true if the target value is originally 32-bit const, false otherwise. // value holds the (zero-extended for 32-bit case) load target constant. -func (c *arm64Compiler) emitIntConstant(is32bit bool, value uint64) error { - c.maybeMoveTopConditionalToFreeGeneralPurposeRegister() +func (c *arm64Compiler) compileIntConstant(is32bit bool, value uint64) error { + c.maybeCompileMoveTopConditionalToFreeGeneralPurposeRegister() if value == 0 { c.pushZeroValue() @@ -1161,7 +1626,7 @@ func (c *arm64Compiler) emitIntConstant(is32bit bool, value uint64) error { } else { inst = arm64.AMOVD } - c.applyConstToRegisterInstruction(inst, int64(value), reg) + c.compileConstToRegisterInstruction(inst, int64(value), reg) c.locationStack.pushValueLocationOnRegister(reg) } @@ -1170,19 +1635,19 @@ func (c *arm64Compiler) emitIntConstant(is32bit bool, value uint64) error { // compileConstF32 implements compiler.compileConstF32 for the arm64 architecture. func (c *arm64Compiler) compileConstF32(o *wazeroir.OperationConstF32) error { - return c.emitFloatConstant(true, uint64(math.Float32bits(o.Value))) + return c.compileFloatConstant(true, uint64(math.Float32bits(o.Value))) } // compileConstF64 implements compiler.compileConstF64 for the arm64 architecture. func (c *arm64Compiler) compileConstF64(o *wazeroir.OperationConstF64) error { - return c.emitFloatConstant(false, math.Float64bits(o.Value)) + return c.compileFloatConstant(false, math.Float64bits(o.Value)) } -// emitFloatConstant adds instructions to load a float constant. +// compileFloatConstant adds instructions to load a float constant. // is32bit is true if the target value is originally 32-bit const, false otherwise. // value holds the (zero-extended for 32-bit case) bit representation of load target float constant. -func (c *arm64Compiler) emitFloatConstant(is32bit bool, value uint64) error { - c.maybeMoveTopConditionalToFreeGeneralPurposeRegister() +func (c *arm64Compiler) compileFloatConstant(is32bit bool, value uint64) error { + c.maybeCompileMoveTopConditionalToFreeGeneralPurposeRegister() // Take a register to load the value. reg, err := c.allocateRegister(generalPurposeRegisterTypeFloat) @@ -1199,7 +1664,7 @@ func (c *arm64Compiler) emitFloatConstant(is32bit bool, value uint64) error { } else { inst = arm64.AMOVD } - c.applyConstToRegisterInstruction(inst, int64(value), tmpReg) + c.compileConstToRegisterInstruction(inst, int64(value), tmpReg) } // Use FMOV instruction to move the value on integer register into the float one. @@ -1209,7 +1674,7 @@ func (c *arm64Compiler) emitFloatConstant(is32bit bool, value uint64) error { } else { inst = arm64.AFMOVD } - c.applyRegisterToRegisterInstruction(inst, tmpReg, reg) + c.compileRegisterToRegisterInstruction(inst, tmpReg, reg) c.locationStack.pushValueLocationOnRegister(reg) return nil @@ -1221,21 +1686,33 @@ func (c *arm64Compiler) pushZeroValue() { // popTwoValuesOnRegisters pops two values from the location stacks, ensures // these two values are located on registers, and mark them unused. +// +// TODO: we’d usually prefix this with compileXXX as this might end up emitting instructions, +// but the name seems awkward. func (c *arm64Compiler) popTwoValuesOnRegisters() (x1, x2 *valueLocation, err error) { - x2, err = c.popValueOnRegister() - if err != nil { + x2 = c.locationStack.pop() + if err = c.maybeCompileEnsureOnGeneralPurposeRegister(x2); err != nil { + return + } + + x1 = c.locationStack.pop() + if err = c.maybeCompileEnsureOnGeneralPurposeRegister(x1); err != nil { return } - x1, err = c.popValueOnRegister() + c.markRegisterUnused(x2.register) + c.markRegisterUnused(x1.register) return } // popValueOnRegister pops one value from the location stack, ensures // that it is located on a register, and mark it unused. +// +// TODO: we’d usually prefix this with compileXXX as this might end up emitting instructions, +// but the name seems awkward. func (c *arm64Compiler) popValueOnRegister() (v *valueLocation, err error) { v = c.locationStack.pop() - if err = c.ensureOnGeneralPurposeRegister(v); err != nil { + if err = c.maybeCompileEnsureOnGeneralPurposeRegister(v); err != nil { return } @@ -1243,27 +1720,27 @@ func (c *arm64Compiler) popValueOnRegister() (v *valueLocation, err error) { return } -// ensureOnGeneralPurposeRegister emits instructions to ensure that a value is located on a register. -func (c *arm64Compiler) ensureOnGeneralPurposeRegister(loc *valueLocation) (err error) { +// maybeCompileEnsureOnGeneralPurposeRegister emits instructions to ensure that a value is located on a register. +func (c *arm64Compiler) maybeCompileEnsureOnGeneralPurposeRegister(loc *valueLocation) (err error) { if loc.onStack() { - err = c.loadValueOnStackToRegister(loc) + err = c.compileLoadValueOnStackToRegister(loc) } else if loc.onConditionalRegister() { - c.loadConditionalRegisterToGeneralPurposeRegister(loc) + c.compileLoadConditionalRegisterToGeneralPurposeRegister(loc) } return } -// maybeMoveTopConditionalToFreeGeneralPurposeRegister moves the top value on the stack +// maybeCompileMoveTopConditionalToFreeGeneralPurposeRegister moves the top value on the stack // if the value is located on a conditional register. // -// This is usually called at the beginning of arm64Compiler.compile* functions where we possibly -// emit istructions without saving the conditional register value. +// This is usually called at the beginning of methods on compiler interface where we possibly +// compile istructions without saving the conditional register value. // The compile* functions without calling this function is saving the conditional // value to the stack or register by invoking ensureOnGeneralPurposeRegister for the top. -func (c *arm64Compiler) maybeMoveTopConditionalToFreeGeneralPurposeRegister() { +func (c *arm64Compiler) maybeCompileMoveTopConditionalToFreeGeneralPurposeRegister() { if c.locationStack.sp > 0 { if loc := c.locationStack.peek(); loc.onConditionalRegister() { - c.loadConditionalRegisterToGeneralPurposeRegister(loc) + c.compileLoadConditionalRegisterToGeneralPurposeRegister(loc) } } } @@ -1273,20 +1750,20 @@ func (c *arm64Compiler) maybeMoveTopConditionalToFreeGeneralPurposeRegister() { // // We use CSET instruction to set 1 on the register if the condition satisfies: // https://developer.arm.com/documentation/100076/0100/a64-instruction-set-reference/a64-general-instructions/cset -func (c *arm64Compiler) loadConditionalRegisterToGeneralPurposeRegister(loc *valueLocation) { +func (c *arm64Compiler) compileLoadConditionalRegisterToGeneralPurposeRegister(loc *valueLocation) { // There must be always at least one free register at this point, as the conditional register located value // is always pushed after consuming at least one value (eqz) or two values for most cases (gt, ge, etc.). reg, _ := c.locationStack.takeFreeRegister(generalPurposeRegisterTypeInt) c.markRegisterUsed(reg) - c.applyRegisterToRegisterInstruction(arm64.ACSET, int16(loc.conditionalRegister), reg) + c.compileRegisterToRegisterInstruction(arm64.ACSET, int16(loc.conditionalRegister), reg) // Record that now the value is located on a general purpose register. loc.setRegister(reg) } -// loadValueOnStackToRegister emits instructions to load the value located on the stack to a register. -func (c *arm64Compiler) loadValueOnStackToRegister(loc *valueLocation) (err error) { +// compileLoadValueOnStackToRegister emits instructions to load the value located on the stack to a register. +func (c *arm64Compiler) compileLoadValueOnStackToRegister(loc *valueLocation) (err error) { var inst obj.As var reg int16 switch loc.regType { @@ -1302,15 +1779,7 @@ func (c *arm64Compiler) loadValueOnStackToRegister(loc *valueLocation) (err erro return } - if offset := int64(loc.stackPointer) * 8; offset > math.MaxInt16 { - // The assembler can take care of offsets larger than 2^15-1 by emitting additional instructions to load such large offset, - // but it uses "its" temporary register which we cannot track. Therefore, we avoid directly emitting memory load with large offsets, - // but instead load the constant manually to "our" temporary register, then emit the load with it. - c.applyConstToRegisterInstruction(arm64.AMOVD, offset, reservedRegisterForTemporary) - c.applyRegisterOffsetMemoryToRegisterInstruction(inst, reservedRegisterForStackBasePointerAddress, reservedRegisterForTemporary, reg) - } else { - c.applyMemoryToRegisterInstruction(inst, reservedRegisterForStackBasePointerAddress, offset, reg) - } + c.compileMemoryToRegisterInstruction(inst, reservedRegisterForStackBasePointerAddress, int64(loc.stackPointer)*8, reg) // Record that the value holds the register and the register is marked used. loc.setRegister(reg) @@ -1322,6 +1791,9 @@ func (c *arm64Compiler) loadValueOnStackToRegister(loc *valueLocation) (err erro // either from the free register pool or by spilling an used register. If we allocate an used register, // this adds an instruction to write the value on a register back to memory stack region. // Note: resulting registers are NOT marked as used so the call site should mark it used if necessary. +// +// TODO: we’d usually prefix this with compileXXX as this might end up emitting instructions, +// but the name seems awkward. func (c *arm64Compiler) allocateRegister(t generalPurposeRegisterType) (reg int16, err error) { var ok bool // Try to get the unused register. @@ -1339,22 +1811,22 @@ func (c *arm64Compiler) allocateRegister(t generalPurposeRegisterType) (reg int1 // Release the steal target register value onto stack location. reg = stealTarget.register - err = c.releaseRegisterToStack(stealTarget) + err = c.compileReleaseRegisterToStack(stealTarget) return } -// releaseAllRegistersToStack adds instructions to store all the values located on +// compileReleaseAllRegistersToStack adds instructions to store all the values located on // either general purpuse or conditional registers onto the memory stack. // See releaseRegisterToStack. -func (c *arm64Compiler) releaseAllRegistersToStack() error { +func (c *arm64Compiler) compileReleaseAllRegistersToStack() error { for i := uint64(0); i < c.locationStack.sp; i++ { if loc := c.locationStack.stack[i]; loc.onRegister() { - if err := c.releaseRegisterToStack(loc); err != nil { + if err := c.compileReleaseRegisterToStack(loc); err != nil { return err } } else if loc.onConditionalRegister() { - c.loadConditionalRegisterToGeneralPurposeRegister(loc) - if err := c.releaseRegisterToStack(loc); err != nil { + c.compileLoadConditionalRegisterToGeneralPurposeRegister(loc) + if err := c.compileReleaseRegisterToStack(loc); err != nil { return err } } @@ -1363,64 +1835,50 @@ func (c *arm64Compiler) releaseAllRegistersToStack() error { } // releaseRegisterToStack adds an instruction to write the value on a register back to memory stack region. -func (c *arm64Compiler) releaseRegisterToStack(loc *valueLocation) (err error) { +func (c *arm64Compiler) compileReleaseRegisterToStack(loc *valueLocation) (err error) { var inst obj.As = arm64.AMOVD if loc.regType == generalPurposeRegisterTypeFloat { inst = arm64.AFMOVD } - if offset := int64(loc.stackPointer) * 8; offset > math.MaxInt16 { - // The assembler can take care of offsets larger than 2^15-1 by emitting additional instructions to load such large offset, - // but it uses "its" temporary register which we cannot track. Therefore, we avoid directly emitting memory load with large offsets, - // but instead load the constant manually to "our" temporary register, then emit the load with it. - c.applyConstToRegisterInstruction(arm64.AMOVD, offset, reservedRegisterForTemporary) - c.applyRegisterToRegisterOffsetMemoryInstruction(inst, reservedRegisterForStackBasePointerAddress, reservedRegisterForTemporary, loc.register) - } else { - if err = c.applyRegisterToMemoryInstruction(inst, reservedRegisterForStackBasePointerAddress, offset, loc.register); err != nil { - return - } - } + c.compileRegisterToMemoryInstruction(inst, loc.register, reservedRegisterForStackBasePointerAddress, int64(loc.stackPointer)*8) // Mark the register is free. c.locationStack.releaseRegister(loc) return } -// initializeReservedStackBasePointerRegister adds intructions to initialize reservedRegisterForStackBasePointerAddress +// compilaleInitializeReservedStackBasePointerRegister adds intructions to initialize reservedRegisterForStackBasePointerAddress // so that it points to the absolute address of the stack base for this function. -func (c *arm64Compiler) initializeReservedStackBasePointerRegister() error { +func (c *arm64Compiler) compileInitializeReservedStackBasePointerRegister() error { // First, load the address of the first element in the value stack into reservedRegisterForStackBasePointerAddress temporarily. - if err := c.applyMemoryToRegisterInstruction(arm64.AMOVD, + c.compileMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, engineGlobalContextValueStackElement0AddressOffset, - reservedRegisterForStackBasePointerAddress); err != nil { - return err - } + reservedRegisterForStackBasePointerAddress) - // Next we move the base pointer (engine.stackBasePointer) to the tmp register. - if err := c.applyMemoryToRegisterInstruction(arm64.AMOVD, + // Next we move the base pointer (engine.stackBasePointer) to reservedRegisterForTemporary. + c.compileMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, engineValueStackContextStackBasePointerOffset, - reservedRegisterForTemporary); err != nil { - return err - } - - // Finally, we calculate "reservedRegisterForStackBasePointerAddress + tmpReg * 8" - // where we multiply tmpReg by 8 because stack pointer is an index in the []uint64 - // so as an bytes we must multiply the size of uint64 = 8 bytes. - calcStackBasePointerAddress := c.newProg() - calcStackBasePointerAddress.As = arm64.AADD - calcStackBasePointerAddress.To.Type = obj.TYPE_REG - calcStackBasePointerAddress.To.Reg = reservedRegisterForStackBasePointerAddress - // We calculate "tmpReg * 8" as "tmpReg << 3". - setLeftShiftedRegister(calcStackBasePointerAddress, reservedRegisterForTemporary, 3) - c.addInstruction(calcStackBasePointerAddress) + reservedRegisterForTemporary) + + // Finally, we calculate "reservedRegisterForStackBasePointerAddress + reservedRegisterForTemporary << 3" + // where we shift tmpReg by 3 because stack pointer is an index in the []uint64 + // so we must multiply the value by the size of uint64 = 8 bytes. + c.compileAddInstructionWithLeftShiftedRegister( + reservedRegisterForTemporary, 3, reservedRegisterForStackBasePointerAddress, + reservedRegisterForStackBasePointerAddress) return nil } -// setShiftedRegister modifies the given *obj.Prog so that .From (source operand) -// becomes the "left shifted register". For example, this is used to emit instruction like -// "add x1, x2, x3, lsl #3" which means "x1 = x2 + (x3 << 3)". -// See https://github.com/twitchyliquid64/golang-asm/blob/v0.15.1/obj/link.go#L120-L131 -func setLeftShiftedRegister(inst *obj.Prog, register int16, shiftNum int64) { +// emitAddInstructionWithLeftShiftedRegister emits an ADD instruction to perform "destinationReg = srcReg + (shiftedSourceReg << shiftNum)". +func (c *arm64Compiler) compileAddInstructionWithLeftShiftedRegister(shiftedSourceReg int16, shiftNum int64, srcReg, destinationReg int16) { + inst := c.newProg() + inst.As = arm64.AADD + inst.To.Type = obj.TYPE_REG + inst.To.Reg = destinationReg + // See https://github.com/twitchyliquid64/golang-asm/blob/v0.15.1/obj/link.go#L120-L131 inst.From.Type = obj.TYPE_SHIFT - inst.From.Offset = (int64(register)&31)<<16 | 0<<22 | (shiftNum&63)<<10 + inst.From.Offset = (int64(shiftedSourceReg)&31)<<16 | 0<<22 | (shiftNum&63)<<10 + inst.Reg = srcReg + c.addInstruction(inst) } diff --git a/wasm/jit/jit_arm64_test.go b/wasm/jit/jit_arm64_test.go index 51ff7159fd..1dcf3f3519 100644 --- a/wasm/jit/jit_arm64_test.go +++ b/wasm/jit/jit_arm64_test.go @@ -4,9 +4,9 @@ package jit import ( - "context" "fmt" "math" + "math/bits" "testing" "unsafe" @@ -49,58 +49,6 @@ func (j *jitEnv) requireNewCompiler(t *testing.T) *arm64Compiler { return ret } -// TODO: delete this as this could be a duplication from other tests especially spectests. -// Use this until we could run spectests on arm64. -func TestArm64CompilerEndToEnd(t *testing.T) { - ctx := context.Background() - for _, tc := range []struct { - name string - body []byte - sig *wasm.FunctionType - }{ - {name: "empty", body: []byte{wasm.OpcodeEnd}, sig: &wasm.FunctionType{}}, - {name: "br .return", body: []byte{wasm.OpcodeBr, 0, wasm.OpcodeEnd}, sig: &wasm.FunctionType{}}, - { - name: "consts", - body: []byte{ - wasm.OpcodeI32Const, 1, wasm.OpcodeI64Const, 1, - wasm.OpcodeF32Const, 1, 1, 1, 1, wasm.OpcodeF64Const, 1, 2, 3, 4, 5, 6, 7, 8, - wasm.OpcodeEnd, - }, - // We push four constants. - sig: &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI32, wasm.ValueTypeI64, wasm.ValueTypeF32, wasm.ValueTypeF64}}, - }, - { - name: "add", - body: []byte{wasm.OpcodeI32Const, 1, wasm.OpcodeI32Const, 1, wasm.OpcodeI32Add, wasm.OpcodeEnd}, - sig: &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI32}}, - }, - { - name: "sub", - body: []byte{wasm.OpcodeI64Const, 1, wasm.OpcodeI64Const, 1, wasm.OpcodeI64Sub, wasm.OpcodeEnd}, - sig: &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI64}}, - }, - { - name: "mul", - body: []byte{wasm.OpcodeI64Const, 1, wasm.OpcodeI64Const, 1, wasm.OpcodeI64Mul, wasm.OpcodeEnd}, - sig: &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI64}}, - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - engine := newEngine() - f := &wasm.FunctionInstance{ - FunctionType: &wasm.TypeInstance{Type: tc.sig}, - Body: tc.body, - } - err := engine.Compile(f) - require.NoError(t, err) - _, err = engine.Call(ctx, f) - require.NoError(t, err) - }) - } -} - func TestArchContextOffsetInEngine(t *testing.T) { var eng engine // If this fails, we have to fix jit_arm64.s as well. @@ -108,25 +56,77 @@ func TestArchContextOffsetInEngine(t *testing.T) { } func TestArm64Compiler_returnFunction(t *testing.T) { - env := newJITEnvironment() + t.Run("exit", func(t *testing.T) { + env := newJITEnvironment() - // Build code. - compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() - require.NoError(t, err) - compiler.returnFunction() + // Build code. + compiler := env.requireNewCompiler(t) + err := compiler.compilePreamble() + require.NoError(t, err) + compiler.compileReturnFunction() - // Generate the code under test. - code, _, _, err := compiler.compile() - require.NoError(t, err) + code, _, _, err := compiler.compile() + require.NoError(t, err) + + env.exec(code) + + // JIT status on engine must be returned. + require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + // Plus, the call frame stack pointer must be zero after return. + require.Equal(t, uint64(0), env.callFrameStackPointer()) + }) + t.Run("deep call stack", func(t *testing.T) { + env := newJITEnvironment() + engine := env.engine() - // Run native code. - env.exec(code) + // Push the call frames. + const callFrameNums = 10 + stackPointerToExpectedValue := map[uint64]uint32{} + for funcaddr := wasm.FunctionAddress(0); funcaddr < callFrameNums; funcaddr++ { + // Each function pushes its funcaddr and soon returns. + compiler := env.requireNewCompiler(t) + err := compiler.compilePreamble() + require.NoError(t, err) + + // Push its funcaddr. + expValue := uint32(funcaddr) + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: expValue}) + require.NoError(t, err) - // JIT status on engine must be returned. - require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) - // Plus, the call frame stack pointer must be zero after return. - require.Equal(t, uint64(0), env.callFrameStackPointer()) + err = compiler.compileReturnFunction() + require.NoError(t, err) + + code, _, _, err := compiler.compile() + require.NoError(t, err) + + // Compiles and adds to the engine. + compiledFunction := &compiledFunction{codeSegment: code, codeInitialAddress: uintptr(unsafe.Pointer(&code[0]))} + engine.addCompiledFunction(funcaddr, compiledFunction) + + // Pushes the frame whose return address equals the beginning of the function just compiled. + frame := callFrame{ + // Set the return address to the beginning of the function so that we can execute the constI32 above. + returnAddress: compiledFunction.codeInitialAddress, + // Note: return stack base pointer is set to funcaddr*10 and this is where the const should be pushed. + returnStackBasePointer: uint64(funcaddr) * 10, + compiledFunction: compiledFunction, + } + engine.callFrameStack[engine.globalContext.callFrameStackPointer] = frame + engine.globalContext.callFrameStackPointer++ + stackPointerToExpectedValue[frame.returnStackBasePointer] = expValue + } + + require.Equal(t, uint64(callFrameNums), env.callFrameStackPointer()) + + // Run code from the top frame. + env.exec(engine.callFrameTop().compiledFunction.codeSegment) + + // Check the exit status and the values on stack. + require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + for pos, exp := range stackPointerToExpectedValue { + require.Equal(t, exp, uint32(env.stack()[pos])) + } + }) } func TestArm64Compiler_exit(t *testing.T) { @@ -142,7 +142,7 @@ func TestArm64Compiler_exit(t *testing.T) { // Build code. compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() expStackPointer := uint64(100) compiler.locationStack.sp = expStackPointer @@ -190,7 +190,7 @@ func TestArm64Compiler_compileConsts(t *testing.T) { // Build code. compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) switch op { @@ -210,8 +210,8 @@ func TestArm64Compiler_compileConsts(t *testing.T) { require.True(t, loc.onRegister()) // Release the register allocated value to the memory stack so that we can see the value after exiting. - compiler.releaseRegisterToStack(loc) - compiler.returnFunction() + compiler.compileReleaseRegisterToStack(loc) + compiler.compileReturnFunction() // Generate the code under test. code, _, _, err := compiler.compile() @@ -254,7 +254,7 @@ func TestArm64Compiler_releaseRegisterToStack(t *testing.T) { // Build code. compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the location stack so that we push the const on the specified height. @@ -271,8 +271,8 @@ func TestArm64Compiler_releaseRegisterToStack(t *testing.T) { require.NoError(t, err) // Release the register allocated value to the memory stack so that we can see the value after exiting. - compiler.releaseRegisterToStack(compiler.locationStack.peek()) - compiler.returnFunction() + compiler.compileReleaseRegisterToStack(compiler.locationStack.peek()) + compiler.exit(jitCallStatusCodeReturned) // Generate the code under test. code, _, _, err := compiler.compile() @@ -295,7 +295,7 @@ func TestArm64Compiler_releaseRegisterToStack(t *testing.T) { } } -func TestArm64Compiler_loadValueOnStackToRegister(t *testing.T) { +func TestArm64Compiler_compileLoadValueOnStackToRegister(t *testing.T) { const val = 123 for _, tc := range []struct { name string @@ -313,7 +313,7 @@ func TestArm64Compiler_loadValueOnStackToRegister(t *testing.T) { // Build code. compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Setup the location stack so that we push the const on the specified height. @@ -332,26 +332,26 @@ func TestArm64Compiler_loadValueOnStackToRegister(t *testing.T) { require.True(t, loc.onStack()) // Release the stack-allocated value to register. - compiler.loadValueOnStackToRegister(loc) + compiler.compileLoadValueOnStackToRegister(loc) require.Len(t, compiler.locationStack.usedRegisters, 1) require.True(t, loc.onRegister()) // To verify the behavior, increment the value on the register. if tc.isFloat { // For float, we cannot add consts, so load the constant first. - err = compiler.emitFloatConstant(false, math.Float64bits(1)) + err = compiler.compileFloatConstant(false, math.Float64bits(1)) require.NoError(t, err) // Then, do the increment. - compiler.applyRegisterToRegisterInstruction(arm64.AFADDD, compiler.locationStack.peek().register, loc.register) + compiler.compileRegisterToRegisterInstruction(arm64.AFADDD, compiler.locationStack.peek().register, loc.register) // Delete the loaded const. compiler.locationStack.pop() } else { - compiler.applyConstToRegisterInstruction(arm64.AADD, 1, loc.register) + compiler.compileConstToRegisterInstruction(arm64.AADD, 1, loc.register) } // Release the value to the memory stack so that we can see the value after exiting. - compiler.releaseRegisterToStack(loc) - compiler.returnFunction() + compiler.compileReleaseRegisterToStack(loc) + compiler.exit(jitCallStatusCodeReturned) // Generate the code under test. code, _, _, err := compiler.compile() @@ -433,7 +433,7 @@ func TestArm64Compiler_compile_Le_Lt_Gt_Ge_Eq_Eqz_Ne(t *testing.T) { t.Run(fmt.Sprintf("x1=0x%x,x2=0x%x", x1, x2), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Emit consts operands. @@ -514,12 +514,12 @@ func TestArm64Compiler_compile_Le_Lt_Gt_Ge_Eq_Eqz_Ne(t *testing.T) { require.True(t, resultLocation.onConditionalRegister()) // Move the conditional register value to a general purpose register to verify the value. - compiler.loadConditionalRegisterToGeneralPurposeRegister(resultLocation) + compiler.compileLoadConditionalRegisterToGeneralPurposeRegister(resultLocation) require.True(t, resultLocation.onRegister()) // Release the value to the memory stack again to verify the operation. - compiler.releaseRegisterToStack(resultLocation) - compiler.returnFunction() + compiler.compileReleaseRegisterToStack(resultLocation) + compiler.compileReturnFunction() // Compile and execute the code under test. code, _, _, err := compiler.compile() @@ -668,7 +668,7 @@ func TestArm64Compiler_compile_Add_Sub_Mul(t *testing.T) { t.Run(fmt.Sprintf("x1=0x%x,x2=0x%x", x1, x2), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Emit consts operands. @@ -713,8 +713,8 @@ func TestArm64Compiler_compile_Add_Sub_Mul(t *testing.T) { } // Release the value to the memory stack again to verify the operation. - compiler.releaseRegisterToStack(resultLocation) - compiler.returnFunction() + compiler.compileReleaseRegisterToStack(resultLocation) + compiler.compileReturnFunction() // Compile and execute the code under test. code, _, _, err := compiler.compile() @@ -797,6 +797,220 @@ func TestArm64Compiler_compile_Add_Sub_Mul(t *testing.T) { } } +func TestArm64Compiler_compile_And_Or_Xor_Shl_Rotr(t *testing.T) { + for _, kind := range []wazeroir.OperationKind{ + wazeroir.OperationKindAnd, + wazeroir.OperationKindOr, + wazeroir.OperationKindXor, + wazeroir.OperationKindShl, + wazeroir.OperationKindRotl, + wazeroir.OperationKindRotr, + } { + kind := kind + t.Run(kind.String(), func(t *testing.T) { + for _, unsignedInt := range []wazeroir.UnsignedInt{ + wazeroir.UnsignedInt32, + wazeroir.UnsignedInt64, + } { + unsignedInt := unsignedInt + t.Run(unsignedInt.String(), func(t *testing.T) { + for _, values := range [][2]uint64{ + {0, 0}, {0, 1}, {1, 0}, {1, 1}, + {1 << 31, 1}, {1, 1 << 31}, {1 << 31, 1 << 31}, + {1 << 63, 1}, {1, 1 << 63}, {1 << 63, 1 << 63}, + } { + x1, x2 := values[0], values[1] + t.Run(fmt.Sprintf("x1=0x%x,x2=0x%x", x1, x2), func(t *testing.T) { + env := newJITEnvironment() + compiler := env.requireNewCompiler(t) + err := compiler.compilePreamble() + require.NoError(t, err) + + // Emit consts operands. + for _, v := range []uint64{x1, x2} { + switch unsignedInt { + case wazeroir.UnsignedInt32: + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: uint32(v)}) + case wazeroir.UnsignedInt64: + err = compiler.compileConstI64(&wazeroir.OperationConstI64{Value: v}) + } + require.NoError(t, err) + } + + // At this point, two values exist. + require.Equal(t, uint64(2), compiler.locationStack.sp) + + // Emit the operation. + switch kind { + case wazeroir.OperationKindAnd: + err = compiler.compileAnd(&wazeroir.OperationAnd{Type: unsignedInt}) + case wazeroir.OperationKindOr: + err = compiler.compileOr(&wazeroir.OperationOr{Type: unsignedInt}) + case wazeroir.OperationKindXor: + err = compiler.compileXor(&wazeroir.OperationXor{Type: unsignedInt}) + case wazeroir.OperationKindShl: + err = compiler.compileShl(&wazeroir.OperationShl{Type: unsignedInt}) + case wazeroir.OperationKindRotl: + err = compiler.compileRotl(&wazeroir.OperationRotl{Type: unsignedInt}) + case wazeroir.OperationKindRotr: + err = compiler.compileRotr(&wazeroir.OperationRotr{Type: unsignedInt}) + } + require.NoError(t, err) + + // We consumed two values, but push the result back. + require.Equal(t, uint64(1), compiler.locationStack.sp) + resultLocation := compiler.locationStack.peek() + // Plus the result must be located on a register. + require.True(t, resultLocation.onRegister()) + // Also, the result must have an appropriate register type. + require.Equal(t, generalPurposeRegisterTypeInt, resultLocation.regType) + + // Release the value to the memory stack again to verify the operation. + compiler.compileReleaseRegisterToStack(resultLocation) + compiler.compileReturnFunction() + + // Compile and execute the code under test. + code, _, _, err := compiler.compile() + require.NoError(t, err) + env.exec(code) + + // Check the stack. + require.Equal(t, uint64(1), env.stackPointer()) + + switch kind { + case wazeroir.OperationKindAnd: + switch unsignedInt { + case wazeroir.UnsignedInt32: + require.Equal(t, uint32(x1)&uint32(x2), env.stackTopAsUint32()) + case wazeroir.UnsignedInt64: + require.Equal(t, x1&x2, env.stackTopAsUint64()) + } + case wazeroir.OperationKindOr: + switch unsignedInt { + case wazeroir.UnsignedInt32: + require.Equal(t, uint32(x1)|uint32(x2), env.stackTopAsUint32()) + case wazeroir.UnsignedInt64: + require.Equal(t, x1|x2, env.stackTopAsUint64()) + } + case wazeroir.OperationKindXor: + switch unsignedInt { + case wazeroir.UnsignedInt32: + require.Equal(t, uint32(x1)^uint32(x2), env.stackTopAsUint32()) + case wazeroir.UnsignedInt64: + require.Equal(t, x1^x2, env.stackTopAsUint64()) + } + case wazeroir.OperationKindShl: + switch unsignedInt { + case wazeroir.UnsignedInt32: + require.Equal(t, uint32(x1)<>(uint32(x2)%32), env.stackTopAsInt32()) + case wazeroir.SignedInt64: + require.Equal(t, int64(x1)>>(x2%64), env.stackTopAsInt64()) + case wazeroir.SignedUint32: + require.Equal(t, uint32(x1)>>(uint32(x2)%32), env.stackTopAsUint32()) + case wazeroir.SignedUint64: + require.Equal(t, x1>>(x2%64), env.stackTopAsUint64()) + } + }) + } + }) + } + }) +} + func TestArm64Compiler_compielePick(t *testing.T) { const pickTargetValue uint64 = 12345 op := &wazeroir.OperationPick{Depth: 1} @@ -849,7 +1063,7 @@ func TestArm64Compiler_compielePick(t *testing.T) { t.Run(tc.name, func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Set up the stack before picking. @@ -871,11 +1085,11 @@ func TestArm64Compiler_compielePick(t *testing.T) { require.Equal(t, pickTargetLocation.registerType(), pickedLocation.registerType()) // Release the value to the memory stack again to verify the operation, and then return. - compiler.releaseRegisterToStack(pickedLocation) + compiler.compileReleaseRegisterToStack(pickedLocation) if tc.isPickTargetOnRegister { - compiler.releaseRegisterToStack(pickTargetLocation) + compiler.compileReleaseRegisterToStack(pickTargetLocation) } - compiler.returnFunction() + compiler.compileReturnFunction() // Compile and execute the code under test. code, _, _, err := compiler.compile() @@ -903,7 +1117,7 @@ func TestArm64Compiler_compieleDrop(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Put existing contents on stack. @@ -919,7 +1133,7 @@ func TestArm64Compiler_compieleDrop(t *testing.T) { // After the nil range drop, the stack must remain the same. require.Equal(t, uint64(liveNum), compiler.locationStack.sp) - compiler.returnFunction() + compiler.compileReturnFunction() code, _, _, err := compiler.compile() require.NoError(t, err) @@ -935,7 +1149,7 @@ func TestArm64Compiler_compieleDrop(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Put existing contents on stack. @@ -959,9 +1173,9 @@ func TestArm64Compiler_compieleDrop(t *testing.T) { top := compiler.locationStack.peek() require.True(t, top.onRegister()) // Release the top value after drop so that we can verify the cpu itself is not mainpulated. - compiler.releaseRegisterToStack(top) + compiler.compileReleaseRegisterToStack(top) - compiler.returnFunction() + compiler.compileReturnFunction() code, _, _, err := compiler.compile() require.NoError(t, err) @@ -984,7 +1198,7 @@ func TestArm64Compiler_compieleDrop(t *testing.T) { eng := env.engine() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Put existing contents except the top on stack @@ -1009,9 +1223,9 @@ func TestArm64Compiler_compieleDrop(t *testing.T) { require.True(t, compiler.locationStack.peek().onRegister()) // Release all register values so that we can verify the register allocated values. - err = compiler.releaseAllRegistersToStack() + err = compiler.compileReleaseAllRegistersToStack() require.NoError(t, err) - compiler.returnFunction() + compiler.compileReturnFunction() code, _, _, err := compiler.compile() require.NoError(t, err) @@ -1070,7 +1284,7 @@ func TestArm64Compiler_compileBr(t *testing.T) { t.Run("return", func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Branch into nil label is interpreted as return. See BranchTarget.IsReturnTarget @@ -1099,7 +1313,7 @@ func TestArm64Compiler_compileBr(t *testing.T) { nop.As = obj.ANOP compiler.addInstruction(nop) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) err = compiler.compileBr(&wazeroir.OperationBr{Target: &wazeroir.BranchTarget{Label: backwardLabel}}) @@ -1116,7 +1330,7 @@ func TestArm64Compiler_compileBr(t *testing.T) { // .backwardLabel: // exit jitCallStatusCodeReturned // nop - // ... code from emitPreamble() + // ... code from compilePreamble() // br .backwardLabel // exit jitCallStatusCodeUnreachable // @@ -1127,7 +1341,7 @@ func TestArm64Compiler_compileBr(t *testing.T) { t.Run("forward br", func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) // Emit the forward br, meaning that handle Br instruction where the target label hasn't been compiled yet. @@ -1147,7 +1361,7 @@ func TestArm64Compiler_compileBr(t *testing.T) { // The generated code looks like this: // - // ... code from emitPreamble() + // ... code from compilePreamble() // br .forwardLabel // exit jitCallStatusCodeUnreachable // .forwardLabel: @@ -1206,8 +1420,6 @@ func TestArm64Compiler_compileBrIf(t *testing.T) { require.NoError(t, err) }, }, - // {name: "EQ"} TODO: after compileEq support - // {name: "NE"} TODO: after compileNe support { name: "HS", setupFunc: func(t *testing.T, compiler *arm64Compiler, shoulGoElse bool) { @@ -1299,6 +1511,30 @@ func TestArm64Compiler_compileBrIf(t *testing.T) { require.NoError(t, err) }, }, + { + name: "EQ", + setupFunc: func(t *testing.T, compiler *arm64Compiler, shoulGoElse bool) { + x1, x2 := uint32(1), uint32(1) + if shoulGoElse { + x2++ + } + requirePushTwoInt32Consts(t, x1, x2, compiler) + err := compiler.compileEq(&wazeroir.OperationEq{Type: wazeroir.UnsignedTypeI32}) + require.NoError(t, err) + }, + }, + { + name: "NE", + setupFunc: func(t *testing.T, compiler *arm64Compiler, shoulGoElse bool) { + x1, x2 := uint32(1), uint32(2) + if shoulGoElse { + x2 = x1 + } + requirePushTwoInt32Consts(t, x1, x2, compiler) + err := compiler.compileNe(&wazeroir.OperationNe{Type: wazeroir.UnsignedTypeI32}) + require.NoError(t, err) + }, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -1307,7 +1543,7 @@ func TestArm64Compiler_compileBrIf(t *testing.T) { t.Run(fmt.Sprintf("should_goto_else=%v", shouldGoToElse), func(t *testing.T) { env := newJITEnvironment() compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() + err := compiler.compilePreamble() require.NoError(t, err) tc.setupFunc(t, compiler, shouldGoToElse) @@ -1330,7 +1566,7 @@ func TestArm64Compiler_compileBrIf(t *testing.T) { // The generated code looks like this: // - // ... code from emitPreamble() + // ... code from compilePreamble() // br_if .then, .else // exit $unreachableStatus // .then: @@ -1351,3 +1587,159 @@ func TestArm64Compiler_compileBrIf(t *testing.T) { }) } } + +func TestArm64Compiler_readInstructionAddress(t *testing.T) { + t.Run("target instruction not found", func(t *testing.T) { + env := newJITEnvironment() + compiler := env.requireNewCompiler(t) + + err := compiler.compilePreamble() + require.NoError(t, err) + + // Set the acquisition target instruction to the one after JMP. + compiler.readInstructionAddress(obj.AJMP, reservedRegisterForTemporary) + + compiler.exit(jitCallStatusCodeReturned) + + // If generate the code without JMP after readInstructionAddress, + // the call back added must return error. + _, _, _, err = compiler.compile() + require.Error(t, err) + require.Contains(t, err.Error(), "target instruction not found") + }) + t.Run("too large offset", func(t *testing.T) { + env := newJITEnvironment() + compiler := env.requireNewCompiler(t) + + err := compiler.compilePreamble() + require.NoError(t, err) + + // Set the acquisition target instruction to the one after RET. + compiler.readInstructionAddress(obj.ARET, reservedRegisterForTemporary) + + // Add many instruction between the target and readInstructionAddress. + for i := 0; i < 100; i++ { + compiler.compileConstI32(&wazeroir.OperationConstI32{Value: 10}) + } + + ret := compiler.newProg() + ret.As = obj.ARET + ret.To.Type = obj.TYPE_REG + ret.To.Reg = reservedRegisterForTemporary + compiler.compileReturnFunction() + + // If generate the code with too many instruction between ADR and + // the target, compile must fail. + _, _, _, err = compiler.compile() + require.Error(t, err) + require.Contains(t, err.Error(), "too large offset") + }) + t.Run("ok", func(t *testing.T) { + env := newJITEnvironment() + compiler := env.requireNewCompiler(t) + + err := compiler.compilePreamble() + require.NoError(t, err) + + // Set the acquisition target instruction to the one after RET, + // and read the absolute address into destinationRegister. + const addressReg = reservedRegisterForTemporary + compiler.readInstructionAddress(obj.ARET, addressReg) + + // Branch to the instruction after RET below via the absolute + // address stored in destinationRegister. + compiler.compileUnconditionalBranchToAddressOnRegister(addressReg) + + // If we fail to branch, we reach here and exit with unreachable status, + // so the assertion would fail. + compiler.exit(jitCallStatusCodeUnreachable) + + // This could be the read instruction target as this is the + // right after RET. Therefore, the branch instruction above + // must target here. + err = compiler.compileReturnFunction() + require.NoError(t, err) + + code, _, _, err := compiler.compile() + require.NoError(t, err) + + env.exec(code) + + require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + }) +} + +func TestArm64Compiler_compieleCall(t *testing.T) { + t.Run("need to grow call frame stack", func(t *testing.T) { + t.Skip("TODO") + }) + t.Run("callframe stack ok", func(t *testing.T) { + env := newJITEnvironment() + engine := env.engine() + expectedValue := uint32(0) + + // Emit the call target function. + const numCalls = 10 + targetFunctionType := &wasm.FunctionType{ + Params: []wasm.ValueType{wasm.ValueTypeI32}, + Results: []wasm.ValueType{wasm.ValueTypeI32}, + } + for i := 0; i < numCalls; i++ { + // Each function takes one arguments, adds the value with 100 + i and returns the result. + addTargetValue := uint32(100 + i) + expectedValue += addTargetValue + + compiler := env.requireNewCompiler(t) + compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}} + + err := compiler.compilePreamble() + require.NoError(t, err) + + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: uint32(addTargetValue)}) + require.NoError(t, err) + err = compiler.compileAdd(&wazeroir.OperationAdd{Type: wazeroir.UnsignedTypeI32}) + require.NoError(t, err) + err = compiler.compileReturnFunction() + require.NoError(t, err) + + code, _, _, err := compiler.compile() + require.NoError(t, err) + engine.addCompiledFunction(wasm.FunctionAddress(i), &compiledFunction{ + codeSegment: code, + codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), + }) + } + + // Now we start building the caller's code. + compiler := env.requireNewCompiler(t) + err := compiler.compilePreamble() + require.NoError(t, err) + + const initialValue = 100 + expectedValue += initialValue + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: 0}) // Dummy value so the base pointer would be non-trivial for callees. + require.NoError(t, err) + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: initialValue}) + require.NoError(t, err) + + // Call all the built functions. + for i := 0; i < numCalls; i++ { + err = compiler.compileCallFunction(wasm.FunctionAddress(i), targetFunctionType) + require.NoError(t, err) + } + + err = compiler.compileReturnFunction() + require.NoError(t, err) + + code, _, _, err := compiler.compile() + require.NoError(t, err) + + env.exec(code) + + // Check status and returned values. + require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + require.Equal(t, uint64(2), env.stackPointer()) // Must be 2 (dummy value + the calculation results) + require.Equal(t, uint64(0), env.stackBasePointer()) + require.Equal(t, expectedValue, env.stackTopAsUint32()) + }) +} diff --git a/wasm/module.go b/wasm/module.go index 6009dfe38f..2965fe93c0 100644 --- a/wasm/module.go +++ b/wasm/module.go @@ -1,6 +1,8 @@ package wasm -// DecodeModule parses the configured source into a wasm.Module. This function returns when the source is exhausted or +import "fmt" + +// DecodeModule parses the configured source into a Module. This function returns when the source is exhausted or // an error occurs. The result can be initialized for use via Store.Instantiate. // // Here's a description of the return values: @@ -90,6 +92,7 @@ type Module struct { // (Store.Instantiate). // // Note: there are no unique constraints relating to the two-level namespace of Import.Module and Import.Name. + // // Note: In the Binary Format, this is SectionIDImport. // // See https://www.w3.org/TR/wasm-core-1/#import-section%E2%91%A0 @@ -103,6 +106,7 @@ type Module struct { // // Note: FunctionSection is index correlated with the CodeSection. If given the same position, ex. 2, a function // type is at TypeSection[FunctionSection[2]], while its locals and body are at CodeSection[2]. + // // Note: In the Binary Format, this is SectionIDFunction. // // See https://www.w3.org/TR/wasm-core-1/#function-section%E2%91%A0 @@ -115,7 +119,8 @@ type Module struct { // this module at TableSection[0]. // // Note: Version 1.0 (MVP) of the WebAssembly spec allows at most one table definition per module, so the length of - // the TableSection can be zero or one. + // the TableSection can be zero or one, and can only be one if there is no ImportKindTable. + // // Note: In the Binary Format, this is SectionIDTable. // // See https://www.w3.org/TR/wasm-core-1/#table-section%E2%91%A0 @@ -128,7 +133,8 @@ type Module struct { // this module at TableSection[0]. // // Note: Version 1.0 (MVP) of the WebAssembly spec allows at most one memory definition per module, so the length of - // the MemorySection can be zero or one. + // the MemorySection can be zero or one, and can only be one if there is no ImportKindMemory. + // // Note: In the Binary Format, this is SectionIDMemory. // // See https://www.w3.org/TR/wasm-core-1/#memory-section%E2%91%A0 @@ -147,6 +153,8 @@ type Module struct { // ExportSection contains each export defined in this module. // + // Note: In the Binary Format, this is SectionIDExport. + // // See https://www.w3.org/TR/wasm-core-1/#exports%E2%91%A0 ExportSection map[string]*Export @@ -154,6 +162,7 @@ type Module struct { // // Note: The index here is not the position in the FunctionSection, rather in the function index namespace, which // begins with imported functions. + // // Note: In the Binary Format, this is SectionIDStart. // // See https://www.w3.org/TR/wasm-core-1/#start-section%E2%91%A0 @@ -174,20 +183,12 @@ type Module struct { // NameSection is set when the SectionIDCustom "name" was successfully decoded from the binary format. // - // Note: This is the only SectionIDCustom defined in the WebAssembly 1.0 (MVP) Binary Format. Others are in - // CustomSections + // Note: This is the only SectionIDCustom defined in the WebAssembly 1.0 (MVP) Binary Format. + // Others are skipped as they are not used in wazero. // // See https://www.w3.org/TR/wasm-core-1/#name-section%E2%91%A0 - NameSection *NameSection - - // CustomSections is set when at least one non-standard, or otherwise unsupported custom section was found in the - // binary format. - // - // Note: This never contains a "name" because that is standard and parsed into the NameSection. - // Note: In the Binary Format, this is SectionIDCode. - // // See https://www.w3.org/TR/wasm-core-1/#custom-section%E2%91%A0 - CustomSections map[string][]byte + NameSection *NameSection } // Index is the offset in an index namespace, not necessarily an absolute position in a Module section. This is because @@ -226,7 +227,7 @@ const ( ValueTypeF64 ValueType = 0x7c ) -// ValuTypeName returns the type name of the given ValueType as a string. +// ValueTypeName returns the type name of the given ValueType as a string. // These type names match the names used in the WebAssembly text format. // Note that ValueTypeName returns "unknown", if an undefined ValueType value is passed. func ValueTypeName(t ValueType) string { @@ -461,3 +462,46 @@ func (m *Module) allDeclarations() (functions []Index, globals []*GlobalType, me tables = append(tables, m.TableSection...) return } + +// SectionElementCount returns the count of elements in a given section ID +// +// For example... +// * SectionIDType returns the count of FunctionType +// * SectionIDCustom returns one if the NameSection is present +// * SectionIDExport returns the count of unique export names +func (m *Module) SectionElementCount(sectionID SectionID) uint32 { // element as in vector elements! + switch sectionID { + case SectionIDCustom: + if m.NameSection != nil { + return 1 + } + return 0 + case SectionIDType: + return uint32(len(m.TypeSection)) + case SectionIDImport: + return uint32(len(m.ImportSection)) + case SectionIDFunction: + return uint32(len(m.FunctionSection)) + case SectionIDTable: + return uint32(len(m.TableSection)) + case SectionIDMemory: + return uint32(len(m.MemorySection)) + case SectionIDGlobal: + return uint32(len(m.GlobalSection)) + case SectionIDExport: + return uint32(len(m.ExportSection)) + case SectionIDStart: + if m.StartSection != nil { + return 1 + } + return 0 + case SectionIDElement: + return uint32(len(m.ElementSection)) + case SectionIDCode: + return uint32(len(m.CodeSection)) + case SectionIDData: + return uint32(len(m.DataSection)) + default: + panic(fmt.Errorf("BUG: unknown section: %d", sectionID)) + } +} diff --git a/wasm/module_test.go b/wasm/module_test.go index 7789a25eec..13a463adc2 100644 --- a/wasm/module_test.go +++ b/wasm/module_test.go @@ -203,3 +203,103 @@ func TestModule_allDeclarations(t *testing.T) { }) } } + +func TestModule_SectionSize(t *testing.T) { + i32, f32 := ValueTypeI32, ValueTypeF32 + zero := uint32(0) + empty := &ConstantExpression{Opcode: OpcodeI32Const, Data: []byte{0x00}} + + tests := []struct { + name string + input *Module + expected map[string]uint32 + }{ + { + name: "empty", + input: &Module{}, + expected: map[string]uint32{}, + }, + { + name: "only name section", + input: &Module{NameSection: &NameSection{ModuleName: "simple"}}, + expected: map[string]uint32{"custom": 1}, + }, + { + name: "type section", + input: &Module{ + TypeSection: []*FunctionType{ + {}, + {Params: []ValueType{i32, i32}, Results: []ValueType{i32}}, + {Params: []ValueType{i32, i32, i32, i32}, Results: []ValueType{i32}}, + }, + }, + expected: map[string]uint32{"type": 3}, + }, + { + name: "type and import section", + input: &Module{ + TypeSection: []*FunctionType{ + {Params: []ValueType{i32, i32}, Results: []ValueType{i32}}, + {Params: []ValueType{f32, f32}, Results: []ValueType{f32}}, + }, + ImportSection: []*Import{ + { + Module: "Math", Name: "Mul", + Kind: ImportKindFunc, + DescFunc: 1, + }, { + Module: "Math", Name: "Add", + Kind: ImportKindFunc, + DescFunc: 0, + }, + }, + }, + expected: map[string]uint32{"import": 2, "type": 2}, + }, + { + name: "type function and start section", + input: &Module{ + TypeSection: []*FunctionType{{}}, + FunctionSection: []Index{0}, + CodeSection: []*Code{ + {Body: []byte{OpcodeLocalGet, 0, OpcodeLocalGet, 1, OpcodeI32Add, OpcodeEnd}}, + }, + ExportSection: map[string]*Export{ + "AddInt": {Name: "AddInt", Kind: ExportKindFunc, Index: Index(0)}, + }, + StartSection: &zero, + }, + expected: map[string]uint32{"code": 1, "export": 1, "function": 1, "start": 1, "type": 1}, + }, + { + name: "memory and data", + input: &Module{ + MemorySection: []*MemoryType{{Min: 1}}, + DataSection: []*DataSegment{{MemoryIndex: 0, OffsetExpression: empty}}, + }, + expected: map[string]uint32{"data": 1, "memory": 1}, + }, + { + name: "table and element", + input: &Module{ + TableSection: []*TableType{{ElemType: 0x70, Limit: &LimitsType{Min: 1}}}, + ElementSection: []*ElementSegment{{TableIndex: 0, OffsetExpr: empty}}, + }, + expected: map[string]uint32{"element": 1, "table": 1}, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + actual := map[string]uint32{} + for i := SectionID(0); i <= SectionIDData; i++ { + if size := tc.input.SectionElementCount(i); size > 0 { + actual[SectionIDName(i)] = size + } + } + require.Equal(t, tc.expected, actual) + }) + } +} diff --git a/wasm/store.go b/wasm/store.go index d23a7cde34..3771303c0a 100644 --- a/wasm/store.go +++ b/wasm/store.go @@ -288,17 +288,21 @@ func (s *Store) Instantiate(module *Module, name string) error { for i, f := range instance.Functions { if err := s.engine.Compile(f); err != nil { - return fmt.Errorf("compilation failed at index %d/%d: %v", i, len(module.FunctionSection)-1, err) + idx := module.SectionElementCount(SectionIDFunction) - 1 + return fmt.Errorf("compilation failed at index %d/%d: %v", i, idx, err) } } // Check the start function is valid. - if startIndex := module.StartSection; startIndex != nil { - index := *startIndex - if int(index) >= len(instance.Functions) { - return fmt.Errorf("invalid start function index: %d", index) + // TODO: this should be verified during decode so that errors have the correct source positions + var startFunction *FunctionInstance + if module.StartSection != nil { + startIndex := *module.StartSection + if startIndex >= uint32(len(instance.Functions)) { + return fmt.Errorf("invalid start function index: %d", startIndex) } - ft := instance.Functions[index].FunctionType + startFunction = instance.Functions[startIndex] + ft := startFunction.FunctionType if len(ft.Type.Params) != 0 || len(ft.Type.Results) != 0 { return fmt.Errorf("start function must have the empty function type") } @@ -308,10 +312,8 @@ func (s *Store) Instantiate(module *Module, name string) error { rollbackFuncs = nil // Execute the start function. - ctx := context.Background() - if startIndex := module.StartSection; startIndex != nil { - f := instance.Functions[*startIndex] - if _, err := s.engine.Call(ctx, f); err != nil { + if startFunction != nil { + if _, err = s.engine.Call(context.Background(), startFunction); err != nil { return fmt.Errorf("calling start function failed: %v", err) } } @@ -596,9 +598,9 @@ func (s *Store) buildFunctionInstances(module *Module, target *ModuleInstance) ( n, nLen := 0, len(functionNames) for codeIndex, typeIndex := range module.FunctionSection { - if typeIndex >= uint32(len(module.TypeSection)) { + if typeIndex >= module.SectionElementCount(SectionIDType) { return rollbackFuncs, fmt.Errorf("function type index out of range") - } else if codeIndex >= len(module.CodeSection) { + } else if uint32(codeIndex) >= module.SectionElementCount(SectionIDCode) { return rollbackFuncs, fmt.Errorf("code index out of range") } @@ -629,7 +631,8 @@ func (s *Store) buildFunctionInstances(module *Module, target *ModuleInstance) ( } if err := validateFunctionInstance(f, funcs, globals, mems, tables, module.TypeSection, maximumValuesOnStack); err != nil { - return rollbackFuncs, fmt.Errorf("invalid function '%s' (%d/%d): %v", f.Name, codeIndex, len(module.FunctionSection)-1, err) + idx := module.SectionElementCount(SectionIDFunction) - 1 + return rollbackFuncs, fmt.Errorf("invalid function '%s' (%d/%d): %v", f.Name, codeIndex, idx, err) } err = s.addFunctionInstance(f) @@ -680,8 +683,8 @@ func (s *Store) buildMemoryInstances(module *Module, target *ModuleInstance) (ro } size := uint64(offset) + uint64(len(d.Init)) - maxPage := uint32(MemoryMaxPages) - if int(d.MemoryIndex) < len(module.MemorySection) && module.MemorySection[d.MemoryIndex].Max != nil { + maxPage := MemoryMaxPages + if d.MemoryIndex < module.SectionElementCount(SectionIDMemory) && module.MemorySection[d.MemoryIndex].Max != nil { maxPage = *module.MemorySection[d.MemoryIndex].Max } if size > memoryPagesToBytesNum(maxPage) { @@ -734,7 +737,7 @@ func (s *Store) buildTableInstances(module *Module, target *ModuleInstance) (rol size := offset + len(elem.Init) max := uint32(math.MaxUint32) - if int(elem.TableIndex) < len(module.TableSection) && module.TableSection[elem.TableIndex].Limit.Max != nil { + if elem.TableIndex < module.SectionElementCount(SectionIDTable) && module.TableSection[elem.TableIndex].Limit.Max != nil { max = *module.TableSection[elem.TableIndex].Limit.Max } @@ -772,7 +775,7 @@ func (s *Store) buildTableInstances(module *Module, target *ModuleInstance) (rol } func (s *Store) buildExportInstances(module *Module, target *ModuleInstance) (rollbackFuncs []func(), err error) { - target.Exports = make(map[string]*ExportInstance, len(module.ExportSection)) + target.Exports = make(map[string]*ExportInstance, module.SectionElementCount(SectionIDExport)) for name, exp := range module.ExportSection { index := exp.Index var ei *ExportInstance diff --git a/wasm/text/decoder.go b/wasm/text/decoder.go index 521f0fff5d..1d5f89d7a8 100644 --- a/wasm/text/decoder.go +++ b/wasm/text/decoder.go @@ -18,17 +18,32 @@ const ( positionParam positionResult positionModule - positionModuleType - positionModuleImport - positionModuleImportFunc - positionModuleFunc - positionModuleMemory - positionModuleExport - positionModuleExportFunc - positionModuleExportMemory - positionModuleStart + positionImport + positionImportFunc + positionMemory + positionExport + positionExportFunc + positionExportMemory + positionStart ) +type callbackPosition byte + +const ( + // callbackPositionUnhandledToken is set on a token besides a paren. + callbackPositionUnhandledToken callbackPosition = iota + // callbackPositionUnhandledField is at the field name (tokenKeyword) which isn't "type", "param" or "result" + callbackPositionUnhandledField + // callbackPositionEndField is at the end (tokenRParen) of the field enclosing the type use. + callbackPositionEndField +) + +// moduleParser parses a single wasm.Module from WebAssembly 1.0 (MVP) Text format. +// +// Note: The indexNamespace of wasm.SectionIDMemory and wasm.SectionIDTable allow up-to-one item. For example, you +// cannot define both one import and one module-defined memory, rather one or the other (or none). Even if these rules +// are also enforced in module instantiation, they are also enforced here, to allow relevant source line/col in errors. +// See https://www.w3.org/TR/wasm-core-1/#modules%E2%91%A3 type moduleParser struct { // source is the entire WebAssembly text format source code being parsed. source []byte @@ -153,27 +168,38 @@ func (p *moduleParser) beginModuleField(tok tokenType, tokenBytes []byte, _, _ u if tok == tokenKeyword { switch string(tokenBytes) { case "type": - p.pos = positionModuleType + p.pos = positionType return p.typeParser.begin, nil case "import": - p.pos = positionModuleImport + p.pos = positionImport return p.parseImportModule, nil case "func": - p.pos = positionModuleFunc - p.funcParser.currentIdx = wasm.Index(len(p.module.FunctionSection)) + p.pos = positionFunc + p.funcParser.currentIdx = p.module.SectionElementCount(wasm.SectionIDFunction) return p.funcParser.begin, nil + case "table": + return nil, errors.New("TODO: table") case "memory": - p.pos = positionModuleMemory + if p.memoryNamespace.count > 0 { + return nil, moreThanOneInvalidInSection(wasm.SectionIDMemory) + } + p.pos = positionMemory return p.memoryParser.begin, nil + case "global": + return nil, errors.New("TODO: global") case "export": - p.pos = positionModuleExport + p.pos = positionExport return p.parseExportName, nil case "start": - if p.module.StartSection != nil { - return nil, errors.New("redundant start") + if p.module.SectionElementCount(wasm.SectionIDStart) > 0 { + return nil, moreThanOneInvalid("start") } - p.pos = positionModuleStart + p.pos = positionStart return p.parseStart, nil + case "elem": + return nil, errors.New("TODO: elem") + case "data": + return nil, errors.New("TODO: data") default: return nil, unexpectedFieldName(tokenBytes) } @@ -243,9 +269,9 @@ func (p *moduleParser) parseImportModule(tok tokenType, tokenBytes []byte, _, _ // parseImportName returns parseImport after recording the import name, or errs if it couldn't be read. // // Ex. Import name is present `(import "Math" "PI" (func (result f32)))` -// starts here --^ ^ -// records PI --^ | -// parseImport resumes here --+ +// starts here --^^ ^ +// records PI --+ | +// parseImport resumes here --+ // // Ex. Imported function name is absent `(import "Math" (func (result f32)))` // errs here --+ @@ -284,7 +310,10 @@ func (p *moduleParser) beginImportDesc(tok tokenType, tokenBytes []byte, _, _ ui switch string(tokenBytes) { case "func": - p.pos = positionModuleImportFunc + if p.module.SectionElementCount(wasm.SectionIDFunction) > 0 { + return nil, importAfterModuleDefined(wasm.SectionIDFunction) + } + p.pos = positionImportFunc return p.parseImportFuncID, nil case "table", "memory", "global": return nil, fmt.Errorf("TODO: %s", tokenBytes) @@ -303,17 +332,18 @@ func (p *moduleParser) beginImportDesc(tok tokenType, tokenBytes []byte, _, _ ui // calls parseImportFunc here --^ func (p *moduleParser) parseImportFuncID(tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { if tok == tokenID { // Ex. $main - id, err := p.funcNamespace.setID(tokenBytes) - if err != nil { + if name, err := p.funcNamespace.setID(tokenBytes); err != nil { return nil, err + } else { + p.addFunctionName(name) } - p.addFunctionName(id) return p.parseImportFunc, nil } return p.parseImportFunc(tok, tokenBytes, line, col) } -// addFunctionName appends the current imported or module-defined function name to the wasm.NameSection +// addFunctionName appends the current imported or module-defined function name to the wasm.NameSection iff it is not +// empty. func (p *moduleParser) addFunctionName(name string) { if name == "" { return // there's no value in an empty name @@ -331,7 +361,7 @@ func (p *moduleParser) addFunctionName(name string) { // Ex. If there is no signature `(import "" "main" (func))` // calls onImportFunc here ---^ func (p *moduleParser) parseImportFunc(tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { - idx := wasm.Index(len(p.module.ImportSection)) + idx := p.module.SectionElementCount(wasm.SectionIDImport) if tok == tokenID { // Ex. (func $main $main) return nil, fmt.Errorf("redundant ID %s", tokenBytes) } @@ -342,7 +372,7 @@ func (p *moduleParser) parseImportFunc(tok tokenType, tokenBytes []byte, line, c // onImportFunc records the type index and local names of the current imported function, and increments // funcNamespace as it is shared across imported and module-defined functions. Finally, this returns parseImportEnd to // the current import into the ImportSection. -func (p *moduleParser) onImportFunc(typeIdx wasm.Index, paramNames wasm.NameMap, pos onTypeUsePosition, tok tokenType, tokenBytes []byte, _, _ uint32) (tokenParser, error) { +func (p *moduleParser) onImportFunc(typeIdx wasm.Index, paramNames wasm.NameMap, pos callbackPosition, tok tokenType, tokenBytes []byte, _, _ uint32) (tokenParser, error) { i := p.currentModuleField.(*wasm.Import) i.Kind = wasm.ImportKindFunc i.DescFunc = typeIdx @@ -351,12 +381,12 @@ func (p *moduleParser) onImportFunc(typeIdx wasm.Index, paramNames wasm.NameMap, p.funcNamespace.count++ switch pos { - case onTypeUseUnhandledToken: + case callbackPositionUnhandledToken: return nil, unexpectedToken(tok, tokenBytes) - case onTypeUseUnhandledField: + case callbackPositionUnhandledField: return nil, unexpectedFieldName(tokenBytes) - case onTypeUseEndField: - p.pos = positionModuleImport + case callbackPositionEndField: + p.pos = positionImport return p.parseImportEnd, nil } return p.parseImportFuncEnd, nil @@ -368,7 +398,7 @@ func (p *moduleParser) parseImportFuncEnd(tok tokenType, tokenBytes []byte, _, _ return nil, unexpectedToken(tok, tokenBytes) } - p.pos = positionModuleImport + p.pos = positionImport return p.parseImportEnd, nil } @@ -431,7 +461,7 @@ func (p *moduleParser) parseExportName(tok tokenType, tokenBytes []byte, _, _ ui case tokenString: // Ex. "" or "PI" name := string(tokenBytes[1 : len(tokenBytes)-1]) // strip quotes if _, ok := p.module.ExportSection[name]; ok { - return nil, fmt.Errorf("duplicate name %q", name) + return nil, fmt.Errorf("%q already exported", name) } p.currentModuleField = &wasm.Export{Name: name} return p.parseExport, nil @@ -464,10 +494,10 @@ func (p *moduleParser) beginExportDesc(tok tokenType, tokenBytes []byte, _, _ ui switch string(tokenBytes) { case "func": - p.pos = positionModuleExportFunc + p.pos = positionExportFunc return p.parseExportDesc, nil case "memory": - p.pos = positionModuleExportMemory + p.pos = positionExportMemory return p.parseExportDesc, nil case "table", "global": return nil, fmt.Errorf("TODO: %s", tokenBytes) @@ -481,16 +511,16 @@ func (p *moduleParser) parseExportDesc(tok tokenType, tokenBytes []byte, line, c var namespace *indexNamespace e := p.currentModuleField.(*wasm.Export) switch p.pos { - case positionModuleExportFunc: + case positionExportFunc: e.Kind = wasm.ExportKindFunc namespace = p.funcNamespace - case positionModuleExportMemory: + case positionExportMemory: e.Kind = wasm.ExportKindMemory namespace = p.memoryNamespace default: panic(fmt.Errorf("BUG: unhandled parsing state on parseExportDesc: %v", p.pos)) } - eIdx := wasm.Index(len(p.module.ExportSection)) + eIdx := p.module.SectionElementCount(wasm.SectionIDExport) typeIdx, resolved, err := namespace.parseIndex(wasm.SectionIDExport, eIdx, 0, tok, tokenBytes, line, col) if err != nil { return nil, err @@ -514,7 +544,7 @@ func (p *moduleParser) parseExportDescEnd(tok tokenType, tokenBytes []byte, _, _ case tokenUN, tokenID: return nil, errors.New("redundant index") case tokenRParen: - p.pos = positionModuleExport + p.pos = positionExport return p.parseExportEnd, nil default: return nil, unexpectedToken(tok, tokenBytes) @@ -612,10 +642,10 @@ func (p *moduleParser) resolveFunctionIndices(module *wasm.Module) error { // This errs if any type index is unresolved, out of range or mismatches an inlined type use signature. func (p *moduleParser) resolveTypeUses(module *wasm.Module) error { inlinedToRealIdx := p.addInlinedTypes() - return p.resolveInlined(module, inlinedToRealIdx) + return p.resolveInlinedTypes(module, inlinedToRealIdx) } -func (p *moduleParser) resolveInlined(module *wasm.Module, inlinedToRealIdx map[wasm.Index]wasm.Index) error { +func (p *moduleParser) resolveInlinedTypes(module *wasm.Module, inlinedToRealIdx map[wasm.Index]wasm.Index) error { // Now look for all the uses of the inlined types and apply the mapping above for _, i := range p.typeUseParser.inlinedTypeIndices { switch i.section { @@ -699,28 +729,28 @@ func (p *moduleParser) errorContext() string { return "" case positionModule: return "module" - case positionModuleType: - idx := wasm.Index(len(p.module.TypeSection)) + case positionType: + idx := p.module.SectionElementCount(wasm.SectionIDType) return fmt.Sprintf("module.type[%d]%s", idx, p.typeParser.errorContext()) - case positionModuleImport, positionModuleImportFunc: // TODO: table, memory or global - idx := wasm.Index(len(p.module.ImportSection)) - if p.pos == positionModuleImport { + case positionImport, positionImportFunc: // TODO: table, memory or global + idx := p.module.SectionElementCount(wasm.SectionIDImport) + if p.pos == positionImport { return fmt.Sprintf("module.import[%d]", idx) } return fmt.Sprintf("module.import[%d].func%s", idx, p.typeUseParser.errorContext()) - case positionModuleFunc: - idx := wasm.Index(len(p.module.FunctionSection)) + case positionFunc: + idx := p.module.SectionElementCount(wasm.SectionIDFunction) return fmt.Sprintf("module.func[%d]%s", idx, p.typeUseParser.errorContext()) - case positionModuleMemory: - idx := wasm.Index(len(p.module.MemorySection)) + case positionMemory: + idx := p.module.SectionElementCount(wasm.SectionIDMemory) return fmt.Sprintf("module.memory[%d]", idx) - case positionModuleExport, positionModuleExportFunc: // TODO: table, memory or global - idx := wasm.Index(len(p.module.ExportSection)) - if p.pos == positionModuleExport { + case positionExport, positionExportFunc: // TODO: table, memory or global + idx := p.module.SectionElementCount(wasm.SectionIDExport) + if p.pos == positionExport { return fmt.Sprintf("module.export[%d]", idx) } return fmt.Sprintf("module.export[%d].func", idx) - case positionModuleStart: + case positionStart: return "module.start" default: // parserPosition is an enum, we expect to have handled all cases above. panic if we didn't panic(fmt.Errorf("BUG: unhandled parsing state on errorContext: %v", p.pos)) diff --git a/wasm/text/decoder_test.go b/wasm/text/decoder_test.go index 43ab766e6b..5b08370367 100644 --- a/wasm/text/decoder_test.go +++ b/wasm/text/decoder_test.go @@ -1343,34 +1343,28 @@ func TestDecodeModule(t *testing.T) { }, }, { - name: "export different memory - numeric", + name: "export memory - numeric", input: `(module (memory 0) - (memory 1) (export "foo" (memory 0)) - (export "bar" (memory 1)) )`, expected: &wasm.Module{ - MemorySection: []*wasm.MemoryType{{Min: 0}, {Min: 1}}, + MemorySection: []*wasm.MemoryType{{Min: 0}}, ExportSection: map[string]*wasm.Export{ "foo": {Name: "foo", Kind: wasm.ExportKindMemory, Index: 0}, - "bar": {Name: "bar", Kind: wasm.ExportKindMemory, Index: 1}, }, }, }, { - name: "export different memory - numeric - late", + name: "export memory - numeric - late", input: `(module (export "foo" (memory 0)) - (export "bar" (memory 1)) (memory 0) - (memory 1) )`, expected: &wasm.Module{ - MemorySection: []*wasm.MemoryType{{Min: 0}, {Min: 1}}, + MemorySection: []*wasm.MemoryType{{Min: 0}}, ExportSection: map[string]*wasm.Export{ "foo": {Name: "foo", Kind: wasm.ExportKindMemory, Index: 0}, - "bar": {Name: "bar", Kind: wasm.ExportKindMemory, Index: 1}, }, }, }, @@ -1396,15 +1390,13 @@ func TestDecodeModule(t *testing.T) { { name: "export memory - ID", input: `(module - (memory 1) - (memory $mem 2) - (memory $mam 3) + (memory $mem 1) (export "memory" (memory $mem)) )`, expected: &wasm.Module{ - MemorySection: []*wasm.MemoryType{{Min: 1}, {Min: 2}, {Min: 3}}, + MemorySection: []*wasm.MemoryType{{Min: 1}}, ExportSection: map[string]*wasm.Export{ - "memory": {Name: "memory", Kind: wasm.ExportKindMemory, Index: 2}, + "memory": {Name: "memory", Kind: wasm.ExportKindMemory, Index: 0}, }, }, }, @@ -1780,6 +1772,11 @@ func TestParseModule_Errors(t *testing.T) { input: "(module (import \"foo\" \"bar\" (func (type $v_v))))", expectedErr: "1:41: unknown ID $v_v", }, + { + name: "import func after func", + input: "(module (func) (import \"\" \"\" (func)))", + expectedErr: "1:31: import after module-defined function in module.import[0]", + }, { name: "func invalid name", input: "(module (func baz)))", @@ -1833,7 +1830,7 @@ func TestParseModule_Errors(t *testing.T) { { name: "func duplicate result", input: "(module (func (param i32) (result i32) (result i32)))", - expectedErr: "1:41: duplicate result in module.func[0]", + expectedErr: "1:41: at most one result allowed in module.func[0]", }, { name: "func double result type", @@ -1885,11 +1882,6 @@ func TestParseModule_Errors(t *testing.T) { input: "(module (import \"\" \"\" (func $main)) (func $main)))", expectedErr: "1:43: duplicate ID $main in module.func[0]", }, - { - name: "import func ID clashes with func ID", - input: "(module (func $main) (import \"\" \"\" (func $main)))", - expectedErr: "1:42: duplicate ID $main in module.import[0].func", - }, { name: "func ID after result", input: "(module (func (result i32) $main)))", @@ -1946,6 +1938,11 @@ func TestParseModule_Errors(t *testing.T) { )`, expectedErr: "3:15: unknown ID $mein in module.code[1].body[1]", }, + { + name: "second memory", + input: "(module (memory 1) (memory 1))", + expectedErr: "1:21: at most one memory allowed in module", + }, { name: "export duplicates empty name", input: `(module @@ -1954,7 +1951,7 @@ func TestParseModule_Errors(t *testing.T) { (export "" (func 0)) (export "" (memory 1)) )`, - expectedErr: "5:13: duplicate name \"\" in module.export[1]", + expectedErr: `5:13: "" already exported in module.export[1]`, }, { name: "export duplicates name", @@ -1964,7 +1961,7 @@ func TestParseModule_Errors(t *testing.T) { (export "a" (func 0)) (export "a" (memory 1)) )`, - expectedErr: "5:13: duplicate name \"a\" in module.export[1]", + expectedErr: `5:13: "a" already exported in module.export[1]`, }, { name: "export double name", @@ -2057,7 +2054,7 @@ func TestParseModule_Errors(t *testing.T) { { name: "double start", input: "(module (start $main) (start $main))", - expectedErr: "1:24: redundant start in module", + expectedErr: "1:24: at most one start allowed in module", }, { name: "wrong start", @@ -2098,13 +2095,13 @@ func TestModuleParser_ErrorContext(t *testing.T) { }{ {input: "initial", pos: positionInitial, expected: ""}, {input: "module", pos: positionModule, expected: "module"}, - {input: "module import", pos: positionModuleImport, expected: "module.import[0]"}, - {input: "module import func", pos: positionModuleImportFunc, expected: "module.import[0].func"}, - {input: "module func", pos: positionModuleFunc, expected: "module.func[0]"}, - {input: "module memory", pos: positionModuleMemory, expected: "module.memory[0]"}, - {input: "module export", pos: positionModuleExport, expected: "module.export[0]"}, - {input: "module export func", pos: positionModuleExportFunc, expected: "module.export[0].func"}, - {input: "start", pos: positionModuleStart, expected: "module.start"}, + {input: "module import", pos: positionImport, expected: "module.import[0]"}, + {input: "module import func", pos: positionImportFunc, expected: "module.import[0].func"}, + {input: "module func", pos: positionFunc, expected: "module.func[0]"}, + {input: "module memory", pos: positionMemory, expected: "module.memory[0]"}, + {input: "module export", pos: positionExport, expected: "module.export[0]"}, + {input: "module export func", pos: positionExportFunc, expected: "module.export[0].func"}, + {input: "start", pos: positionStart, expected: "module.start"}, } for _, tt := range tests { diff --git a/wasm/text/errors.go b/wasm/text/errors.go index a5c34df024..987fc80740 100644 --- a/wasm/text/errors.go +++ b/wasm/text/errors.go @@ -45,6 +45,36 @@ func unexpectedToken(tok tokenType, tokenBytes []byte) error { } } +// importAfterModuleDefined is the failure for the condition "all imports must occur before any regular definition", +// which applies regardless of abbreviation. +// +// Ex. Both of these fail because an import can only be declared when SectionIDFunction is empty. +// `(func) (import "" "" (func))` which is the same as `(func) (import "" "" (func))` +// +// See https://www.w3.org/TR/wasm-core-1/#modules%E2%91%A0%E2%91%A2 +func importAfterModuleDefined(section wasm.SectionID) error { + return fmt.Errorf("import after module-defined %s", wasm.SectionIDName(section)) +} + +// moreThanOneInvalidInSection allows enforcement of section size limits. +// +// Ex. All of these fail because they result in two memories. +// * `(module (memory 1) (memory 1))` +// * `(module (memory 1) (import "" "" (memory 1)))` +// * Note the latter expands to the same as the former: `(import "" "" (memory 1))` +// * `(module (import "" "" (memory 1)) (import "" "" (memory 1)))` +// +// See https://www.w3.org/TR/wasm-core-1/#tables%E2%91%A0 +// See https://www.w3.org/TR/wasm-core-1/#memories%E2%91%A0 +func moreThanOneInvalidInSection(section wasm.SectionID) error { + return moreThanOneInvalid(wasm.SectionIDName(section)) +} + +// moreThanOneInvalid is the failure when a declaration that can result in more than one item. +func moreThanOneInvalid(context string) error { + return fmt.Errorf("at most one %s allowed", context) +} + func unhandledSection(section wasm.SectionID) error { return fmt.Errorf("BUG: unhandled %s", wasm.SectionIDName(section)) } diff --git a/wasm/text/func_parser.go b/wasm/text/func_parser.go index efac06edfb..9e907e5c52 100644 --- a/wasm/text/func_parser.go +++ b/wasm/text/func_parser.go @@ -17,8 +17,8 @@ type onFunc func(typeIdx wasm.Index, code *wasm.Code, name string, localNames wa // funcParser parses any instructions and dispatches to onFunc. // // Ex. `(module (func (nop)))` -// starts here --^ ^ -// calls onFunc here --+ +// begin here --^ ^ +// end calls onFunc here --+ // // Note: funcParser is reusable. The caller resets via begin. type funcParser struct { @@ -73,11 +73,11 @@ func (p *funcParser) begin(tok tokenType, tokenBytes []byte, line, col uint32) ( // are read. Finally, this finishes via endFunc. // // Ex. `(module (func $math.pi (result f32))` -// starts here --^ ^ +// begin here --^ ^ // endFunc resumes here --+ // // Ex. `(module (func $math.pi (result f32) (local i32) )` -// starts here --^ ^ ^ +// begin here --^ ^ ^ // funcParser.afterTypeUse resumes here --+ | // endFunc resumes here --+ // @@ -98,11 +98,11 @@ func (p *funcParser) parseFunc(tok tokenType, tokenBytes []byte, line, col uint3 // Ex. Given the source `(module (func nop))` // afterTypeUse starts here --^ ^ // calls onFunc here --+ -func (p *funcParser) afterTypeUse(typeIdx wasm.Index, paramNames wasm.NameMap, pos onTypeUsePosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { +func (p *funcParser) afterTypeUse(typeIdx wasm.Index, paramNames wasm.NameMap, pos callbackPosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { switch pos { - case onTypeUseEndField: + case callbackPositionEndField: return p.onFunc(typeIdx, codeEnd, p.currentName, paramNames) - case onTypeUseUnhandledField: + case callbackPositionUnhandledField: return sExpressionsUnsupported(tok, tokenBytes, line, col) } @@ -120,7 +120,7 @@ func sExpressionsUnsupported(tok tokenType, tokenBytes []byte, _, _ uint32) (tok case "param": return nil, errors.New("param after result") case "result": - return nil, errors.New("duplicate result") + return nil, moreThanOneInvalid("result") case "local": return nil, errors.New("TODO: local") } diff --git a/wasm/text/func_parser_test.go b/wasm/text/func_parser_test.go index 0c2e2e7cc5..2294f05aa9 100644 --- a/wasm/text/func_parser_test.go +++ b/wasm/text/func_parser_test.go @@ -213,7 +213,7 @@ func TestFuncParser_Errors(t *testing.T) { { name: "duplicate result", source: "(func (result i32) (result i32))", - expectedErr: "1:21: duplicate result", + expectedErr: "1:21: at most one result allowed", }, } diff --git a/wasm/text/index_namespace.go b/wasm/text/index_namespace.go index 3bc406ffac..f6db68d621 100644 --- a/wasm/text/index_namespace.go +++ b/wasm/text/index_namespace.go @@ -38,12 +38,22 @@ type indexNamespace struct { // setID ensures the given tokenID is unique within this context and raises an error if not. The resulting mapping is // stripped of the leading '$' to match other tools, as described in stripDollar. func (i *indexNamespace) setID(idToken []byte) (string, error) { - id := string(stripDollar(idToken)) - if _, ok := i.idToIdx[id]; ok { - return id, fmt.Errorf("duplicate ID %s", idToken) + name, err := i.requireNoID(idToken) + if err != nil { + return name, err + } + i.idToIdx[name] = i.count + return name, nil +} + +// hasID checks to see if this tokenID is unique within this context and returns an error. The result string is +// stripped of the leading '$' to match other tools, as described in stripDollar. +func (i *indexNamespace) requireNoID(idToken []byte) (string, error) { + name := string(stripDollar(idToken)) + if _, ok := i.idToIdx[name]; ok { + return name, fmt.Errorf("duplicate ID %s", idToken) } - i.idToIdx[id] = i.count - return id, nil + return name, nil } // parseIndex is a tokenParser called in a field that can only contain a symbolic identifier or raw numeric index. diff --git a/wasm/text/memory_parser_test.go b/wasm/text/memory_parser_test.go index 534bcbb9d8..9498eb1484 100644 --- a/wasm/text/memory_parser_test.go +++ b/wasm/text/memory_parser_test.go @@ -60,8 +60,8 @@ func TestMemoryParser(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - namespace := newIndexNamespace() - parsed, tp, err := parseMemoryType(namespace, tc.input) + memoryNamespace := newIndexNamespace() + parsed, tp, err := parseMemoryType(memoryNamespace, tc.input) require.NoError(t, err) require.Equal(t, tc.expected, parsed) require.Equal(t, uint32(1), tp.memoryNamespace.count) @@ -151,13 +151,13 @@ func TestMemoryParser_Errors(t *testing.T) { }) } -func parseMemoryType(namespace *indexNamespace, input string) (*wasm.MemoryType, *memoryParser, error) { +func parseMemoryType(memoryNamespace *indexNamespace, input string) (*wasm.MemoryType, *memoryParser, error) { var parsed *wasm.MemoryType var setFunc onMemory = func(min uint32, max *uint32) tokenParser { parsed = &wasm.MemoryType{Min: min, Max: max} return parseErr } - tp := newMemoryParser(namespace, setFunc) + tp := newMemoryParser(memoryNamespace, setFunc) // memoryParser starts after the '(memory', so we need to eat it first! _, _, err := lex(skipTokens(2, tp.begin), []byte(input)) return parsed, tp, err diff --git a/wasm/text/type_parser_test.go b/wasm/text/type_parser_test.go index ec07c33c48..76f72c2d59 100644 --- a/wasm/text/type_parser_test.go +++ b/wasm/text/type_parser_test.go @@ -100,8 +100,8 @@ func TestTypeParser(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { - namespace := newIndexNamespace() - parsed, tp, err := parseFunctionType(namespace, tc.input) + typeNamespace := newIndexNamespace() + parsed, tp, err := parseFunctionType(typeNamespace, tc.input) require.NoError(t, err) require.Equal(t, tc.expected, parsed) require.Equal(t, uint32(1), tp.typeNamespace.count) @@ -246,13 +246,13 @@ func TestTypeParser_Errors(t *testing.T) { }) } -func parseFunctionType(namespace *indexNamespace, input string) (*wasm.FunctionType, *typeParser, error) { +func parseFunctionType(typeNamespace *indexNamespace, input string) (*wasm.FunctionType, *typeParser, error) { var parsed *wasm.FunctionType var setFunc onType = func(ft *wasm.FunctionType) tokenParser { parsed = ft return parseErr } - tp := newTypeParser(namespace, setFunc) + tp := newTypeParser(typeNamespace, setFunc) // typeParser starts after the '(type', so we need to eat it first! _, _, err := lex(skipTokens(2, tp.begin), []byte(input)) return parsed, tp, err diff --git a/wasm/text/typeuse_parser.go b/wasm/text/typeuse_parser.go index 8c2a6838ca..87eb4dfae0 100644 --- a/wasm/text/typeuse_parser.go +++ b/wasm/text/typeuse_parser.go @@ -12,32 +12,21 @@ func newTypeUseParser(module *wasm.Module, typeNamespace *indexNamespace) *typeU return &typeUseParser{module: module, typeNamespace: typeNamespace} } -type onTypeUsePosition byte - -const ( - // onTypeUseUnhandledToken is set on a token besides a paren. - onTypeUseUnhandledToken onTypeUsePosition = iota - // onTypeUseUnhandledField is at the field name (tokenKeyword) which isn't "type", "param" or "result" - onTypeUseUnhandledField - // onTypeUseEndField is at the end (tokenRParen) of the field enclosing the type use. - onTypeUseEndField -) - // onTypeUse is invoked when the grammar "(param)* (result)?" completes. // -// * typeIdx if unresolved, this is replaced in typeUseParser.resolveTypeUses +// * typeIdx if unresolved, this is replaced in moduleParser.resolveTypeUses // * paramNames is nil unless IDs existed on at least one "param" field. // * pos is the context used to determine which tokenParser to return // // Note: this is called when neither a "param" nor a "result" field are parsed, or on any field following a "param" // that is not a "result": pos clarifies this. -type onTypeUse func(typeIdx wasm.Index, paramNames wasm.NameMap, pos onTypeUsePosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) +type onTypeUse func(typeIdx wasm.Index, paramNames wasm.NameMap, pos callbackPosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) -// typeUseParser parses an inlined type from a field such "func" and calls to onTypeUse or onUnknownField. +// typeUseParser parses an inlined type from a field such "func" and calls onTypeUse. // -// Ex. `(import "Math" "PI" (func $math.pi (result f32))` +// Ex. `(import "Math" "PI" (func $math.pi (result f32)))` // starts here --^ ^ -// onTypeUse resumes here --+ +// onTypeUse resumes here --+ // // Note: Unlike normal parsers, this is not used for an entire field (enclosed by parens). Rather, this only handles // "type", "param" and "result" inner fields in the correct order. @@ -113,7 +102,7 @@ type typeUseParser struct { // onTypeUse resumes here --+ // func (p *typeUseParser) begin(section wasm.SectionID, idx wasm.Index, onTypeUse onTypeUse, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { - pos := onTypeUseUnhandledToken + pos := callbackPositionUnhandledToken p.pos = positionInitial // to ensure errorContext reports properly switch tok { case tokenLParen: @@ -122,7 +111,7 @@ func (p *typeUseParser) begin(section wasm.SectionID, idx wasm.Index, onTypeUse p.onTypeUse = onTypeUse return p.beginTypeParamOrResult, nil case tokenRParen: - pos = onTypeUseEndField + pos = callbackPositionEndField } return onTypeUse(p.emptyTypeIndex(section, idx), nil, pos, tok, tokenBytes, line, col) } @@ -222,7 +211,7 @@ func (p *typeUseParser) beginParamOrResult(tok tokenType, tokenBytes []byte, lin case "type": return nil, errors.New("redundant type") default: - return p.end(onTypeUseUnhandledField, tok, tokenBytes, line, col) + return p.end(callbackPositionUnhandledField, tok, tokenBytes, line, col) } } @@ -282,9 +271,10 @@ func (p *typeUseParser) setParamID(idToken []byte) error { // records i32 --^ ^ // parseMoreParamsOrResult resumes here --+ // -// Ex. One param type is present `(param i32)` -// records i32 --^ ^ -// parseMoreParamsOrResult resumes here --+ +// Ex. Multiple param types are present `(param i32 i64)` +// records i32 --^ ^ ^ +// records i32 --+ | +// parseMoreParamsOrResult resumes here --+ // // Ex. type is missing `(param)` // errs here --^ @@ -351,9 +341,9 @@ func (p *typeUseParser) parseResult(tok tokenType, tokenBytes []byte, _, _ uint3 } func (p *typeUseParser) parseEnd(tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { - pos := onTypeUseUnhandledToken + pos := callbackPositionUnhandledToken if tok == tokenRParen { - pos = onTypeUseEndField + pos = callbackPositionEndField } return p.end(pos, tok, tokenBytes, line, col) } @@ -379,7 +369,7 @@ type lineCol struct { } // end invokes onTypeUse to continue parsing -func (p *typeUseParser) end(pos onTypeUsePosition, tok tokenType, tokenBytes []byte, line, col uint32) (parser tokenParser, err error) { +func (p *typeUseParser) end(pos callbackPosition, tok tokenType, tokenBytes []byte, line, col uint32) (parser tokenParser, err error) { // Record the potentially inlined type if needed and invoke onTypeUse with the parsed index var typeIdx wasm.Index if p.parsedTypeField { diff --git a/wasm/text/typeuse_parser_test.go b/wasm/text/typeuse_parser_test.go index 74825ce752..73b3e41de2 100644 --- a/wasm/text/typeuse_parser_test.go +++ b/wasm/text/typeuse_parser_test.go @@ -18,7 +18,7 @@ type typeUseParserTest struct { expectedTypeIdx wasm.Index expectedParamNames wasm.NameMap - expectedOnTypeUsePosition onTypeUsePosition + expectedOnTypeUsePosition callbackPosition expectedOnTypeUseToken tokenType expectedTrailingTokens []tokenType } @@ -325,14 +325,14 @@ type typeUseTestFunc func(*typeUseParserTest) (*typeUseParser, func(t *testing.T func runTypeUseParserTests(t *testing.T, tests []*typeUseParserTest, tf typeUseTestFunc) { moreTests := make([]*typeUseParserTest, 0, len(tests)*2) for _, tt := range tests { - tt.expectedOnTypeUsePosition = onTypeUseEndField + tt.expectedOnTypeUsePosition = callbackPositionEndField tt.expectedOnTypeUseToken = tokenRParen // at the end of the field ')' tt.expectedTrailingTokens = nil kt := *tt // copy kt.name = fmt.Sprintf("%s - trailing keyword", tt.name) kt.source = fmt.Sprintf("%s nop)", tt.source[:len(tt.source)-1]) - kt.expectedOnTypeUsePosition = onTypeUseUnhandledToken + kt.expectedOnTypeUsePosition = callbackPositionUnhandledToken kt.expectedOnTypeUseToken = tokenKeyword // at 'nop' and ')' remains kt.expectedTrailingTokens = []tokenType{tokenRParen} moreTests = append(moreTests, &kt) @@ -342,11 +342,11 @@ func runTypeUseParserTests(t *testing.T, tests []*typeUseParserTest, tf typeUseT ft.source = fmt.Sprintf("%s (nop))", tt.source[:len(tt.source)-1]) // Two outcomes, we've reached a field not named "type", "param" or "result" or we completed "result" if strings.Contains(tt.source, "result") { - ft.expectedOnTypeUsePosition = onTypeUseUnhandledToken + ft.expectedOnTypeUsePosition = callbackPositionUnhandledToken ft.expectedOnTypeUseToken = tokenLParen // at '(' and 'nop))' remain ft.expectedTrailingTokens = []tokenType{tokenKeyword, tokenRParen, tokenRParen} } else { - ft.expectedOnTypeUsePosition = onTypeUseUnhandledField + ft.expectedOnTypeUsePosition = callbackPositionUnhandledField ft.expectedOnTypeUseToken = tokenKeyword // at 'nop' and '))' remain ft.expectedTrailingTokens = []tokenType{tokenRParen, tokenRParen} } @@ -360,7 +360,7 @@ func runTypeUseParserTests(t *testing.T, tests []*typeUseParserTest, tf typeUseT var parsedTypeIdx wasm.Index var parsedParamNames wasm.NameMap p := &collectTokenTypeParser{} - var setTypeUse onTypeUse = func(typeIdx wasm.Index, paramNames wasm.NameMap, pos onTypeUsePosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { + var setTypeUse onTypeUse = func(typeIdx wasm.Index, paramNames wasm.NameMap, pos callbackPosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { parsedTypeIdx = typeIdx parsedParamNames = paramNames require.Equal(t, tc.expectedOnTypeUsePosition, pos) @@ -534,11 +534,11 @@ func TestTypeUseParser_FailsMatch(t *testing.T) { } } -var ignoreTypeUse onTypeUse = func(typeIdx wasm.Index, paramNames wasm.NameMap, pos onTypeUsePosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { +var ignoreTypeUse onTypeUse = func(typeIdx wasm.Index, paramNames wasm.NameMap, pos callbackPosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { return parseNoop, nil } -var failOnTypeUse onTypeUse = func(typeIdx wasm.Index, paramNames wasm.NameMap, pos onTypeUsePosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { +var failOnTypeUse onTypeUse = func(typeIdx wasm.Index, paramNames wasm.NameMap, pos callbackPosition, tok tokenType, tokenBytes []byte, line, col uint32) (tokenParser, error) { return nil, errors.New("unexpected to call onTypeUse on error") }