bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { if (skipSCC(SCC)) return false; // Get the callgraph information that we need to update to reflect our // changes. CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); LegacyAARGetter AARGetter(*this); bool Changed = false, LocalChange; // Iterate until we stop promoting from this SCC. do { LocalChange = false; // Attempt to promote arguments from all functions in this SCC. for (CallGraphNode *OldNode : SCC) { Function *OldF = OldNode->getFunction(); if (!OldF) continue; auto ReplaceCallSite = [&](CallSite OldCS, CallSite NewCS) { Function *Caller = OldCS.getInstruction()->getParent()->getParent(); CallGraphNode *NewCalleeNode = CG.getOrInsertFunction(NewCS.getCalledFunction()); CallGraphNode *CallerNode = CG[Caller]; CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode); }; if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements, {ReplaceCallSite})) { LocalChange = true; // Update the call graph for the newly promoted function. CallGraphNode *NewNode = CG.getOrInsertFunction(NewF); NewNode->stealCalledFunctionsFrom(OldNode); if (OldNode->getNumReferences() == 0) delete CG.removeFunctionFromModule(OldNode); else OldF->setLinkage(Function::ExternalLinkage); // And updat ethe SCC we're iterating as well. SCC.ReplaceNode(OldNode, NewNode); } } // Remember that we changed something. Changed |= LocalChange; } while (LocalChange); return Changed; }
CallGraphNode* ArgumentRecovery::recoverArguments(llvm::CallGraphNode *node) { Function* fn = node->getFunction(); if (fn == nullptr) { // "theoretical nodes", whatever that is return nullptr; } // quick exit if there isn't exactly one argument if (fn->arg_size() != 1) { return nullptr; } Argument* fnArg = fn->arg_begin(); if (!isStructType(fnArg)) { return nullptr; } // This is a nasty NASTY hack that relies on the AA pass being RegisterUse. // The data should be moved to a separate helper pass that can be queried from both the AA pass and this one. RegisterUse& regUse = getAnalysis<RegisterUse>(); CallGraph& cg = getAnalysis<CallGraphWrapperPass>().getCallGraph(); const auto* modRefInfo = regUse.getModRefInfo(fn); assert(modRefInfo != nullptr); // At this point we pretty much know that we're going to modify the function, so start doing that. // Get register offsets from the old function before we start mutilating it. auto& registerMap = exposeAllRegisters(fn); // Create a new function prototype, asking RegisterUse for which registers should be passed in, and how. LLVMContext& ctx = fn->getContext(); SmallVector<pair<const char*, Type*>, 16> parameters; Type* int64 = Type::getInt64Ty(ctx); Type* int64ptr = Type::getInt64PtrTy(ctx); for (const auto& pair : *modRefInfo) { if (pair.second != RegisterUse::NoModRef) { Type* paramType = (pair.second & RegisterUse::Mod) == RegisterUse::Mod ? int64ptr : int64; parameters.push_back({pair.first, paramType}); } } // Order parameters. // FIXME: This could use an ABI-specific sort routine. For now, use a lexicographical sort. sort(parameters.begin(), parameters.end(), [](const pair<const char*, Type*>& a, const pair<const char*, Type*>& b) { return strcmp(a.first, b.first) < 0; }); // Extract parameter types. SmallVector<Type*, 16> parameterTypes; for (const auto& pair : parameters) { parameterTypes.push_back(pair.second); } // Ideally, we would also do caller analysis here to figure out which output registers are never read, such that // we can either eliminate them from the parameter list or pass them by value instead of by address. // We would also pick a return value. FunctionType* newFunctionType = FunctionType::get(Type::getVoidTy(ctx), parameterTypes, false); Function* newFunc = Function::Create(newFunctionType, fn->getLinkage()); newFunc->copyAttributesFrom(fn); fn->getParent()->getFunctionList().insert(fn, newFunc); newFunc->takeName(fn); fn->setName("__hollow_husk__" + newFunc->getName()); // Set argument names size_t i = 0; for (Argument& arg : newFunc->args()) { arg.setName(parameters[i].first); i++; } // update call graph CallGraphNode* newFuncNode = cg.getOrInsertFunction(newFunc); CallGraphNode* oldFuncNode = cg[fn]; // loop over callers and transform call sites. while (!fn->use_empty()) { CallSite cs(fn->user_back()); Instruction* call = cast<CallInst>(cs.getInstruction()); Function* caller = call->getParent()->getParent(); auto& registerPositions = exposeAllRegisters(caller); SmallVector<Value*, 16> callParameters; for (const auto& pair : parameters) { // HACKHACK: find a pointer to a 64-bit int in the set. Value* registerPointer = nullptr; auto range = registerPositions.equal_range(pair.first); for (auto iter = range.first; iter != range.second; iter++) { if (auto gep = dyn_cast<GetElementPtrInst>(iter->second)) if (gep->getResultElementType() == int64) { registerPointer = gep; break; } } assert(registerPointer != nullptr); if (isa<PointerType>(pair.second)) { callParameters.push_back(registerPointer); } else { // Create a load instruction. GVN will get rid of it if it's unnecessary. LoadInst* load = new LoadInst(registerPointer, pair.first, call); callParameters.push_back(load); } } CallInst* newCall = CallInst::Create(newFunc, callParameters, "", call); // Update AA regUse.replaceWithNewValue(call, newCall); // Update call graph CallGraphNode* calleeNode = cg[caller]; calleeNode->replaceCallEdge(cs, CallSite(newCall), newFuncNode); // Finish replacing if (!call->use_empty()) { call->replaceAllUsesWith(newCall); newCall->takeName(call); } call->eraseFromParent(); } // Do not fix functions without a body. if (!fn->isDeclaration()) { // Fix up function code. Start by moving everything into the new function. newFunc->getBasicBlockList().splice(newFunc->begin(), fn->getBasicBlockList()); newFuncNode->stealCalledFunctionsFrom(oldFuncNode); // Change register uses size_t argIndex = 0; auto& argList = newFunc->getArgumentList(); // Create a temporary insertion point. We don't want an existing instruction since chances are that we'll remove it. Instruction* insertionPoint = BinaryOperator::CreateAdd(ConstantInt::get(int64, 0), ConstantInt::get(int64, 0), "noop", newFunc->begin()->begin()); for (auto iter = argList.begin(); iter != argList.end(); iter++, argIndex++) { Value* replaceWith = iter; const auto& paramTuple = parameters[argIndex]; if (!isa<PointerType>(paramTuple.second)) { // Create an alloca, copy value from parameter, replace GEP with alloca. // This is ugly code gen, but it will optimize easily, and still work if // we need a pointer reference to the register. auto alloca = new AllocaInst(paramTuple.second, paramTuple.first, insertionPoint); new StoreInst(iter, alloca, insertionPoint); replaceWith = alloca; } // Replace all uses with new instance. auto iterPair = registerMap.equal_range(paramTuple.first); for (auto registerMapIter = iterPair.first; registerMapIter != iterPair.second; registerMapIter++) { auto& registerValue = registerMapIter->second; registerValue->replaceAllUsesWith(replaceWith); cast<Instruction>(registerValue)->eraseFromParent(); registerValue = replaceWith; } } // At this point, the uses of the argument struct left should be: // * preserved registers // * indirect jumps const auto& target = getAnalysis<TargetInfo>(); while (!fnArg->use_empty()) { auto lastUser = fnArg->user_back(); if (auto user = dyn_cast<GetElementPtrInst>(lastUser)) { // Promote register to alloca. const char* maybeName = target.registerName(*user); const char* regName = target.largestOverlappingRegister(maybeName); assert(regName != nullptr); auto alloca = new AllocaInst(user->getResultElementType(), regName, insertionPoint); user->replaceAllUsesWith(alloca); user->eraseFromParent(); } else { auto call = cast<CallInst>(lastUser); Function* intrin = nullptr; StringRef intrinName = call->getCalledFunction()->getName(); if (intrinName == "x86_jump_intrin") { intrin = indirectJump; } else if (intrinName == "x86_call_intrin") { intrin = indirectCall; } else { assert(false); // Can't decompile this function. Delete its body. newFunc->deleteBody(); insertionPoint = nullptr; break; } // Replace intrinsic with another intrinsic. Value* jumpTarget = call->getOperand(2); SmallVector<Value*, 16> callArgs; callArgs.push_back(jumpTarget); for (Argument& arg : argList) { callArgs.push_back(&arg); } CallInst* varargCall = CallInst::Create(intrin, callArgs, "", call); newFuncNode->replaceCallEdge(CallSite(call), CallSite(varargCall), cg[intrin]); regUse.replaceWithNewValue(call, varargCall); varargCall->takeName(call); call->eraseFromParent(); } } if (insertionPoint != nullptr) { // no longer needed insertionPoint->eraseFromParent(); } } // At this point nothing should be using the old register argument anymore. (Pray!) // Leave the hollow husk of the old function in place to be erased by global DCE. registerAddresses[newFunc] = move(registerMap); registerAddresses.erase(fn); // Should be all. return newFuncNode; }