diff --git a/wasm/jit/jit_amd64.go b/wasm/jit/jit_amd64.go index 308aa8bb32..2998534f79 100644 --- a/wasm/jit/jit_amd64.go +++ b/wasm/jit/jit_amd64.go @@ -4729,8 +4729,8 @@ func (c *amd64Compiler) callFunction(addr wasm.FunctionAddress, addrReg int16, f // 2) Set engine.valueStackContext.stackBasePointer for the next function. { - // At this point, tmpRegister holds the OLD stack base pointer. We could get the new frame's - // stack base pointer by "OLD stack base pointer + OLD stack pointer - # of function params" + // At this point, tmpRegister holds the old stack base pointer. We could get the new frame's + // stack base pointer by "old stack base pointer + old stack pointer - # of function params" // See the comments in engine.pushCallFrame which does exactly the same calculation in Go. calculateNextStackBasePointer := c.newProg() calculateNextStackBasePointer.As = x86.AADDQ @@ -4870,8 +4870,7 @@ func (c *amd64Compiler) returnFunction() error { // Obtain the temporary registers to be used in the followings. regs, found := c.locationStack.takeFreeRegisters(generalPurposeRegisterTypeInt, 3) if !found { - // This in theory never happen as all the registers must be free except addrReg. - return fmt.Errorf("could not find enough free registers") + return fmt.Errorf("BUG: all the registers should be free at this point") } c.locationStack.markRegisterUsed(regs...) diff --git a/wasm/jit/jit_arm64.go b/wasm/jit/jit_arm64.go index 0c1cd0fca0..578209f1bc 100644 --- a/wasm/jit/jit_arm64.go +++ b/wasm/jit/jit_arm64.go @@ -77,6 +77,8 @@ type arm64Compiler struct { labels map[string]*labelInfo // stackPointerCeil is the greatest stack pointer value (from valueLocationStack) seen during compilation. stackPointerCeil uint64 + // afterAssembleCallback hold the callbacks which are called after assembling native code. + afterAssembleCallback []func(code []byte) error } // compile implements compiler.compile for the arm64 architecture. @@ -89,10 +91,19 @@ func (c *arm64Compiler) compile() (code []byte, staticData compiledFunctionStati stackPointerCeil = c.locationStack.stackPointerCeil } - code, err = mmapCodeSegment(c.builder.Assemble()) + original := c.builder.Assemble() + + for _, cb := range c.afterAssembleCallback { + if err = cb(original); err != nil { + return + } + } + + code, err = mmapCodeSegment(original) if err != nil { return } + return } @@ -136,9 +147,11 @@ func (c *arm64Compiler) markRegisterUsed(reg int16) { c.locationStack.markRegisterUsed(reg) } -func (c *arm64Compiler) markRegisterUnused(reg int16) { - if !isZeroRegister(reg) { - c.locationStack.markRegisterUnused(reg) +func (c *arm64Compiler) markRegisterUnused(regs ...int16) { + for _, reg := range regs { + if !isZeroRegister(reg) { + c.locationStack.markRegisterUnused(reg) + } } } @@ -158,74 +171,63 @@ func (c *arm64Compiler) applyConstToRegisterInstruction(instruction obj.As, cons // applyMemoryToRegisterInstruction adds an instruction where source operand is a memory location and destination is a register. // baseRegister is the base absolute address in the memory, and offset is the offset from the absolute address in baseRegister. -func (c *arm64Compiler) applyMemoryToRegisterInstruction(instruction obj.As, baseRegister int16, offset int64, destinationRegister int16) (err error) { - if offset > math.MaxInt16 { - // This is a bug in JIT copmiler: caller must check the offset at compilation time, and avoid such a large offset - // by loading the const to the register beforehand and then using applyRegisterOffsetMemoryToRegisterInstruction instead. - err = fmt.Errorf("memory offset must be smaller than or equal %d, but got %d", math.MaxInt16, offset) - return +func (c *arm64Compiler) applyMemoryToRegisterInstruction(instruction obj.As, sourceBaseRegister int16, sourceOffsetConst int64, destinationRegister int16) { + if sourceOffsetConst > math.MaxInt16 { + // The assembler can take care of offsets larger than 2^15-1 by emitting additional instructions to load such large offset, + // but it uses "its" temporary register which we cannot track. Therefore, we avoid directly emitting memory load with large offsets, + // but instead load the constant manually to "our" temporary register, then emit the load with it. + c.applyConstToRegisterInstruction(arm64.AMOVD, sourceOffsetConst, reservedRegisterForTemporary) + inst := c.newProg() + inst.As = instruction + inst.From.Type = obj.TYPE_MEM + inst.From.Reg = sourceBaseRegister + inst.From.Index = reservedRegisterForTemporary + inst.From.Scale = 1 + inst.To.Type = obj.TYPE_REG + inst.To.Reg = destinationRegister + c.addInstruction(inst) + } else { + inst := c.newProg() + inst.As = instruction + inst.From.Type = obj.TYPE_MEM + inst.From.Reg = sourceBaseRegister + inst.From.Offset = sourceOffsetConst + inst.To.Type = obj.TYPE_REG + inst.To.Reg = destinationRegister + c.addInstruction(inst) } - inst := c.newProg() - inst.As = instruction - inst.From.Type = obj.TYPE_MEM - inst.From.Reg = baseRegister - inst.From.Offset = offset - inst.To.Type = obj.TYPE_REG - inst.To.Reg = destinationRegister - c.addInstruction(inst) - return -} - -// applyRegisterOffsetMemoryToRegisterInstruction adds an instruction where source operand is a memory location and destination is a register. -// The difference from applyMemoryToRegisterInstruction is that here we specify the offset by a register rather than offset constant. -func (c *arm64Compiler) applyRegisterOffsetMemoryToRegisterInstruction(instruction obj.As, baseRegister, offsetRegister, destinationRegister int16) (err error) { - inst := c.newProg() - inst.As = instruction - inst.From.Type = obj.TYPE_MEM - inst.From.Reg = baseRegister - inst.From.Index = offsetRegister - inst.From.Scale = 1 - inst.To.Type = obj.TYPE_REG - inst.To.Reg = destinationRegister - c.addInstruction(inst) - return nil } // applyRegisterToMemoryInstruction adds an instruction where destination operand is a memory location and source is a register. // This is the opposite of applyMemoryToRegisterInstruction. -func (c *arm64Compiler) applyRegisterToMemoryInstruction(instruction obj.As, baseRegister int16, offset int64, source int16) (err error) { - if offset > math.MaxInt16 { - // This is a bug in JIT copmiler: caller must check the offset at compilation time, and avoid such a large offset - // by loading the const to the register beforehand and then using applyRegisterToRegisterOffsetMemoryInstruction instead. - err = fmt.Errorf("memory offset must be smaller than or equal %d, but got %d", math.MaxInt16, offset) - return +func (c *arm64Compiler) applyRegisterToMemoryInstruction(instruction obj.As, sourceRegister int16, destinationBaseRegister int16, destinationOffsetConst int64) { + if destinationOffsetConst > math.MaxInt16 { + // The assembler can take care of offsets larger than 2^15-1 by emitting additional instructions to load such large offset, + // but we cannot track its temporary register. Therefore, we avoid directly emitting memory load with large offsets: + // load the constant manually to "our" temporary register, then emit the load with it. + c.applyConstToRegisterInstruction(arm64.AMOVD, destinationOffsetConst, reservedRegisterForTemporary) + inst := c.newProg() + inst.As = instruction + inst.To.Type = obj.TYPE_MEM + inst.To.Reg = destinationBaseRegister + inst.To.Index = reservedRegisterForTemporary + inst.To.Scale = 1 + inst.From.Type = obj.TYPE_REG + inst.From.Reg = sourceRegister + c.addInstruction(inst) + } else { + inst := c.newProg() + inst.As = instruction + inst.To.Type = obj.TYPE_MEM + inst.To.Reg = destinationBaseRegister + inst.To.Offset = destinationOffsetConst + inst.From.Type = obj.TYPE_REG + inst.From.Reg = sourceRegister + c.addInstruction(inst) } - inst := c.newProg() - inst.As = instruction - inst.To.Type = obj.TYPE_MEM - inst.To.Reg = baseRegister - inst.To.Offset = offset - inst.From.Type = obj.TYPE_REG - inst.From.Reg = source - c.addInstruction(inst) - return } -// applyRegisterToRegisterOffsetMemoryInstruction adds an instruction where destination operand is a memory location and source is a register. -// The difference from applyRegisterToMemoryInstruction is that here we specify the offset by a register rather than offset constant. -func (c *arm64Compiler) applyRegisterToRegisterOffsetMemoryInstruction(instruction obj.As, baseRegister, offsetRegister, source int16) { - inst := c.newProg() - inst.As = instruction - inst.To.Type = obj.TYPE_MEM - inst.To.Reg = baseRegister - inst.To.Index = offsetRegister - inst.To.Scale = 1 - inst.From.Type = obj.TYPE_REG - inst.From.Reg = source - c.addInstruction(inst) -} - -// applyRegisterToRegisterOffsetMemoryInstruction adds an instruction where both destination and source operands are registers. +// applyRegisterToRegisterInstruction adds an instruction where both destination and source operands are registers. func (c *arm64Compiler) applyRegisterToRegisterInstruction(instruction obj.As, from, to int16) { inst := c.newProg() inst.As = instruction @@ -236,7 +238,7 @@ func (c *arm64Compiler) applyRegisterToRegisterInstruction(instruction obj.As, f c.addInstruction(inst) } -// applyRegisterToRegisterOffsetMemoryInstruction adds an instruction which takes two source operands on registers and one destination register operand. +// applyTwoRegistersToRegisterInstruction adds an instruction which takes two source operands on registers and one destination register operand. func (c *arm64Compiler) applyTwoRegistersToRegisterInstruction(instruction obj.As, src1, src2, destination int16) { inst := c.newProg() inst.As = instruction @@ -261,14 +263,22 @@ func (c *arm64Compiler) applyTwoRegistersToNoneInstruction(instruction obj.As, s c.addInstruction(inst) } -func (c *arm64Compiler) emitUnconditionalBRInstruction(targetType obj.AddrType) (jmp *obj.Prog) { - jmp = c.newProg() - jmp.As = obj.AJMP - jmp.To.Type = targetType - c.addInstruction(jmp) +func (c *arm64Compiler) emitUnconditionalBranchInstruction() (br *obj.Prog) { + br = c.newProg() + br.As = obj.AJMP + br.To.Type = obj.TYPE_BRANCH + c.addInstruction(br) return } +func (c *arm64Compiler) emitUnconditionalBranchToAddressOnRegister(addressRegister int16) { + br := c.newProg() + br.As = obj.AJMP + br.To.Type = obj.TYPE_MEM + br.To.Reg = addressRegister + c.addInstruction(br) +} + func (c *arm64Compiler) String() (ret string) { return } // pushFunctionParams pushes any function parameters onto the stack, setting appropriate register types. @@ -303,56 +313,114 @@ func (c *arm64Compiler) emitPreamble() error { // returnFunction emits instructions to return from the current function frame. // If the current frame is the bottom, the code goes back to the Go code with jitCallStatusCodeReturned status. -// Otherwise, we branch into the caller's return address (TODO). +// Otherwise, we branch into the caller's return address. func (c *arm64Compiler) returnFunction() error { - // TODO: we don't support function calls yet. - // For now the following code just returns to Go code. + // Release all the registers as our calling convention requires the caller-save. + if err := c.releaseAllRegistersToStack(); err != nil { + return err + } // Since we return from the function, we need to decrement the callframe stack pointer, and write it back. - callFramePointerReg, _ := c.locationStack.takeFreeRegister(generalPurposeRegisterTypeInt) - if err := c.applyMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, - engineGlobalContextCallFrameStackPointerOffset, callFramePointerReg); err != nil { - return err + tmpRegs, found := c.locationStack.takeFreeRegisters(generalPurposeRegisterTypeInt, 3) + if !found { + return fmt.Errorf("BUG: all the registers should be free at this point") } + + // Alias for readability. + callFramePointerReg, callFrameStackTopAddressRegister, tmpReg := tmpRegs[0], tmpRegs[1], tmpRegs[2] + + // First we decrement the callframe stack pointer. + c.applyMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset, callFramePointerReg) c.applyConstToRegisterInstruction(arm64.ASUBS, 1, callFramePointerReg) - if err := c.applyRegisterToMemoryInstruction(arm64.AMOVD, reservedRegisterForEngine, - engineGlobalContextCallFrameStackPointerOffset, callFramePointerReg); err != nil { - return err - } + c.applyRegisterToMemoryInstruction(arm64.AMOVD, callFramePointerReg, reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset) + + // Next we compare the decremented call frame stack pointer with the engine.precviousCallFrameStackPointer. + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextPreviouscallFrameStackPointer, + tmpReg, + ) + c.applyTwoRegistersToNoneInstruction(arm64.ACMP, callFramePointerReg, tmpReg) + + // If the values are identical, we return back to the Go code with returned status. + brIfNotEqual := c.newProg() + brIfNotEqual.As = arm64.ABNE + brIfNotEqual.To.Type = obj.TYPE_BRANCH + c.addInstruction(brIfNotEqual) + c.exit(jitCallStatusCodeReturned) + + // Otherwise, we have to jump to the caller's return address. + c.setBRTargetOnNext(brIfNotEqual) + + // First, we have to calculate the caller callFrame's absolute address to aquire the return address. + // + // "tmpReg = &engine.callFrameStack[0]" + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCallFrameStackElement0AddressOffset, + tmpReg, + ) + // "callFrameStackTopAddressRegister = tmpReg + callFramePointerReg << ${callFrameDataSizeMostSignificantSetBit}" + c.emitAddInstructionWithLeftShiftedRegister( + callFramePointerReg, callFrameDataSizeMostSignificantSetBit, + tmpReg, + callFrameStackTopAddressRegister, + ) - return c.exit(jitCallStatusCodeReturned) + // At this point, we have + // + // [......., ra.caller, rb.caller, rc.caller, _, ra.current, rb.current, rc.current, _, ...] <- call frame stack's data region (somewhere in the memory) + // ^ + // callFrameStackTopAddressRegister + // (absolute address of &callFrameStack[engine.callFrameStackPointer]) + // + // where: + // ra.* = callFrame.returnAddress + // rb.* = callFrame.returnStackBasePointer + // rc.* = callFrame.compiledFunction + // _ = callFrame's padding (see comment on callFrame._ field.) + // + // What we have to do in the following is that + // 1) Set engine.valueStackContext.stackBasePointer to the value on "rb.caller". + // 2) Jump into the address of "ra.caller". + + // 1) Set engine.valueStackContext.stackBasePointer to the value on "rb.caller". + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + // "rb.caller" is below the top address. + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameReturnStackBasePointerOffset), + tmpReg) + c.applyRegisterToMemoryInstruction(arm64.AMOVD, + tmpReg, + reservedRegisterForEngine, engineValueStackContextStackBasePointerOffset) + + // 2) Branch into the address of "ra.caller". + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + // "rb.caller" is below the top address. + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameReturnAddressOffset), + tmpReg) + c.emitUnconditionalBranchToAddressOnRegister(tmpReg) + + c.locationStack.markRegisterUnused(tmpRegs...) + return nil } // exit adds instructions to give the control back to engine.exec with the given status code. func (c *arm64Compiler) exit(status jitCallStatusCode) error { // Write the current stack pointer to the engine.stackPointer. c.applyConstToRegisterInstruction(arm64.AMOVW, int64(c.locationStack.sp), reservedRegisterForTemporary) - if err := c.applyRegisterToMemoryInstruction(arm64.AMOVW, reservedRegisterForEngine, - engineValueStackContextStackPointerOffset, reservedRegisterForTemporary); err != nil { - return err - } + c.applyRegisterToMemoryInstruction(arm64.AMOVW, reservedRegisterForTemporary, reservedRegisterForEngine, + engineValueStackContextStackPointerOffset) if status != 0 { c.applyConstToRegisterInstruction(arm64.AMOVW, int64(status), reservedRegisterForTemporary) - if err := c.applyRegisterToMemoryInstruction(arm64.AMOVWU, reservedRegisterForEngine, - engineExitContextJITCallStatusCodeOffset, reservedRegisterForTemporary); err != nil { - return err - } + c.applyRegisterToMemoryInstruction(arm64.AMOVWU, reservedRegisterForTemporary, reservedRegisterForEngine, engineExitContextJITCallStatusCodeOffset) } else { // If the status == 0, we use zero register to store zero. - if err := c.applyRegisterToMemoryInstruction(arm64.AMOVWU, reservedRegisterForEngine, - engineExitContextJITCallStatusCodeOffset, zeroRegister); err != nil { - return err - } + c.applyRegisterToMemoryInstruction(arm64.AMOVWU, zeroRegister, reservedRegisterForEngine, engineExitContextJITCallStatusCodeOffset) } // The return address to the Go code is stored in archContext.jitReturnAddress which // is embedded in engine. We load the value to the tmpRegister, and then // invoke RET with that register. - if err := c.applyMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, - engineArchContextJITCallReturnAddressOffset, reservedRegisterForTemporary); err != nil { - return err - } + c.applyMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, engineArchContextJITCallReturnAddressOffset, reservedRegisterForTemporary) ret := c.newProg() ret.As = obj.ARET @@ -530,8 +598,8 @@ func (c *arm64Compiler) branchInto(target *wazeroir.BranchTarget) error { targetLabel.initialStack = c.locationStack.clone() } - jmp := c.emitUnconditionalBRInstruction(obj.TYPE_BRANCH) - c.assignBranchTarget(labelKey, jmp) + br := c.emitUnconditionalBranchInstruction() + c.assignBranchTarget(labelKey, br) return nil } } @@ -554,8 +622,229 @@ func (c *arm64Compiler) compileBrTable(o *wazeroir.OperationBrTable) error { return fmt.Errorf("TODO: unsupported on arm64") } +// compileCall implements compiler.compileCall for the arm64 architecture. func (c *arm64Compiler) compileCall(o *wazeroir.OperationCall) error { - return fmt.Errorf("TODO: unsupported on arm64") + target := c.f.ModuleInstance.Functions[o.FunctionIndex] + + if err := c.callFunction(target.Address, target.FunctionType.Type); err != nil { + return err + } + return nil +} + +// compileCall implements compiler.compileCall and compiler.compileCallIndirect (TODO) for the arm64 architecture. +func (c *arm64Compiler) callFunction(addr wasm.FunctionAddress, functype *wasm.FunctionType) error { + // TODO: the following code can be generalized for CallIndirect. + + // Release all the registers as our calling convention requires the caller-save. + if err := c.releaseAllRegistersToStack(); err != nil { + return err + } + + // Obtain the free registers to be used in the followings. + freeRegisters, found := c.locationStack.takeFreeRegisters(generalPurposeRegisterTypeInt, 5) + if !found { + return fmt.Errorf("BUG: all registers except addrReg should be free at this point") + } + c.locationStack.markRegisterUsed(freeRegisters...) + + // Alias for readability. + callFrameStackTopAddressRegister, compiledFunctionAddressRegister, oldStackBasePointer, + tmp := freeRegisters[0], freeRegisters[1], freeRegisters[2], freeRegisters[3] + + // TODO: Check the callframe stack length, and if necessary, grow the call frame stack before jump into the target. + + // "tmp = engine.callFrameStackPointer" + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset, + tmp) + // "callFrameStackTopAddressRegister = &engine.callFrameStack[0]" + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCallFrameStackElement0AddressOffset, + callFrameStackTopAddressRegister) + // "callFrameStackTopAddressRegister += tmp << $callFrameDataSizeMostSignificantSetBit" + c.emitAddInstructionWithLeftShiftedRegister( + tmp, callFrameDataSizeMostSignificantSetBit, + callFrameStackTopAddressRegister, + callFrameStackTopAddressRegister, + ) + + // At this point, we have: + // + // [..., ra.current, rb.current, rc.current, _, ra.next, rb.next, rc.next, ...] <- call frame stack's data region (somewhere in the memory) + // ^ + // callFrameStackTopAddressRegister + // (absolute address of &callFrame[engine.callFrameStackPointer]]) + // + // where: + // ra.* = callFrame.returnAddress + // rb.* = callFrame.returnStackBasePointer + // rc.* = callFrame.compiledFunction + // _ = callFrame's padding (see comment on callFrame._ field.) + // + // In the following comment, we use the notations in the above example. + // + // What we have to do in the following is that + // 1) Set rb.current so that we can return back to this function properly. + // 2) Set engine.valueStackContext.stackBasePointer for the next function. + // 3) Set rc.next to specify which function is executed on the current call frame (needs to make Go function calls). + // 4) Set ra.current so that we can return back to this function properly. + + // 1) Set rb.current so that we can return back to this function properly. + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineValueStackContextStackBasePointerOffset, + oldStackBasePointer) + c.applyRegisterToMemoryInstruction(arm64.AMOVD, + oldStackBasePointer, + // "rb.current" is BELOW the top address. See the above example for detail. + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameReturnStackBasePointerOffset)) + + // 2) Set engine.valueStackContext.stackBasePointer for the next function. + // + // At this point, oldStackBasePointer holds the old stack base pointer. We could get the new frame's + // stack base pointer by "old stack base pointer + old stack pointer - # of function params" + // See the comments in engine.pushCallFrame which does exactly the same calculation in Go. + c.applyConstToRegisterInstruction(arm64.AADD, + int64(c.locationStack.sp)-int64(len(functype.Params)), + oldStackBasePointer) + c.applyRegisterToMemoryInstruction(arm64.AMOVD, + oldStackBasePointer, + reservedRegisterForEngine, engineValueStackContextStackBasePointerOffset) + + // 3) Set rc.next to specify which function is executed on the current call frame. + // + // First, we read the address of the first item of engine.compiledFunctions slice (= &engine.compiledFunctions[0]) + // into tmp. + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCompiledFunctionsElement0AddressOffset, + tmp) + + // Next, read the address of the target function (= &engine.compiledFunctions[offset]) + // into compiledFunctionAddressRegister. + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + tmp, int64(addr)*8, // * 8 because the size of *compiledFunction equals 8 bytes. + compiledFunctionAddressRegister) + + // Finally, we are ready to write the address of the target function's *compiledFunction into the new callframe. + c.applyRegisterToMemoryInstruction(arm64.AMOVD, + compiledFunctionAddressRegister, + callFrameStackTopAddressRegister, callFrameCompiledFunctionOffset) + + // 4) Set ra.current so that we can return back to this function properly. + // + // First, Get the return address into the tmp. + c.readInstructionAddress(obj.AJMP, tmp) + // Then write the address into the callframe. + c.applyRegisterToMemoryInstruction(arm64.AMOVD, + tmp, + // "ra.current" is BELOW the top address. See the above example for detail. + callFrameStackTopAddressRegister, -(callFrameDataSize - callFrameReturnAddressOffset), + ) + + // Everthing is done to make function call now. + // We increment the callframe stack pointer. + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset, + tmp) + c.applyConstToRegisterInstruction(arm64.AADD, 1, tmp) + c.applyRegisterToMemoryInstruction(arm64.AMOVD, + tmp, + reservedRegisterForEngine, engineGlobalContextCallFrameStackPointerOffset) + + // Then, br into the target function's initial address. + c.applyMemoryToRegisterInstruction(arm64.AMOVD, + compiledFunctionAddressRegister, compiledFunctionCodeInitialAddressOffset, + tmp) + c.emitUnconditionalBranchToAddressOnRegister(tmp) + + // All the registers used are temporary so we mark them unused. + c.markRegisterUnused(freeRegisters...) + + // We consumed the function parameters from the stack after call. + for i := 0; i < len(functype.Params); i++ { + c.locationStack.pop() + } + + // Also, the function results were pushed by the call. + for _, t := range functype.Results { + loc := c.locationStack.pushValueLocationOnStack() + switch t { + case wasm.ValueTypeI32, wasm.ValueTypeI64: + loc.setRegisterType(generalPurposeRegisterTypeInt) + case wasm.ValueTypeF32, wasm.ValueTypeF64: + loc.setRegisterType(generalPurposeRegisterTypeFloat) + } + } + + // On the function return, we initialize the state for this function. + c.initializeReservedStackBasePointerRegister() + + // TODO: initialize module context, and memory pointer. + return nil +} + +// readInstructionAddress adds an ADR instruction to set the absolute address of "target instruction" +// into destinationRegister. "target instruction" is specified by beforeTargetInst argument and +// the target is determined by "the instruction right after beforeTargetInst type". +// +// For example, if beforeTargetInst == RET and we have the instruction sequence like +// ADR -> X -> Y -> ... -> RET -> MOV, then the ADR instruction emitted by this function set the absolute +// address of MOV instruction into the destination register. +func (c *arm64Compiler) readInstructionAddress(beforeTargetInst obj.As, destinationRegister int16) { + // Emit ADR instruction to read the specified instruction's absolute address. + // Note: we cannot emit the "ADR REG, $(target's offset from here)" due to the + // incapability of the assembler. Instead, we emit "ADR REG, ." meaning that + // "reading the current program counter" = "reading the absolute address of this ADR instruction". + // And then, after compilation phase, we directly edit the native code slice so that + // it can properly read the target instruction's absolute address. + readAddress := c.newProg() + readAddress.As = arm64.AADR + readAddress.From.Type = obj.TYPE_BRANCH + readAddress.To.Type = obj.TYPE_REG + readAddress.To.Reg = destinationRegister + c.addInstruction(readAddress) + + // Setup the callback to modify the instruction bytes after compilation. + // Note: this is the closure over readAddress (*obj.Prog). + c.afterAssembleCallback = append(c.afterAssembleCallback, func(code []byte) error { + // Find the target instruction. + target := readAddress + for target != nil { + if target.As == beforeTargetInst { + // At this point, target is the instruction right before the target instruction. + // Thus, advance one more time to make target the target instruction. + target = target.Link + break + } + target = target.Link + } + + if target == nil { + return fmt.Errorf("BUG: target instruction not found for read instruction address") + } + + offset := target.Pc - readAddress.Pc + if offset > math.MaxUint8 { + // We could support up to 20-bit integer, but byte should be enough for our impl. + // If the necessity comes up, we could fix the below to support larger offsets. + return fmt.Errorf("BUG: too large offset for read") + } + + // Now ready to write an offset byte. + v := byte(offset) + // arm64 has 4-bytes = 32-bit fixed-length instruction. + adrInstructionBytes := code[readAddress.Pc : readAddress.Pc+4] + // According to the binary format of ADR instruction in arm64: + // https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/ADR--Form-PC-relative-address-?lang=en + // + // The 0 to 1 bits live on 29 to 30 bits of the instruction. + adrInstructionBytes[3] |= (v & 0b00000011) << 5 + // The 2 to 4 bits live on 5 to 7 bits of the instruction. + adrInstructionBytes[0] |= (v & 0b00011100) << 3 + // The 5 to 7 bits live on 8 to 10 bits of the instruction. + adrInstructionBytes[1] |= (v & 0b11100000) >> 5 + return nil + }) } func (c *arm64Compiler) compileCallIndirect(o *wazeroir.OperationCallIndirect) error { @@ -1398,12 +1687,18 @@ func (c *arm64Compiler) pushZeroValue() { // popTwoValuesOnRegisters pops two values from the location stacks, ensures // these two values are located on registers, and mark them unused. func (c *arm64Compiler) popTwoValuesOnRegisters() (x1, x2 *valueLocation, err error) { - x2, err = c.popValueOnRegister() - if err != nil { + x2 = c.locationStack.pop() + if err = c.ensureOnGeneralPurposeRegister(x2); err != nil { return } - x1, err = c.popValueOnRegister() + x1 = c.locationStack.pop() + if err = c.ensureOnGeneralPurposeRegister(x1); err != nil { + return + } + + c.markRegisterUnused(x2.register) + c.markRegisterUnused(x1.register) return } @@ -1478,15 +1773,7 @@ func (c *arm64Compiler) loadValueOnStackToRegister(loc *valueLocation) (err erro return } - if offset := int64(loc.stackPointer) * 8; offset > math.MaxInt16 { - // The assembler can take care of offsets larger than 2^15-1 by emitting additional instructions to load such large offset, - // but it uses "its" temporary register which we cannot track. Therefore, we avoid directly emitting memory load with large offsets, - // but instead load the constant manually to "our" temporary register, then emit the load with it. - c.applyConstToRegisterInstruction(arm64.AMOVD, offset, reservedRegisterForTemporary) - c.applyRegisterOffsetMemoryToRegisterInstruction(inst, reservedRegisterForStackBasePointerAddress, reservedRegisterForTemporary, reg) - } else { - c.applyMemoryToRegisterInstruction(inst, reservedRegisterForStackBasePointerAddress, offset, reg) - } + c.applyMemoryToRegisterInstruction(inst, reservedRegisterForStackBasePointerAddress, int64(loc.stackPointer)*8, reg) // Record that the value holds the register and the register is marked used. loc.setRegister(reg) @@ -1545,17 +1832,7 @@ func (c *arm64Compiler) releaseRegisterToStack(loc *valueLocation) (err error) { inst = arm64.AFMOVD } - if offset := int64(loc.stackPointer) * 8; offset > math.MaxInt16 { - // The assembler can take care of offsets larger than 2^15-1 by emitting additional instructions to load such large offset, - // but it uses "its" temporary register which we cannot track. Therefore, we avoid directly emitting memory load with large offsets, - // but instead load the constant manually to "our" temporary register, then emit the load with it. - c.applyConstToRegisterInstruction(arm64.AMOVD, offset, reservedRegisterForTemporary) - c.applyRegisterToRegisterOffsetMemoryInstruction(inst, reservedRegisterForStackBasePointerAddress, reservedRegisterForTemporary, loc.register) - } else { - if err = c.applyRegisterToMemoryInstruction(inst, reservedRegisterForStackBasePointerAddress, offset, loc.register); err != nil { - return - } - } + c.applyRegisterToMemoryInstruction(inst, loc.register, reservedRegisterForStackBasePointerAddress, int64(loc.stackPointer)*8) // Mark the register is free. c.locationStack.releaseRegister(loc) @@ -1566,37 +1843,33 @@ func (c *arm64Compiler) releaseRegisterToStack(loc *valueLocation) (err error) { // so that it points to the absolute address of the stack base for this function. func (c *arm64Compiler) initializeReservedStackBasePointerRegister() error { // First, load the address of the first element in the value stack into reservedRegisterForStackBasePointerAddress temporarily. - if err := c.applyMemoryToRegisterInstruction(arm64.AMOVD, + c.applyMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, engineGlobalContextValueStackElement0AddressOffset, - reservedRegisterForStackBasePointerAddress); err != nil { - return err - } + reservedRegisterForStackBasePointerAddress) - // Next we move the base pointer (engine.stackBasePointer) to the tmp register. - if err := c.applyMemoryToRegisterInstruction(arm64.AMOVD, + // Next we move the base pointer (engine.stackBasePointer) to reservedRegisterForTemporary. + c.applyMemoryToRegisterInstruction(arm64.AMOVD, reservedRegisterForEngine, engineValueStackContextStackBasePointerOffset, - reservedRegisterForTemporary); err != nil { - return err - } - - // Finally, we calculate "reservedRegisterForStackBasePointerAddress + tmpReg * 8" - // where we multiply tmpReg by 8 because stack pointer is an index in the []uint64 - // so as an bytes we must multiply the size of uint64 = 8 bytes. - calcStackBasePointerAddress := c.newProg() - calcStackBasePointerAddress.As = arm64.AADD - calcStackBasePointerAddress.To.Type = obj.TYPE_REG - calcStackBasePointerAddress.To.Reg = reservedRegisterForStackBasePointerAddress - // We calculate "tmpReg * 8" as "tmpReg << 3". - setLeftShiftedRegister(calcStackBasePointerAddress, reservedRegisterForTemporary, 3) - c.addInstruction(calcStackBasePointerAddress) + reservedRegisterForTemporary) + + // Finally, we calculate "reservedRegisterForStackBasePointerAddress + reservedRegisterForTemporary << 3" + // where we shift tmpReg by 3 because stack pointer is an index in the []uint64 + // so we must multiply the value by the size of uint64 = 8 bytes. + c.emitAddInstructionWithLeftShiftedRegister( + reservedRegisterForTemporary, 3, reservedRegisterForStackBasePointerAddress, + reservedRegisterForStackBasePointerAddress) return nil } -// setShiftedRegister modifies the given *obj.Prog so that .From (source operand) -// becomes the "left shifted register". For example, this is used to emit instruction like -// "add x1, x2, x3, lsl #3" which means "x1 = x2 + (x3 << 3)". -// See https://github.com/twitchyliquid64/golang-asm/blob/v0.15.1/obj/link.go#L120-L131 -func setLeftShiftedRegister(inst *obj.Prog, register int16, shiftNum int64) { +// emitAddInstructionWithLeftShiftedRegister emits an ADD instruction to perform "destinationReg = srcReg + (shiftedSourceReg << shiftNum)". +func (c *arm64Compiler) emitAddInstructionWithLeftShiftedRegister(shiftedSourceReg int16, shiftNum int64, srcReg, destinationReg int16) { + inst := c.newProg() + inst.As = arm64.AADD + inst.To.Type = obj.TYPE_REG + inst.To.Reg = destinationReg + // See https://github.com/twitchyliquid64/golang-asm/blob/v0.15.1/obj/link.go#L120-L131 inst.From.Type = obj.TYPE_SHIFT - inst.From.Offset = (int64(register)&31)<<16 | 0<<22 | (shiftNum&63)<<10 + inst.From.Offset = (int64(shiftedSourceReg)&31)<<16 | 0<<22 | (shiftNum&63)<<10 + inst.Reg = srcReg + c.addInstruction(inst) } diff --git a/wasm/jit/jit_arm64_test.go b/wasm/jit/jit_arm64_test.go index fa291f3e4c..965badc800 100644 --- a/wasm/jit/jit_arm64_test.go +++ b/wasm/jit/jit_arm64_test.go @@ -4,7 +4,6 @@ package jit import ( - "context" "fmt" "math" "math/bits" @@ -50,58 +49,6 @@ func (j *jitEnv) requireNewCompiler(t *testing.T) *arm64Compiler { return ret } -// TODO: delete this as this could be a duplication from other tests especially spectests. -// Use this until we could run spectests on arm64. -func TestArm64CompilerEndToEnd(t *testing.T) { - ctx := context.Background() - for _, tc := range []struct { - name string - body []byte - sig *wasm.FunctionType - }{ - {name: "empty", body: []byte{wasm.OpcodeEnd}, sig: &wasm.FunctionType{}}, - {name: "br .return", body: []byte{wasm.OpcodeBr, 0, wasm.OpcodeEnd}, sig: &wasm.FunctionType{}}, - { - name: "consts", - body: []byte{ - wasm.OpcodeI32Const, 1, wasm.OpcodeI64Const, 1, - wasm.OpcodeF32Const, 1, 1, 1, 1, wasm.OpcodeF64Const, 1, 2, 3, 4, 5, 6, 7, 8, - wasm.OpcodeEnd, - }, - // We push four constants. - sig: &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI32, wasm.ValueTypeI64, wasm.ValueTypeF32, wasm.ValueTypeF64}}, - }, - { - name: "add", - body: []byte{wasm.OpcodeI32Const, 1, wasm.OpcodeI32Const, 1, wasm.OpcodeI32Add, wasm.OpcodeEnd}, - sig: &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI32}}, - }, - { - name: "sub", - body: []byte{wasm.OpcodeI64Const, 1, wasm.OpcodeI64Const, 1, wasm.OpcodeI64Sub, wasm.OpcodeEnd}, - sig: &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI64}}, - }, - { - name: "mul", - body: []byte{wasm.OpcodeI64Const, 1, wasm.OpcodeI64Const, 1, wasm.OpcodeI64Mul, wasm.OpcodeEnd}, - sig: &wasm.FunctionType{Results: []wasm.ValueType{wasm.ValueTypeI64}}, - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - engine := newEngine() - f := &wasm.FunctionInstance{ - FunctionType: &wasm.TypeInstance{Type: tc.sig}, - Body: tc.body, - } - err := engine.Compile(f) - require.NoError(t, err) - _, err = engine.Call(ctx, f) - require.NoError(t, err) - }) - } -} - func TestArchContextOffsetInEngine(t *testing.T) { var eng engine // If this fails, we have to fix jit_arm64.s as well. @@ -109,25 +56,77 @@ func TestArchContextOffsetInEngine(t *testing.T) { } func TestArm64Compiler_returnFunction(t *testing.T) { - env := newJITEnvironment() + t.Run("exit", func(t *testing.T) { + env := newJITEnvironment() - // Build code. - compiler := env.requireNewCompiler(t) - err := compiler.emitPreamble() - require.NoError(t, err) - compiler.returnFunction() + // Build code. + compiler := env.requireNewCompiler(t) + err := compiler.emitPreamble() + require.NoError(t, err) + compiler.returnFunction() - // Generate the code under test. - code, _, _, err := compiler.compile() - require.NoError(t, err) + code, _, _, err := compiler.compile() + require.NoError(t, err) + + env.exec(code) + + // JIT status on engine must be returned. + require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + // Plus, the call frame stack pointer must be zero after return. + require.Equal(t, uint64(0), env.callFrameStackPointer()) + }) + t.Run("deep call stack", func(t *testing.T) { + env := newJITEnvironment() + engine := env.engine() + + // Push the call frames. + const callFrameNums = 10 + stackPointerToExpectedValue := map[uint64]uint32{} + for funcaddr := wasm.FunctionAddress(0); funcaddr < callFrameNums; funcaddr++ { + // Each function pushes its funcaddr and soon returns. + compiler := env.requireNewCompiler(t) + err := compiler.emitPreamble() + require.NoError(t, err) + + // Push its funcaddr. + expValue := uint32(funcaddr) + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: expValue}) + require.NoError(t, err) + + err = compiler.returnFunction() + require.NoError(t, err) + + code, _, _, err := compiler.compile() + require.NoError(t, err) + + // Compiles and adds to the engine. + compiledFunction := &compiledFunction{codeSegment: code, codeInitialAddress: uintptr(unsafe.Pointer(&code[0]))} + engine.addCompiledFunction(funcaddr, compiledFunction) + + // Pushes the frame whose return address equals the beginning of the function just compiled. + frame := callFrame{ + // Set the return address to the beginning of the function so that we can execute the constI32 above. + returnAddress: compiledFunction.codeInitialAddress, + // Note: return stack base pointer is set to funcaddr*10 and this is where the const should be pushed. + returnStackBasePointer: uint64(funcaddr) * 10, + compiledFunction: compiledFunction, + } + engine.callFrameStack[engine.globalContext.callFrameStackPointer] = frame + engine.globalContext.callFrameStackPointer++ + stackPointerToExpectedValue[frame.returnStackBasePointer] = expValue + } + + require.Equal(t, uint64(callFrameNums), env.callFrameStackPointer()) - // Run native code. - env.exec(code) + // Run code from the top frame. + env.exec(engine.callFrameTop().compiledFunction.codeSegment) - // JIT status on engine must be returned. - require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) - // Plus, the call frame stack pointer must be zero after return. - require.Equal(t, uint64(0), env.callFrameStackPointer()) + // Check the exit status and the values on stack. + require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + for pos, exp := range stackPointerToExpectedValue { + require.Equal(t, exp, uint32(env.stack()[pos])) + } + }) } func TestArm64Compiler_exit(t *testing.T) { @@ -273,7 +272,7 @@ func TestArm64Compiler_releaseRegisterToStack(t *testing.T) { // Release the register allocated value to the memory stack so that we can see the value after exiting. compiler.releaseRegisterToStack(compiler.locationStack.peek()) - compiler.returnFunction() + compiler.exit(jitCallStatusCodeReturned) // Generate the code under test. code, _, _, err := compiler.compile() @@ -352,7 +351,7 @@ func TestArm64Compiler_loadValueOnStackToRegister(t *testing.T) { // Release the value to the memory stack so that we can see the value after exiting. compiler.releaseRegisterToStack(loc) - compiler.returnFunction() + compiler.exit(jitCallStatusCodeReturned) // Generate the code under test. code, _, _, err := compiler.compile() @@ -1421,8 +1420,6 @@ func TestArm64Compiler_compileBrIf(t *testing.T) { require.NoError(t, err) }, }, - // {name: "EQ"} TODO: after compileEq support - // {name: "NE"} TODO: after compileNe support { name: "HS", setupFunc: func(t *testing.T, compiler *arm64Compiler, shoulGoElse bool) { @@ -1514,6 +1511,30 @@ func TestArm64Compiler_compileBrIf(t *testing.T) { require.NoError(t, err) }, }, + { + name: "EQ", + setupFunc: func(t *testing.T, compiler *arm64Compiler, shoulGoElse bool) { + x1, x2 := uint32(1), uint32(1) + if shoulGoElse { + x2++ + } + requirePushTwoInt32Consts(t, x1, x2, compiler) + err := compiler.compileEq(&wazeroir.OperationEq{Type: wazeroir.UnsignedTypeI32}) + require.NoError(t, err) + }, + }, + { + name: "NE", + setupFunc: func(t *testing.T, compiler *arm64Compiler, shoulGoElse bool) { + x1, x2 := uint32(1), uint32(2) + if shoulGoElse { + x2 = x1 + } + requirePushTwoInt32Consts(t, x1, x2, compiler) + err := compiler.compileNe(&wazeroir.OperationNe{Type: wazeroir.UnsignedTypeI32}) + require.NoError(t, err) + }, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -1566,3 +1587,159 @@ func TestArm64Compiler_compileBrIf(t *testing.T) { }) } } + +func TestArm64Compiler_readInstructionAddress(t *testing.T) { + t.Run("target instruction not found", func(t *testing.T) { + env := newJITEnvironment() + compiler := env.requireNewCompiler(t) + + err := compiler.emitPreamble() + require.NoError(t, err) + + // Set the acquisition target instruction to the one after JMP. + compiler.readInstructionAddress(obj.AJMP, reservedRegisterForTemporary) + + compiler.exit(jitCallStatusCodeReturned) + + // If generate the code without JMP after readInstructionAddress, + // the call back added must return error. + _, _, _, err = compiler.compile() + require.Error(t, err) + require.Contains(t, err.Error(), "target instruction not found") + }) + t.Run("too large offset", func(t *testing.T) { + env := newJITEnvironment() + compiler := env.requireNewCompiler(t) + + err := compiler.emitPreamble() + require.NoError(t, err) + + // Set the acquisition target instruction to the one after RET. + compiler.readInstructionAddress(obj.ARET, reservedRegisterForTemporary) + + // Add many instruction between the target and readInstructionAddress. + for i := 0; i < 100; i++ { + compiler.compileConstI32(&wazeroir.OperationConstI32{Value: 10}) + } + + ret := compiler.newProg() + ret.As = obj.ARET + ret.To.Type = obj.TYPE_REG + ret.To.Reg = reservedRegisterForTemporary + compiler.returnFunction() + + // If generate the code with too many instruction between ADR and + // the target, compile must fail. + _, _, _, err = compiler.compile() + require.Error(t, err) + require.Contains(t, err.Error(), "too large offset") + }) + t.Run("ok", func(t *testing.T) { + env := newJITEnvironment() + compiler := env.requireNewCompiler(t) + + err := compiler.emitPreamble() + require.NoError(t, err) + + // Set the acquisition target instruction to the one after RET, + // and read the absolute address into destinationRegister. + const addressReg = reservedRegisterForTemporary + compiler.readInstructionAddress(obj.ARET, addressReg) + + // Branch to the instruction after RET below via the absolute + // address stored in destinationRegister. + compiler.emitUnconditionalBranchToAddressOnRegister(addressReg) + + // If we fail to branch, we reach here and exit with unreachable status, + // so the assertion would fail. + compiler.exit(jitCallStatusCodeUnreachable) + + // This could be the read instruction target as this is the + // right after RET. Therefore, the branch instruction above + // must target here. + err = compiler.returnFunction() + require.NoError(t, err) + + code, _, _, err := compiler.compile() + require.NoError(t, err) + + env.exec(code) + + require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + }) +} + +func TestArm64Compiler_compieleCall(t *testing.T) { + t.Run("need to grow call frame stack", func(t *testing.T) { + t.Skip("TODO") + }) + t.Run("callframe stack ok", func(t *testing.T) { + env := newJITEnvironment() + engine := env.engine() + expectedValue := uint32(0) + + // Emit the call target function. + const numCalls = 10 + targetFunctionType := &wasm.FunctionType{ + Params: []wasm.ValueType{wasm.ValueTypeI32}, + Results: []wasm.ValueType{wasm.ValueTypeI32}, + } + for i := 0; i < numCalls; i++ { + // Each function takes one arguments, adds the value with 100 + i and returns the result. + addTargetValue := uint32(100 + i) + expectedValue += addTargetValue + + compiler := env.requireNewCompiler(t) + compiler.f = &wasm.FunctionInstance{FunctionType: &wasm.TypeInstance{Type: targetFunctionType}} + + err := compiler.emitPreamble() + require.NoError(t, err) + + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: uint32(addTargetValue)}) + require.NoError(t, err) + err = compiler.compileAdd(&wazeroir.OperationAdd{Type: wazeroir.UnsignedTypeI32}) + require.NoError(t, err) + err = compiler.returnFunction() + require.NoError(t, err) + + code, _, _, err := compiler.compile() + require.NoError(t, err) + engine.addCompiledFunction(wasm.FunctionAddress(i), &compiledFunction{ + codeSegment: code, + codeInitialAddress: uintptr(unsafe.Pointer(&code[0])), + }) + } + + // Now we start building the caller's code. + compiler := env.requireNewCompiler(t) + err := compiler.emitPreamble() + require.NoError(t, err) + + const initialValue = 100 + expectedValue += initialValue + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: 0}) // Dummy value so the base pointer would be non-trivial for callees. + require.NoError(t, err) + err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: initialValue}) + require.NoError(t, err) + + // Call all the built functions. + for i := 0; i < numCalls; i++ { + err = compiler.callFunction(wasm.FunctionAddress(i), targetFunctionType) + require.NoError(t, err) + } + + err = compiler.returnFunction() + require.NoError(t, err) + + code, _, _, err := compiler.compile() + require.NoError(t, err) + + env.exec(code) + + // Check status and returned values. + require.Equal(t, jitCallStatusCodeReturned, env.jitStatus()) + require.Equal(t, uint64(2), env.stackPointer()) // Must be 2 (dummy value + the calculation results) + require.Equal(t, uint64(0), env.stackBasePointer()) + require.Equal(t, expectedValue, env.stackTopAsUint32()) + }) +}