diff --git a/fcd/ast/ast_context.cpp b/fcd/ast/ast_context.cpp index a95c9f1..fa5effd 100644 --- a/fcd/ast/ast_context.cpp +++ b/fcd/ast/ast_context.cpp @@ -167,13 +167,16 @@ class InstToExpr : public llvm::InstVisitor { if (auto constantInt = dyn_cast(&constant)) { - assert(constantInt->getValue().ule(numeric_limits::max())); - return ctx.numeric(ctx.getIntegerType(false, (unsigned short)constantInt->getBitWidth()), constantInt->getLimitedValue()); + assert(constantInt->getBitWidth() <= numeric_limits::max()); + return ctx.numeric(ctx.getIntegerType(false, (unsigned short)constantInt->getBitWidth()), constantInt->getValue()); } if (auto expression = dyn_cast(&constant)) { - return ctx.uncachedExpressionFor(*expression->getAsInstruction()); + auto inst = expression->getAsInstruction(); + auto res = ctx.uncachedExpressionFor(*inst); + inst->deleteValue(); + return res; } if (auto structure = dyn_cast(&constant)) @@ -311,13 +314,8 @@ class InstToExpr : public llvm::InstVisitor { // special case for a + -const const auto& type = constant->getExpressionType(ctx); - unsigned idleBits = 64 - type.getBits(); - int64_t signedValue = (constant->si64 << idleBits) >> idleBits; - if (signedValue < 0) - { - // I'm pretty sure that we don't need to check for the minimum value for that type - // since a + INT_MIN is the same as a - INT_MIN. - auto positiveRight = ctx.numeric(type, static_cast(-signedValue)); + if (constant->value.isNegative()) { + auto positiveRight = ctx.numeric(type, -constant->value); return ctx.nary(NAryOperatorExpression::Subtract, left, positiveRight); } } diff --git a/fcd/ast/ast_context.h b/fcd/ast/ast_context.h index 146f6ec..308bcc1 100644 --- a/fcd/ast/ast_context.h +++ b/fcd/ast/ast_context.h @@ -162,7 +162,7 @@ class AstContext return allocate(3, cond, ifTrue, ifFalse); } - NumericExpression* numeric(const IntegerExpressionType& type, uint64_t ui) + NumericExpression* numeric(const IntegerExpressionType& type, llvm::APInt ui) { return allocate(0, type, ui); } diff --git a/fcd/ast/expressions.cpp b/fcd/ast/expressions.cpp index ad3d0f2..c572ff7 100644 --- a/fcd/ast/expressions.cpp +++ b/fcd/ast/expressions.cpp @@ -269,7 +269,7 @@ bool NumericExpression::operator==(const Expression& that) const { if (auto token = llvm::dyn_cast(&that)) { - return this->ui64 == token->ui64; + return this->value == token->value; } return false; } diff --git a/fcd/ast/expressions.h b/fcd/ast/expressions.h index 4c8d5b9..ecc7ffa 100644 --- a/fcd/ast/expressions.h +++ b/fcd/ast/expressions.h @@ -17,6 +17,7 @@ #include "not_null.h" #include +#include #include @@ -312,29 +313,20 @@ class TernaryExpression final : public Expression struct NumericExpression final : public Expression { const IntegerExpressionType& expressionType; - union - { - int64_t si64; - uint64_t ui64; - }; + llvm::APInt value; static bool classof(const ExpressionUser* node) { return node->getUserType() == Numeric; } - NumericExpression(AstContext& ctx, unsigned uses, const IntegerExpressionType& type, uint64_t ui) - : Expression(Numeric, ctx, uses), expressionType(type), ui64(ui) - { - assert(uses == 0); - } - - NumericExpression(AstContext& ctx, unsigned uses, const IntegerExpressionType& type, int64_t si) - : Expression(Numeric, ctx, uses), expressionType(type), si64(si) + NumericExpression(AstContext& ctx, unsigned uses, const IntegerExpressionType& type, llvm::APInt val) + : Expression(Numeric, ctx, uses), expressionType(type), value(val) { assert(uses == 0); + assert(value.getBitWidth() == expressionType.getBits()); } - + virtual const IntegerExpressionType& getExpressionType(AstContext&) const override { return expressionType; } virtual bool operator==(const Expression& that) const override; }; diff --git a/fcd/ast/pass_backend.cpp b/fcd/ast/pass_backend.cpp index 192d80b..27968ff 100644 --- a/fcd/ast/pass_backend.cpp +++ b/fcd/ast/pass_backend.cpp @@ -851,7 +851,7 @@ bool AstBackEnd::runOnModule(llvm::Module &m) } // sort outputNodes by virtual address, then by name - sort(outputNodes.begin(), outputNodes.end(), [](unique_ptr& a, unique_ptr& b) + std::sort(outputNodes.begin(), outputNodes.end(), [](unique_ptr& a, unique_ptr& b) { auto virtA = getVirtualAddress(*a); auto virtB = getVirtualAddress(*b); diff --git a/fcd/ast/pass_simplifyexpressions.cpp b/fcd/ast/pass_simplifyexpressions.cpp index 4016479..6813bb9 100644 --- a/fcd/ast/pass_simplifyexpressions.cpp +++ b/fcd/ast/pass_simplifyexpressions.cpp @@ -255,7 +255,7 @@ namespace if (auto addressOf = match(subscript.getPointer(), UnaryOperatorExpression::AddressOf)) if (auto constantIndex = dyn_cast(subscript.getIndex())) - if (constantIndex->ui64 == 0) + if (constantIndex->value == 0) { subscript.replaceAllUsesWith(addressOf->getOperand()); subscript.dropAllReferences(); diff --git a/fcd/ast/pre_ast_cfg.cpp b/fcd/ast/pre_ast_cfg.cpp index 0866bee..afffcd6 100644 --- a/fcd/ast/pre_ast_cfg.cpp +++ b/fcd/ast/pre_ast_cfg.cpp @@ -169,7 +169,7 @@ void PreAstContext::generateBlocks(Function& fn) { auto bits = static_cast(caseValue->getType()->getIntegerBitWidth()); const IntegerExpressionType& type = ctx.getIntegerType(false, bits); - Expression* numericConstant = ctx.numeric(type, caseValue->getLimitedValue()); + Expression* numericConstant = ctx.numeric(type, caseValue->getValue()); caseCondition = ctx.nary(NAryOperatorExpression::Equal, testVariable, numericConstant); } if (dest == &bbRef) @@ -210,7 +210,7 @@ PreAstBasicBlock& PreAstContext::createRedirectorBlock(ArrayRefto); if (iter == caseConditions.end()) { - Expression* numericConstant = ctx.numeric(ctx.getIntegerType(false, 32), caseConditions.size()); + Expression* numericConstant = ctx.numeric(ctx.getIntegerType(false, 32), llvm::APInt(32, caseConditions.size())); auto condition = ctx.nary(NAryOperatorExpression::Equal, sythesizedVariable, numericConstant); iter = caseConditions.insert({edge->to, condition}).first; @@ -231,4 +231,4 @@ void PreAstContext::view() const { ViewGraph(const_cast(this), "Pre-AST Basic Block Graph"); } -#endif \ No newline at end of file +#endif diff --git a/fcd/ast/print.cpp b/fcd/ast/print.cpp index ab3a841..3d3d681 100644 --- a/fcd/ast/print.cpp +++ b/fcd/ast/print.cpp @@ -440,7 +440,7 @@ void StatementPrintVisitor::visitNumeric(const NumericExpression& numeric) // 2- the parent expression is is a bitwise operator and the number is greater than 9. if (auto nary = dyn_cast_or_null(parentExpression)) { - if (numeric.ui64 > 9) + if (numeric.value.ugt(9)) { switch (nary->getType()) { @@ -461,11 +461,11 @@ void StatementPrintVisitor::visitNumeric(const NumericExpression& numeric) if (formatAsHex) { - (os << "0x").write_hex(numeric.ui64); + os << "0x" << numeric.value.toString(16, false); } else { - os << numeric.si64; + os << numeric.value.toString(10, true); } } diff --git a/fcd/ast/print_item.cpp b/fcd/ast/print_item.cpp index 449e3ea..64a76e7 100644 --- a/fcd/ast/print_item.cpp +++ b/fcd/ast/print_item.cpp @@ -28,6 +28,8 @@ namespace } } +PrintableItem::~PrintableItem() {} + void PrintableItem::dump() const { print(errs(), 0); diff --git a/fcd/ast/print_item.h b/fcd/ast/print_item.h index d634414..4806112 100644 --- a/fcd/ast/print_item.h +++ b/fcd/ast/print_item.h @@ -34,6 +34,7 @@ class PrintableItem PrintableScope* parent; public: + virtual ~PrintableItem(); PrintableItem(Type type, PrintableScope* parent) : discriminant(type), parent(parent) { diff --git a/fcd/codegen/translation_context_remill.cpp b/fcd/codegen/translation_context_remill.cpp index e86ebf2..38ffe96 100644 --- a/fcd/codegen/translation_context_remill.cpp +++ b/fcd/codegen/translation_context_remill.cpp @@ -23,6 +23,8 @@ #include #include +#include +#include #include #include @@ -34,10 +36,7 @@ #include "fcd/compat/Scalar.h" #include "fcd/codegen/translation_context_remill.h" -#include "fcd/pass_argrec_remill.h" -#include "fcd/pass_asaa.h" -#include "fcd/pass_intrinsics_remill.h" -#include "fcd/pass_stackrec_remill.h" +#include "fcd/passes.h" namespace fcd { namespace { @@ -278,8 +277,7 @@ RemillTranslationContext::RemillTranslationContext(llvm::LLVMContext &ctx, module = std::unique_ptr(remill::LoadTargetSemantics(&ctx)); target_arch->PrepareModule(module); intrinsics = std::make_unique(module.get()); - lifter = std::make_unique(target_arch, - intrinsics.get()); + lifter = std::make_unique(target_arch, intrinsics.get()); } uint64_t RemillTranslationContext::FindFunctionAddr(llvm::Function *func) { @@ -555,6 +553,7 @@ const StubInfo *RemillTranslationContext::GetStubInfo( if (auto int2ptr = llvm::dyn_cast(inst)) { addr = llvm::dyn_cast(int2ptr->getOperand(0)); } + inst->deleteValue(); } else { addr = llvm::dyn_cast(read_op); } @@ -576,6 +575,7 @@ void RemillTranslationContext::FinalizeModule() { llvm::legacy::PassManager phase_one; phase_one.add(llvm::createAlwaysInlinerLegacyPass()); + phase_one.add(createSignExtPass()); phase_one.add(createRemillArgumentRecoveryPass()); phase_one.add(llvm::createPromoteMemoryToRegisterPass()); phase_one.add(llvm::createReassociatePass()); diff --git a/fcd/pass_argrec_remill.cpp b/fcd/pass_argrec_remill.cpp index baa32fd..3aa3817 100644 --- a/fcd/pass_argrec_remill.cpp +++ b/fcd/pass_argrec_remill.cpp @@ -115,51 +115,25 @@ static std::unordered_set RegisterAliasSet(const char *reg) { return result; } -static std::unordered_set UsersOfVar(llvm::Function *func, - llvm::Value *var) { - std::unordered_set result; - if (var->hasNUsesOrMore(2)) { +static std::unordered_map UsersOfReg(llvm::Function *func, const char *reg) { + std::unordered_map result; + for (auto alias : RegisterAliasSet(reg)) { + auto var = remill::FindVarInFunction(func, alias); for (auto var_user : var->users()) { - if (auto addr = llvm::dyn_cast(var_user)) { - for (auto addr_user : addr->users()) { - result.insert(addr_user); - } - } + result[var_user] = alias; } } return result; } -template -static llvm::User *FirstUserOfVar(llvm::Function *func, llvm::Value *var, - T &instruction_list) { - auto users = UsersOfVar(func, var); - if (!users.empty()) { - for (auto &inst : instruction_list) { - if (users.count(&inst) > 0) { - return &inst; - } - } - } - return nullptr; -} - -template -static std::pair FirstUserOfReg( - llvm::Function *func, const char *reg, T &instruction_list) { - std::unordered_map users; - for (auto alias : RegisterAliasSet(reg)) { - auto var = remill::FindVarInFunction(func, alias); - if (auto user = FirstUserOfVar(func, var, instruction_list)) { - users[user] = alias; - } - } - - if (!users.empty()) { - for (auto &inst : instruction_list) { +template +static std::pair FirstRegUser(llvm::Function *func, const char *reg, T &instruction_list) { + auto users = UsersOfReg(func, reg); + for (auto &inst : instruction_list) { + if (auto typedInst = llvm::dyn_cast(&inst)) { auto it = users.find(&inst); - if (it != users.end()) { - return *it; + if (it != users.end()) { + return std::make_pair(typedInst, it->second); } } } @@ -188,11 +162,9 @@ static llvm::Type *RecoverRetType(llvm::Function *func, CallingConvention &cc) { for (auto block : TerminalBlocksOf(func)) { auto ilist = llvm::make_range(block->rbegin(), block->rend()); for (auto reg : ret_regs) { - auto user = FirstUserOfReg(func, reg, ilist); + auto user = FirstRegUser(func, reg, ilist); if (user.first != nullptr) { - if (auto store = llvm::dyn_cast(user.first)) { - found_types.insert(store->getValueOperand()->getType()); - } + found_types.insert(user.first->getValueOperand()->getType()); } } } @@ -213,7 +185,7 @@ static void LoadReturnRegToRetInsts(llvm::Function *func, for (auto block : TerminalBlocksOf(func)) { auto term = block->getTerminator(); ir.SetInsertPoint(term); - auto val = ir.CreateLoad(ir.CreateLoad(var)); + auto val = ir.CreateLoad(var); ir.CreateRet(val); term->eraseFromParent(); } @@ -234,19 +206,24 @@ static std::string TrimPrefix(std::string str) { } return ref.str(); } + +typedef std::unordered_map FunctionMap; static void UpdateCalls(llvm::Function *old_func, llvm::Function *new_func, - CallingConvention &cc) { + CallingConvention &cc, FunctionMap const & funcs) { llvm::IRBuilder<> ir(new_func->getContext()); + llvm::Value* undef = llvm::UndefValue::get(old_func->getType()); for (auto old_call : remill::CallersOf(old_func)) { auto caller = old_call->getParent()->getParent(); - if (caller->getName().startswith(sPrefix)) { + if (funcs.find(caller) != funcs.end()) { + old_call->setCalledFunction(undef); + } else { ir.SetInsertPoint(old_call); std::vector params; for (auto &arg : new_func->args()) { auto name = TrimPrefix(arg.getName().str()); auto arg_var = remill::FindVarInFunction(caller, name); - params.push_back(ir.CreateLoad(ir.CreateLoad(arg_var))); + params.push_back(ir.CreateLoad(arg_var)); } auto new_call = ir.CreateCall(new_func, params); @@ -255,7 +232,7 @@ static void UpdateCalls(llvm::Function *old_func, llvm::Function *new_func, if (!ret_type->isVoidTy()) { auto ret_reg = cc.ReturnRegForType(ret_type); auto ret_var = remill::FindVarInFunction(caller, ret_reg); - ir.CreateStore(new_call, ir.CreateLoad(ret_var)); + ir.CreateStore(new_call, ret_var); } old_call->replaceAllUsesWith( old_call->getArgOperand(remill::kMemoryPointerArgNum)); @@ -266,32 +243,23 @@ static void UpdateCalls(llvm::Function *old_func, llvm::Function *new_func, static llvm::Function *DeclareParametrizedFunc(llvm::Function *func, CallingConvention &cc) { + std::vector used_regs; + std::vector params; + // Get parameter regs from the callconv. Also add the stack pointer reg, // since it's used to access parameters passed by stack. Also add aliases. auto cc_regs = cc.ParamRegs(); auto ilist = llvm::make_range(llvm::inst_begin(func), llvm::inst_end(func)); cc_regs.insert(cc_regs.begin(), cc.StackPointerVarName()); for (auto reg : cc_regs) { - auto user = FirstUserOfReg(func, reg, ilist); - if (user.first != nullptr) { - if (llvm::isa(user.first)) { - used_regs.push_back(user.second); - } + auto alias = FirstRegUser(func, reg, ilist); + if (alias.first != nullptr) { + used_regs.push_back(alias.second); + params.push_back(alias.first->getType()); } } - // Gather parameter types from register variable alloca's - std::vector params; - for (auto reg : used_regs) { - auto var = remill::FindVarInFunction(func, reg); - auto inst = llvm::dyn_cast(var); - CHECK(inst != nullptr); - auto type = llvm::dyn_cast(inst->getAllocatedType()); - CHECK(type != nullptr); - params.push_back(type->getElementType()); - } - auto ret = RecoverRetType(func, cc); std::stringstream cc_func_name; @@ -307,8 +275,9 @@ static llvm::Function *DeclareParametrizedFunc(llvm::Function *func, std::stringstream cc_arg_name; cc_arg_name << sPrefix << used_regs[arg.getArgNo()]; arg.setName(cc_arg_name.str()); + removeAttr(arg, llvm::Attribute::Dereferenceable); } - + return cc_func; } @@ -317,8 +286,7 @@ static void StoreRegArgsToLocals(llvm::Function *func) { for (auto &arg : func->args()) { auto name = TrimPrefix(arg.getName().str()); auto var = remill::FindVarInFunction(func, name); - auto ptr = ir.CreateLoad(var); - ir.CreateStore(&arg, ptr); + ir.CreateStore(&arg, var); } } @@ -335,12 +303,13 @@ static void ConvertRemillArgsToLocals(llvm::Function *func) { ir.SetInsertPoint(loc_mem); - auto pc_type = remill::AddressType(module); + //auto pc_type = remill::AddressType(module); auto arg_pc = remill::NthArgument(func, remill::kPCArgNum); - auto loc_pc = ir.CreateAlloca(pc_type, nullptr, "loc_pc"); - arg_pc->replaceAllUsesWith(loc_pc); + CHECK(arg_pc->use_empty()); +// auto loc_pc = ir.CreateAlloca(pc_type, nullptr, "loc_pc"); +// arg_pc->replaceAllUsesWith(loc_pc); - ir.SetInsertPoint(loc_pc); +// ir.SetInsertPoint(loc_pc); auto state_type = remill::StatePointerType(module)->getElementType(); auto arg_state = remill::NthArgument(func, remill::kStatePointerArgNum); @@ -374,7 +343,7 @@ void RemillArgumentRecovery::getAnalysisUsage( llvm::AnalysisUsage &usage) const {} bool RemillArgumentRecovery::runOnModule(llvm::Module &module) { - std::vector new_funcs; + FunctionMap funcs; for (auto &func : module) { if (IsLiftedFunction(&func)) { // Recover the argument and return types of `func` by @@ -403,25 +372,25 @@ bool RemillArgumentRecovery::runOnModule(llvm::Module &module) { // convention defines which return register to use. LoadReturnRegToRetInsts(cc_func, cc); - new_funcs.push_back(cc_func); + funcs[&func] = cc_func; } } - for (auto func : new_funcs) { + for (auto func_pair : funcs) { // Replace all uses of the old function with the new // one with the recovered parameter and return types. // Then delete the old function from the module. - auto name = TrimPrefix(func->getName()); - auto old_func = remill::FindFunction(&module, name); - UpdateCalls(old_func, func, cc); - for (auto &arg : func->args()) { - auto arg_name = TrimPrefix(arg.getName()); - auto var = remill::FindVarInFunction(func, arg_name); - arg.takeName(var); - removeAttr(arg, llvm::Attribute::Dereferenceable); - } - func->takeName(old_func); + llvm::Function* old_func = func_pair.first; + llvm::Function* new_func = func_pair.second; + UpdateCalls(old_func, new_func, cc, funcs); +// for (auto &arg : new_func->args()) { +// auto arg_name = TrimPrefix(arg.getName()); +// auto var = remill::FindVarInFunction(new_func, arg_name); +// arg.takeName(var); +// } + CHECK(old_func->use_empty()); old_func->replaceAllUsesWith(llvm::UndefValue::get(old_func->getType())); + new_func->takeName(old_func); old_func->eraseFromParent(); } diff --git a/fcd/pass_intrinsics_remill.cpp b/fcd/pass_intrinsics_remill.cpp index 771d2ba..69796af 100644 --- a/fcd/pass_intrinsics_remill.cpp +++ b/fcd/pass_intrinsics_remill.cpp @@ -71,19 +71,24 @@ bool RemillFixIntrinsics::runOnModule(llvm::Module &module) { // Create "__fcd*" replacements for "__remill*" intrinsics for (auto &func : module) { if (IsRemillIntrinsicWithUse(&func)) { + llvm::AttributeList attrs = func.getAttributes(); // Gather non-lifting arguments from `func`. std::vector arg_types; + std::vector arg_attrs; for (auto &arg : func.args()) { auto arg_type = arg.getType(); if (!IsRemillLiftingArgType(arg_type)) { arg_types.push_back(arg_type); + arg_attrs.push_back(attrs.getAttributes(llvm::AttributeList::FirstArgIndex + arg.getArgNo())); } } // Determine return type of `new_func` // Replace `Memory` with `void`. auto ret_type = func.getReturnType(); + llvm::AttributeSet ret_attrs = attrs.getAttributes(llvm::AttributeList::ReturnIndex); if (IsRemillLiftingArgType(ret_type)) { ret_type = llvm::Type::getVoidTy(module.getContext()); + ret_attrs = llvm::AttributeSet(); } // Create the type for `new_func` auto func_type = llvm::FunctionType::get(ret_type, arg_types, false); @@ -94,6 +99,10 @@ bool RemillFixIntrinsics::runOnModule(llvm::Module &module) { // Create `new_func` auto new_func = llvm::dyn_cast( module.getOrInsertFunction(ss.str(), func_type)); + + auto func_attrs = attrs.getFnAttributes(); + auto attributes = llvm::AttributeList::get(func.getContext(), func_attrs, ret_attrs, arg_attrs); + new_func->setAttributes(attributes); CHECK(new_func != nullptr); // Map `new_func` to the old `func` for later use. funcs[&func] = new_func; diff --git a/fcd/passes.h b/fcd/passes.h index ad3d36a..ff4b1dc 100644 --- a/fcd/passes.h +++ b/fcd/passes.h @@ -16,7 +16,11 @@ #include "fcd/ast/pass_backend.h" #include "fcd/pass_asaa.h" +#include "fcd/pass_argrec_remill.h" +#include "fcd/pass_stackrec_remill.h" +#include "fcd/pass_intrinsics_remill.h" -llvm::FunctionPass* createRegisterPointerPromotionPass(); +llvm::FunctionPass* createRegisterPointerPromotionPass(); +llvm::FunctionPass* createSignExtPass(); #endif /* defined(fcd__passes_h) */