Esempio n. 1
0
bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) {
  if (skipSCC(SCC))
    return false;

  // Get the callgraph information that we need to update to reflect our
  // changes.
  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();

  LegacyAARGetter AARGetter(*this);

  bool Changed = false, LocalChange;

  // Iterate until we stop promoting from this SCC.
  do {
    LocalChange = false;
    // Attempt to promote arguments from all functions in this SCC.
    for (CallGraphNode *OldNode : SCC) {
      Function *OldF = OldNode->getFunction();
      if (!OldF)
        continue;

      auto ReplaceCallSite = [&](CallSite OldCS, CallSite NewCS) {
        Function *Caller = OldCS.getInstruction()->getParent()->getParent();
        CallGraphNode *NewCalleeNode =
            CG.getOrInsertFunction(NewCS.getCalledFunction());
        CallGraphNode *CallerNode = CG[Caller];
        CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode);
      };

      if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements,
                                            {ReplaceCallSite})) {
        LocalChange = true;

        // Update the call graph for the newly promoted function.
        CallGraphNode *NewNode = CG.getOrInsertFunction(NewF);
        NewNode->stealCalledFunctionsFrom(OldNode);
        if (OldNode->getNumReferences() == 0)
          delete CG.removeFunctionFromModule(OldNode);
        else
          OldF->setLinkage(Function::ExternalLinkage);

        // And updat ethe SCC we're iterating as well.
        SCC.ReplaceNode(OldNode, NewNode);
      }
    }
    // Remember that we changed something.
    Changed |= LocalChange;
  } while (LocalChange);

  return Changed;
}
Esempio n. 2
0
CallGraphNode* ArgumentRecovery::recoverArguments(llvm::CallGraphNode *node)
{
	Function* fn = node->getFunction();
	if (fn == nullptr)
	{
		// "theoretical nodes", whatever that is
		return nullptr;
	}
	
	// quick exit if there isn't exactly one argument
	if (fn->arg_size() != 1)
	{
		return nullptr;
	}
	
	Argument* fnArg = fn->arg_begin();
	if (!isStructType(fnArg))
	{
		return nullptr;
	}
	
	// This is a nasty NASTY hack that relies on the AA pass being RegisterUse.
	// The data should be moved to a separate helper pass that can be queried from both the AA pass and this one.
	RegisterUse& regUse = getAnalysis<RegisterUse>();
	CallGraph& cg = getAnalysis<CallGraphWrapperPass>().getCallGraph();
	
	const auto* modRefInfo = regUse.getModRefInfo(fn);
	assert(modRefInfo != nullptr);
	
	// At this point we pretty much know that we're going to modify the function, so start doing that.
	// Get register offsets from the old function before we start mutilating it.
	auto& registerMap = exposeAllRegisters(fn);
	
	// Create a new function prototype, asking RegisterUse for which registers should be passed in, and how.
	
	LLVMContext& ctx = fn->getContext();
	SmallVector<pair<const char*, Type*>, 16> parameters;
	Type* int64 = Type::getInt64Ty(ctx);
	Type* int64ptr = Type::getInt64PtrTy(ctx);
	for (const auto& pair : *modRefInfo)
	{
		if (pair.second != RegisterUse::NoModRef)
		{
			Type* paramType = (pair.second & RegisterUse::Mod) == RegisterUse::Mod ? int64ptr : int64;
			parameters.push_back({pair.first, paramType});
		}
	}
	
	// Order parameters.
	// FIXME: This could use an ABI-specific sort routine. For now, use a lexicographical sort.
	sort(parameters.begin(), parameters.end(), [](const pair<const char*, Type*>& a, const pair<const char*, Type*>& b) {
		return strcmp(a.first, b.first) < 0;
	});
	
	// Extract parameter types.
	SmallVector<Type*, 16> parameterTypes;
	for (const auto& pair : parameters)
	{
		parameterTypes.push_back(pair.second);
	}
	
	// Ideally, we would also do caller analysis here to figure out which output registers are never read, such that
	// we can either eliminate them from the parameter list or pass them by value instead of by address.
	// We would also pick a return value.
	FunctionType* newFunctionType = FunctionType::get(Type::getVoidTy(ctx), parameterTypes, false);

	Function* newFunc = Function::Create(newFunctionType, fn->getLinkage());
	newFunc->copyAttributesFrom(fn);
	fn->getParent()->getFunctionList().insert(fn, newFunc);
	newFunc->takeName(fn);
	fn->setName("__hollow_husk__" + newFunc->getName());
	
	// Set argument names
	size_t i = 0;
	
	for (Argument& arg : newFunc->args())
	{
		arg.setName(parameters[i].first);
		i++;
	}
	
	// update call graph
	CallGraphNode* newFuncNode = cg.getOrInsertFunction(newFunc);
	CallGraphNode* oldFuncNode = cg[fn];
	
	// loop over callers and transform call sites.
	while (!fn->use_empty())
	{
		CallSite cs(fn->user_back());
		Instruction* call = cast<CallInst>(cs.getInstruction());
		Function* caller = call->getParent()->getParent();
		
		auto& registerPositions = exposeAllRegisters(caller);
		SmallVector<Value*, 16> callParameters;
		for (const auto& pair : parameters)
		{
			// HACKHACK: find a pointer to a 64-bit int in the set.
			Value* registerPointer = nullptr;
			auto range = registerPositions.equal_range(pair.first);
			for (auto iter = range.first; iter != range.second; iter++)
			{
				if (auto gep = dyn_cast<GetElementPtrInst>(iter->second))
				if (gep->getResultElementType() == int64)
				{
					registerPointer = gep;
					break;
				}
			}
			
			assert(registerPointer != nullptr);
			
			if (isa<PointerType>(pair.second))
			{
				callParameters.push_back(registerPointer);
			}
			else
			{
				// Create a load instruction. GVN will get rid of it if it's unnecessary.
				LoadInst* load = new LoadInst(registerPointer, pair.first, call);
				callParameters.push_back(load);
			}
		}
		
		CallInst* newCall = CallInst::Create(newFunc, callParameters, "", call);
		
		// Update AA
		regUse.replaceWithNewValue(call, newCall);
		
		// Update call graph
		CallGraphNode* calleeNode = cg[caller];
		calleeNode->replaceCallEdge(cs, CallSite(newCall), newFuncNode);
		
		// Finish replacing
		if (!call->use_empty())
		{
			call->replaceAllUsesWith(newCall);
			newCall->takeName(call);
		}
		
		call->eraseFromParent();
	}
	
	// Do not fix functions without a body.
	if (!fn->isDeclaration())
	{
		// Fix up function code. Start by moving everything into the new function.
		newFunc->getBasicBlockList().splice(newFunc->begin(), fn->getBasicBlockList());
		newFuncNode->stealCalledFunctionsFrom(oldFuncNode);
		
		// Change register uses
		size_t argIndex = 0;
		auto& argList = newFunc->getArgumentList();
		
		// Create a temporary insertion point. We don't want an existing instruction since chances are that we'll remove it.
		Instruction* insertionPoint = BinaryOperator::CreateAdd(ConstantInt::get(int64, 0), ConstantInt::get(int64, 0), "noop", newFunc->begin()->begin());
		for (auto iter = argList.begin(); iter != argList.end(); iter++, argIndex++)
		{
			Value* replaceWith = iter;
			const auto& paramTuple = parameters[argIndex];
			if (!isa<PointerType>(paramTuple.second))
			{
				// Create an alloca, copy value from parameter, replace GEP with alloca.
				// This is ugly code gen, but it will optimize easily, and still work if
				// we need a pointer reference to the register.
				auto alloca = new AllocaInst(paramTuple.second, paramTuple.first, insertionPoint);
				new StoreInst(iter, alloca, insertionPoint);
				replaceWith = alloca;
			}
			
			// Replace all uses with new instance.
			auto iterPair = registerMap.equal_range(paramTuple.first);
			for (auto registerMapIter = iterPair.first; registerMapIter != iterPair.second; registerMapIter++)
			{
				auto& registerValue = registerMapIter->second;
				registerValue->replaceAllUsesWith(replaceWith);
				cast<Instruction>(registerValue)->eraseFromParent();
				registerValue = replaceWith;
			}
		}
		
		// At this point, the uses of the argument struct left should be:
		// * preserved registers
		// * indirect jumps
		const auto& target = getAnalysis<TargetInfo>();
		while (!fnArg->use_empty())
		{
			auto lastUser = fnArg->user_back();
			if (auto user = dyn_cast<GetElementPtrInst>(lastUser))
			{
				// Promote register to alloca.
				const char* maybeName = target.registerName(*user);
				const char* regName = target.largestOverlappingRegister(maybeName);
				assert(regName != nullptr);
				
				auto alloca = new AllocaInst(user->getResultElementType(), regName, insertionPoint);
				user->replaceAllUsesWith(alloca);
				user->eraseFromParent();
			}
			else
			{
				auto call = cast<CallInst>(lastUser);
				
				Function* intrin = nullptr;
				StringRef intrinName = call->getCalledFunction()->getName();
				if (intrinName == "x86_jump_intrin")
				{
					intrin = indirectJump;
				}
				else if (intrinName == "x86_call_intrin")
				{
					intrin = indirectCall;
				}
				else
				{
					assert(false);
					// Can't decompile this function. Delete its body.
					newFunc->deleteBody();
					insertionPoint = nullptr;
					break;
				}
				
				// Replace intrinsic with another intrinsic.
				Value* jumpTarget = call->getOperand(2);
				SmallVector<Value*, 16> callArgs;
				callArgs.push_back(jumpTarget);
				for (Argument& arg : argList)
				{
					callArgs.push_back(&arg);
				}
				
				CallInst* varargCall = CallInst::Create(intrin, callArgs, "", call);
				newFuncNode->replaceCallEdge(CallSite(call), CallSite(varargCall), cg[intrin]);
				regUse.replaceWithNewValue(call, varargCall);
				
				varargCall->takeName(call);
				call->eraseFromParent();
			}
		}
		if (insertionPoint != nullptr)
		{
			// no longer needed
			insertionPoint->eraseFromParent();
		}
	}
	
	// At this point nothing should be using the old register argument anymore. (Pray!)
	// Leave the hollow husk of the old function in place to be erased by global DCE.
	registerAddresses[newFunc] = move(registerMap);
	registerAddresses.erase(fn);
	
	// Should be all.
	return newFuncNode;
}