Skip to content

Commit

Permalink
[WebAssembly] Handle block and polymorphic stack in AsmTypeCheck
Browse files Browse the repository at this point in the history
This makes the type checker handle blocks with input parameters and
return types, branches, and polymorphic stacks correctly.

We maintain the stack of "block info", which contains its input
parameter type, return type, and whether it is a loop or not. And this
is used when checking the validity of the value stack at the `end`
marker and all branches targeting the block.

`StackType` now supports a new variant `Polymorphic`, which indicates
the stack is in the polymorphic state. `Polymorphic`s are not popped
even when `popType` is executed; they are only popped when the current
block ends.

When popping from the value stack, we ensure we don't pop more than we
are allowed to at the given block level and print appropriate error
messages instead. Also after a block ends, the value stack is guaranteed
to have the right types based on the block return type. For example,
```wast
block i32
  unreachable
end_block
;; You can expect to have an i32 on the stack here
```

This also adds handling for `br_if`. Previously only `br`s were checked.

`checkEnd` and `checkBr` were removed and their contents have been
inlined to the main `typeCheck` function, because they are called only
from a single callsite.

This also fixes two existing bugs in AsmParser, which were required to
make the tests passing.

This modifies several existing invalid tests, those that passed
(incorrectly) before but do not pass with the new type checker anymore.

Fixes llvm#107524.
  • Loading branch information
aheejin committed Oct 2, 2024
1 parent 7b23468 commit e95ddaf
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,9 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {

void addBlockTypeOperand(OperandVector &Operands, SMLoc NameLoc,
WebAssembly::BlockType BT) {
if (BT != WebAssembly::BlockType::Void) {
if (BT == WebAssembly::BlockType::Void) {
TC.setLastSig(wasm::WasmSignature{});
} else {
wasm::WasmSignature Sig({static_cast<wasm::ValType>(BT)}, {});
TC.setLastSig(Sig);
NestingStack.back().Sig = Sig;
Expand Down Expand Up @@ -1002,7 +1004,8 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
auto *Signature = Ctx.createWasmSignature();
if (parseSignature(Signature))
return ParseStatus::Failure;
TC.funcDecl(*Signature);
if (CurrentState == FunctionStart)
TC.funcDecl(*Signature);
WasmSym->setSignature(Signature);
WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
TOut.emitFunctionType(WasmSym);
Expand Down
176 changes: 115 additions & 61 deletions llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,

void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end());
BlockInfoStack.push_back({Sig, 0, false});
}

void WebAssemblyAsmTypeCheck::localDecl(
Expand All @@ -64,14 +63,15 @@ void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
}

bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
// If we're currently in unreachable code, we suppress errors completely.
if (Unreachable)
return false;
dumpTypeStack("current stack: ");
return Parser.Error(ErrorLoc, Msg);
}

bool WebAssemblyAsmTypeCheck::match(StackType TypeA, StackType TypeB) {
// These should have been filtered out in checkTypes()
assert(!std::get_if<Polymorphic>(&TypeA) &&
!std::get_if<Polymorphic>(&TypeB));

if (TypeA == TypeB)
return false;
if (std::get_if<Any>(&TypeA) || std::get_if<Any>(&TypeB))
Expand All @@ -90,6 +90,10 @@ std::string WebAssemblyAsmTypeCheck::getTypesString(ArrayRef<StackType> Types,
size_t StartPos) {
SmallVector<std::string, 4> TypeStrs;
for (auto I = Types.size(); I > StartPos; I--) {
if (std::get_if<Polymorphic>(&Types[I - 1])) {
TypeStrs.push_back("...");
break;
}
if (std::get_if<Any>(&Types[I - 1]))
TypeStrs.push_back("any");
else if (std::get_if<Ref>(&Types[I - 1]))
Expand Down Expand Up @@ -131,29 +135,48 @@ bool WebAssemblyAsmTypeCheck::checkTypes(SMLoc ErrorLoc,
bool ExactMatch) {
auto StackI = Stack.size();
auto TypeI = Types.size();
assert(!BlockInfoStack.empty());
auto BlockStackStartPos = BlockInfoStack.back().StackStartPos;
bool Error = false;
bool PolymorphicStack = false;
// Compare elements one by one from the stack top
for (; StackI > 0 && TypeI > 0; StackI--, TypeI--) {
for (;StackI > BlockStackStartPos && TypeI > 0; StackI--, TypeI--) {
// If the stack is polymorphic, we assume all types in 'Types' have been
// compared and matched
if (std::get_if<Polymorphic>(&Stack[StackI - 1])) {
TypeI = 0;
break;
}
if (match(Stack[StackI - 1], Types[TypeI - 1])) {
Error = true;
break;
}
}

// If the stack top is polymorphic, the stack is in the polymorphic state.
if (StackI > BlockStackStartPos &&
std::get_if<Polymorphic>(&Stack[StackI - 1]))
PolymorphicStack = true;

// Even if no match failure has happened in the loop above, if not all
// elements of Types has been matched, that means we don't have enough
// elements on the stack.
//
// Also, if not all elements of the Stack has been matched and when
// 'ExactMatch' is true, that means we have superfluous elements remaining on
// the stack (e.g. at the end of a function).
if (TypeI > 0 || (ExactMatch && StackI > 0))
// 'ExactMatch' is true and the current stack is not polymorphic, that means
// we have superfluous elements remaining on the stack (e.g. at the end of a
// function).
if (TypeI > 0 ||
(ExactMatch && !PolymorphicStack && StackI > BlockStackStartPos))
Error = true;

if (!Error)
return false;

auto StackStartPos =
ExactMatch ? 0 : std::max(0, (int)Stack.size() - (int)Types.size());
auto StackStartPos = ExactMatch
? BlockStackStartPos
: std::max((int)BlockStackStartPos,
(int)Stack.size() - (int)Types.size());
return typeError(ErrorLoc, "type mismatch, expected " +
getTypesString(Types, 0) + " but got " +
getTypesString(Stack, StackStartPos));
Expand All @@ -169,9 +192,13 @@ bool WebAssemblyAsmTypeCheck::popTypes(SMLoc ErrorLoc,
ArrayRef<StackType> Types,
bool ExactMatch) {
bool Error = checkTypes(ErrorLoc, Types, ExactMatch);
auto NumPops = std::min(Stack.size(), Types.size());
for (size_t I = 0, E = NumPops; I != E; I++)
auto NumPops = std::min(Stack.size() - BlockInfoStack.back().StackStartPos,
Types.size());
for (size_t I = 0, E = NumPops; I != E; I++) {
if (std::get_if<Polymorphic>(&Stack.back()))
break;
Stack.pop_back();
}
return Error;
}

Expand Down Expand Up @@ -201,25 +228,6 @@ bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp,
return false;
}

bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) {
if (Level >= BrStack.size())
return typeError(ErrorLoc,
StringRef("br: invalid depth ") + std::to_string(Level));
const SmallVector<wasm::ValType, 4> &Expected =
BrStack[BrStack.size() - Level - 1];
return checkTypes(ErrorLoc, Expected);
return false;
}

bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
if (!PopVals)
BrStack.pop_back();

if (PopVals)
return popTypes(ErrorLoc, LastSig.Returns);
return checkTypes(ErrorLoc, LastSig.Returns);
}

bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
const wasm::WasmSignature &Sig) {
bool Error = popTypes(ErrorLoc, Sig.Params);
Expand Down Expand Up @@ -308,10 +316,11 @@ bool WebAssemblyAsmTypeCheck::getSignature(SMLoc ErrorLoc,
return false;
}

bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc, bool ExactMatch) {
bool Error = popTypes(ErrorLoc, ReturnTypes, ExactMatch);
Unreachable = true;
return Error;
bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc,
bool ExactMatch) {
assert(!BlockInfoStack.empty());
const auto &FuncInfo = BlockInfoStack[0];
return checkTypes(ErrorLoc, FuncInfo.Sig.Returns, ExactMatch);
}

bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
Expand Down Expand Up @@ -453,51 +462,90 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
}

if (Name == "try" || Name == "block" || Name == "loop" || Name == "if") {
if (Name == "loop")
BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end());
else
BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end());
if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32))
return true;
return false;
bool Error = Name == "if" && popType(ErrorLoc, wasm::ValType::I32);
// Pop block input parameters and check their types are correct
Error |= popTypes(ErrorLoc, LastSig.Params);
// Push a new block info
BlockInfoStack.push_back({LastSig, Stack.size(), Name == "loop"});
// Push back block input parameters
pushTypes(LastSig.Params);
return Error;
}

if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
Name == "else" || Name == "end_try" || Name == "catch" ||
Name == "catch_all" || Name == "delegate") {
bool Error = checkEnd(ErrorLoc, Name == "else" || Name == "catch" ||
Name == "catch_all");
Unreachable = false;
if (Name == "catch") {
assert(!BlockInfoStack.empty());
// Check if the types on the stack match with the block return type
const auto &LastBlockInfo = BlockInfoStack.back();
bool Error = checkTypes(ErrorLoc, LastBlockInfo.Sig.Returns, true);
// Pop all types added to the stack for the current block level
Stack.truncate(LastBlockInfo.StackStartPos);
if (Name == "else") {
// 'else' expects the block input parameters to be on the stack, in the
// same way we entered 'if'
pushTypes(LastBlockInfo.Sig.Params);
} else if (Name == "catch") {
// 'catch' instruction pushes values whose types are specified in the
// tag's 'params' part
const wasm::WasmSignature *Sig = nullptr;
if (!getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
wasm::WASM_SYMBOL_TYPE_TAG, Sig))
// catch instruction pushes values whose types are specified in the
// tag's "params" part
pushTypes(Sig->Params);
else
Error = true;
} else if (Name == "catch_all") {
// 'catch_all' does not push anything onto the stack
} else {
// For normal end markers, push block return value types onto the stack
// and pop the block info
pushTypes(LastBlockInfo.Sig.Returns);
BlockInfoStack.pop_back();
}
return Error;
}

if (Name == "br") {
if (Name == "br" || Name == "br_if") {
bool Error = false;
if (Name == "br_if")
Error |= popType(ErrorLoc, wasm::ValType::I32); // cond
const MCOperand &Operand = Inst.getOperand(0);
if (!Operand.isImm())
return true;
return checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm()));
if (Operand.isImm()) {
unsigned Level = Operand.getImm();
if (Level < BlockInfoStack.size()) {
const auto &DestBlockInfo =
BlockInfoStack[BlockInfoStack.size() - Level - 1];
if (DestBlockInfo.IsLoop)
Error |= checkTypes(ErrorLoc, DestBlockInfo.Sig.Params, false);
else
Error |= checkTypes(ErrorLoc, DestBlockInfo.Sig.Returns, false);
} else {
Error = typeError(ErrorLoc, StringRef("br: invalid depth ") +
std::to_string(Level));
}
} else {
Error =
typeError(Operands[1]->getStartLoc(), "depth should be an integer");
}
if (Name == "br")
pushType(Polymorphic{});
return Error;
}

if (Name == "return") {
return endOfFunction(ErrorLoc, false);
bool Error = endOfFunction(ErrorLoc, false);
pushType(Polymorphic{});
return Error;
}

if (Name == "call_indirect" || Name == "return_call_indirect") {
// Function value.
bool Error = popType(ErrorLoc, wasm::ValType::I32);
Error |= checkSig(ErrorLoc, LastSig);
if (Name == "return_call_indirect" && endOfFunction(ErrorLoc, false))
return true;
if (Name == "return_call_indirect") {
Error |= endOfFunction(ErrorLoc, false);
pushType(Polymorphic{});
}
return Error;
}

Expand All @@ -509,13 +557,15 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
Error |= checkSig(ErrorLoc, *Sig);
else
Error = true;
if (Name == "return_call" && endOfFunction(ErrorLoc, false))
return true;
if (Name == "return_call") {
Error |= endOfFunction(ErrorLoc, false);
pushType(Polymorphic{});
}
return Error;
}

if (Name == "unreachable") {
Unreachable = true;
pushType(Polymorphic{});
return false;
}

Expand All @@ -526,11 +576,15 @@ bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
}

if (Name == "throw") {
bool Error = false;
const wasm::WasmSignature *Sig = nullptr;
if (!getSignature(Operands[1]->getStartLoc(), Inst.getOperand(0),
wasm::WASM_SYMBOL_TYPE_TAG, Sig))
return checkSig(ErrorLoc, *Sig);
return true;
Error |= checkSig(ErrorLoc, *Sig);
else
Error = true;
pushType(Polymorphic{});
return Error;
}

// The current instruction is a stack instruction which doesn't have
Expand Down
18 changes: 9 additions & 9 deletions llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ class WebAssemblyAsmTypeCheck final {

struct Ref : public std::monostate {};
struct Any : public std::monostate {};
using StackType = std::variant<wasm::ValType, Ref, Any>;
struct Polymorphic : public std::monostate {};
using StackType = std::variant<wasm::ValType, Ref, Any, Polymorphic>;
SmallVector<StackType, 16> Stack;
SmallVector<SmallVector<wasm::ValType, 4>, 8> BrStack;
struct BlockInfo {
wasm::WasmSignature Sig;
size_t StackStartPos;
bool IsLoop;
};
SmallVector<BlockInfo, 8> BlockInfoStack;
SmallVector<wasm::ValType, 16> LocalTypes;
SmallVector<wasm::ValType, 4> ReturnTypes;
wasm::WasmSignature LastSig;
bool Unreachable = false;
bool Is64;

// checkTypes checks 'Types' against the value stack. popTypes checks 'Types'
Expand Down Expand Up @@ -68,8 +72,6 @@ class WebAssemblyAsmTypeCheck final {
void dumpTypeStack(Twine Msg);
bool typeError(SMLoc ErrorLoc, const Twine &Msg);
bool getLocal(SMLoc ErrorLoc, const MCOperand &LocalOp, wasm::ValType &Type);
bool checkEnd(SMLoc ErrorLoc, bool PopVals = false);
bool checkBr(SMLoc ErrorLoc, size_t Level);
bool checkSig(SMLoc ErrorLoc, const wasm::WasmSignature &Sig);
bool getSymRef(SMLoc ErrorLoc, const MCOperand &SymOp,
const MCSymbolRefExpr *&SymRef);
Expand All @@ -91,10 +93,8 @@ class WebAssemblyAsmTypeCheck final {

void clear() {
Stack.clear();
BrStack.clear();
BlockInfoStack.clear();
LocalTypes.clear();
ReturnTypes.clear();
Unreachable = false;
}
};

Expand Down
8 changes: 6 additions & 2 deletions llvm/test/MC/WebAssembly/basic-assembly.s
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ test0:
i32.const 3
end_block # "switch" exit.
if # void
i32.const 0
if i32
i32.const 0
end_if
drop
else
end_if
drop
block void
i32.const 2
return
Expand Down Expand Up @@ -222,11 +224,13 @@ empty_exnref_table:
# CHECK-NEXT: i32.const 3
# CHECK-NEXT: end_block # label2:
# CHECK-NEXT: if
# CHECK-NEXT: i32.const 0
# CHECK-NEXT: if i32
# CHECK-NEXT: i32.const 0
# CHECK-NEXT: end_if
# CHECK-NEXT: drop
# CHECK-NEXT: else
# CHECK-NEXT: end_if
# CHECK-NEXT: drop
# CHECK-NEXT: block
# CHECK-NEXT: i32.const 2
# CHECK-NEXT: return
Expand Down
Loading

0 comments on commit e95ddaf

Please sign in to comment.