コード例 #1
0
ファイル: Evaluator.cpp プロジェクト: 2trill2spill/freebsd
/// Evaluate a call to function F, returning true if successful, false if we
/// can't evaluate it.  ActualArgs contains the formal arguments for the
/// function.
bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal,
                                 const SmallVectorImpl<Constant*> &ActualArgs) {
  // Check to see if this function is already executing (recursion).  If so,
  // bail out.  TODO: we might want to accept limited recursion.
  if (is_contained(CallStack, F))
    return false;

  CallStack.push_back(F);

  // Initialize arguments to the incoming values specified.
  unsigned ArgNo = 0;
  for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); AI != E;
       ++AI, ++ArgNo)
    setVal(&*AI, ActualArgs[ArgNo]);

  // ExecutedBlocks - We only handle non-looping, non-recursive code.  As such,
  // we can only evaluate any one basic block at most once.  This set keeps
  // track of what we have executed so we can detect recursive cases etc.
  SmallPtrSet<BasicBlock*, 32> ExecutedBlocks;

  // CurBB - The current basic block we're evaluating.
  BasicBlock *CurBB = &F->front();

  BasicBlock::iterator CurInst = CurBB->begin();

  while (1) {
    BasicBlock *NextBB = nullptr; // Initialized to avoid compiler warnings.
    DEBUG(dbgs() << "Trying to evaluate BB: " << *CurBB << "\n");

    if (!EvaluateBlock(CurInst, NextBB))
      return false;

    if (!NextBB) {
      // Successfully running until there's no next block means that we found
      // the return.  Fill it the return value and pop the call stack.
      ReturnInst *RI = cast<ReturnInst>(CurBB->getTerminator());
      if (RI->getNumOperands())
        RetVal = getVal(RI->getOperand(0));
      CallStack.pop_back();
      return true;
    }

    // Okay, we succeeded in evaluating this control flow.  See if we have
    // executed the new block before.  If so, we have a looping function,
    // which we cannot evaluate in reasonable time.
    if (!ExecutedBlocks.insert(NextBB).second)
      return false;  // looped!

    // Okay, we have never been in this block before.  Check to see if there
    // are any PHI nodes.  If so, evaluate them with information about where
    // we came from.
    PHINode *PN = nullptr;
    for (CurInst = NextBB->begin();
         (PN = dyn_cast<PHINode>(CurInst)); ++CurInst)
      setVal(PN, getVal(PN->getIncomingValueForBlock(CurBB)));

    // Advance to the next block.
    CurBB = NextBB;
  }
}
コード例 #2
0
/// Check to see if the function containing the specified tail call consistently
/// returns the same runtime-constant value at all exit points except for
/// IgnoreRI. If so, return the returned value.
static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) {
  Function *F = CI->getParent()->getParent();
  Value *ReturnedValue = nullptr;

  for (BasicBlock &BBI : *F) {
    ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator());
    if (RI == nullptr || RI == IgnoreRI) continue;

    // We can only perform this transformation if the value returned is
    // evaluatable at the start of the initial invocation of the function,
    // instead of at the end of the evaluation.
    //
    Value *RetOp = RI->getOperand(0);
    if (!isDynamicConstant(RetOp, CI, RI))
      return nullptr;

    if (ReturnedValue && RetOp != ReturnedValue)
      return nullptr;     // Cannot transform if differing values are returned.
    ReturnedValue = RetOp;
  }
  return ReturnedValue;
}
コード例 #3
0
// getCommonReturnValue - Check to see if the function containing the specified
// tail call consistently returns the same runtime-constant value at all exit
// points except for IgnoreRI.  If so, return the returned value.
//
static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) {
  Function *F = CI->getParent()->getParent();
  Value *ReturnedValue = 0;

  for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) {
    ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator());
    if (RI == 0 || RI == IgnoreRI) continue;

    // We can only perform this transformation if the value returned is
    // evaluatable at the start of the initial invocation of the function,
    // instead of at the end of the evaluation.
    //
    Value *RetOp = RI->getOperand(0);
    if (!isDynamicConstant(RetOp, CI, RI))
      return 0;

    if (ReturnedValue && RetOp != ReturnedValue)
      return 0;     // Cannot transform if differing values are returned.
    ReturnedValue = RetOp;
  }
  return ReturnedValue;
}
コード例 #4
0
ファイル: SimplifyCFGPass.cpp プロジェクト: Xmister/llvm-onex
/// mergeEmptyReturnBlocks - If we have more than one empty (other than phi
/// node) return blocks, merge them together to promote recursive block merging.
static bool mergeEmptyReturnBlocks(Function &F) {
  bool Changed = false;

  BasicBlock *RetBlock = 0;

  // Scan all the blocks in the function, looking for empty return blocks.
  for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; ) {
    BasicBlock &BB = *BBI++;

    // Only look at return blocks.
    ReturnInst *Ret = dyn_cast<ReturnInst>(BB.getTerminator());
    if (Ret == 0) continue;

    // Only look at the block if it is empty or the only other thing in it is a
    // single PHI node that is the operand to the return.
    if (Ret != &BB.front()) {
      // Check for something else in the block.
      BasicBlock::iterator I = Ret;
      --I;
      // Skip over debug info.
      while (isa<DbgInfoIntrinsic>(I) && I != BB.begin())
        --I;
      if (!isa<DbgInfoIntrinsic>(I) &&
          (!isa<PHINode>(I) || I != BB.begin() ||
           Ret->getNumOperands() == 0 ||
           Ret->getOperand(0) != I))
        continue;
    }

    // If this is the first returning block, remember it and keep going.
    if (RetBlock == 0) {
      RetBlock = &BB;
      continue;
    }

    // Otherwise, we found a duplicate return block.  Merge the two.
    Changed = true;

    // Case when there is no input to the return or when the returned values
    // agree is trivial.  Note that they can't agree if there are phis in the
    // blocks.
    if (Ret->getNumOperands() == 0 ||
        Ret->getOperand(0) ==
          cast<ReturnInst>(RetBlock->getTerminator())->getOperand(0)) {
      BB.replaceAllUsesWith(RetBlock);
      BB.eraseFromParent();
      continue;
    }

    // If the canonical return block has no PHI node, create one now.
    PHINode *RetBlockPHI = dyn_cast<PHINode>(RetBlock->begin());
    if (RetBlockPHI == 0) {
      Value *InVal = cast<ReturnInst>(RetBlock->getTerminator())->getOperand(0);
      pred_iterator PB = pred_begin(RetBlock), PE = pred_end(RetBlock);
      RetBlockPHI = PHINode::Create(Ret->getOperand(0)->getType(),
                                    std::distance(PB, PE), "merge",
                                    &RetBlock->front());

      for (pred_iterator PI = PB; PI != PE; ++PI)
        RetBlockPHI->addIncoming(InVal, *PI);
      RetBlock->getTerminator()->setOperand(0, RetBlockPHI);
    }

    // Turn BB into a block that just unconditionally branches to the return
    // block.  This handles the case when the two return blocks have a common
    // predecessor but that return different things.
    RetBlockPHI->addIncoming(Ret->getOperand(0), &BB);
    BB.getTerminator()->eraseFromParent();
    BranchInst::Create(RetBlock, &BB);
  }

  return Changed;
}
コード例 #5
0
ファイル: Local.cpp プロジェクト: brills/pfpa
void GraphBuilder::visitReturnInst(ReturnInst &RI) {
  if (RI.getNumOperands() && isa<PointerType>(RI.getOperand(0)->getType()))
    G.getOrCreateReturnNodeFor(*FB).mergeWith(getValueDest(RI.getOperand(0)));
}
コード例 #6
0
void LowerEmAsyncify::transformAsyncFunction(Function &F, Instructions const& AsyncCalls) {
  assert(!AsyncCalls.empty());

  // Pass 0
  // collect all the return instructions from the original function
  // will use later
  std::vector<ReturnInst*> OrigReturns;
  for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E; ++I) {
    if (ReturnInst *RI = dyn_cast<ReturnInst>(&*I)) {
      OrigReturns.push_back(RI);
    }
  }

  // Pass 1
  // Scan each async call and make the basic structure:
  // All these will be cloned into the callback functions
  // - allocate the async context before calling an async function
  // - check async right after calling an async function, save context & return if async, continue if not
  // - retrieve the async return value and free the async context if the called function turns out to be sync
  std::vector<AsyncCallEntry> AsyncCallEntries;
  AsyncCallEntries.reserve(AsyncCalls.size());
  for (Instructions::const_iterator I = AsyncCalls.begin(), E = AsyncCalls.end(); I != E; ++I) {
    // prepare blocks
    Instruction *CurAsyncCall = *I;

    // The block containing the async call
    BasicBlock *CurBlock = CurAsyncCall->getParent();
    // The block should run after the async call
    BasicBlock *AfterCallBlock = SplitBlock(CurBlock, CurAsyncCall->getNextNode());
    // The block where we store the context and return
    BasicBlock *SaveAsyncCtxBlock = BasicBlock::Create(TheModule->getContext(), "SaveAsyncCtx", &F, AfterCallBlock);
    // return a dummy value at the end, to make the block valid
    new UnreachableInst(TheModule->getContext(), SaveAsyncCtxBlock);

    // allocate the context before making the call
    // we don't know the size yet, will fix it later
    // we cannot insert the instruction later because,
    // we need to make sure that all the instructions and blocks are fixed before we can generate DT and find context variables
    // In CallHandler.h `sp` will be put as the second parameter
    // such that we can take a note of the original sp 
    CallInst *AllocAsyncCtxInst = CallInst::Create(AllocAsyncCtxFunction, Constant::getNullValue(I32), "AsyncCtx", CurAsyncCall);

    // Right after the call
    // check async and return if so
    // TODO: we can define truly async functions and partial async functions
    {
      // remove old terminator, which came from SplitBlock
      CurBlock->getTerminator()->eraseFromParent();
      // go to SaveAsyncCtxBlock if the previous call is async
      // otherwise just continue to AfterCallBlock
      CallInst *CheckAsync = CallInst::Create(CheckAsyncFunction, "IsAsync", CurBlock);
      BranchInst::Create(SaveAsyncCtxBlock, AfterCallBlock, CheckAsync, CurBlock);
    }

    // take a note of this async call
    AsyncCallEntry CurAsyncCallEntry;
    CurAsyncCallEntry.AsyncCallInst = CurAsyncCall;
    CurAsyncCallEntry.AfterCallBlock = AfterCallBlock;
    CurAsyncCallEntry.AllocAsyncCtxInst = AllocAsyncCtxInst;
    CurAsyncCallEntry.SaveAsyncCtxBlock = SaveAsyncCtxBlock;
    // create an empty function for the callback, which will be constructed later
    CurAsyncCallEntry.CallbackFunc = Function::Create(CallbackFunctionType, F.getLinkage(), F.getName() + "__async_cb", TheModule);
    AsyncCallEntries.push_back(CurAsyncCallEntry);
  }


  // Pass 2
  // analyze the context variables and construct SaveAsyncCtxBlock for each async call
  // also calculate the size of the context and allocate the async context accordingly
  for (std::vector<AsyncCallEntry>::iterator EI = AsyncCallEntries.begin(), EE = AsyncCallEntries.end();  EI != EE; ++EI) {
    AsyncCallEntry & CurEntry = *EI;

    // Collect everything to be saved
    FindContextVariables(CurEntry);

    // Pack the variables as a struct
    {
      // TODO: sort them from large memeber to small ones, in order to make the struct compact even when aligned
      SmallVector<Type*, 8> Types;
      Types.push_back(CallbackFunctionType->getPointerTo());
      for (Values::iterator VI = CurEntry.ContextVariables.begin(), VE = CurEntry.ContextVariables.end(); VI != VE; ++VI) {
        Types.push_back((*VI)->getType());
      }
      CurEntry.ContextStructType = StructType::get(TheModule->getContext(), Types);
    }

    // fix the size of allocation
    CurEntry.AllocAsyncCtxInst->setOperand(0, 
        ConstantInt::get(I32, DL->getTypeStoreSize(CurEntry.ContextStructType)));

    // construct SaveAsyncCtxBlock
    {
      // fill in SaveAsyncCtxBlock
      // temporarily remove the terminator for convenience
      CurEntry.SaveAsyncCtxBlock->getTerminator()->eraseFromParent();
      assert(CurEntry.SaveAsyncCtxBlock->empty());

      Type *AsyncCtxAddrTy = CurEntry.ContextStructType->getPointerTo();
      BitCastInst *AsyncCtxAddr = new BitCastInst(CurEntry.AllocAsyncCtxInst, AsyncCtxAddrTy, "AsyncCtxAddr", CurEntry.SaveAsyncCtxBlock);
      SmallVector<Value*, 2> Indices;
      // store the callback
      {
        Indices.push_back(ConstantInt::get(I32, 0));
        Indices.push_back(ConstantInt::get(I32, 0));
        GetElementPtrInst *AsyncVarAddr = GetElementPtrInst::Create(AsyncCtxAddrTy, AsyncCtxAddr, Indices, "", CurEntry.SaveAsyncCtxBlock);
        new StoreInst(CurEntry.CallbackFunc, AsyncVarAddr, CurEntry.SaveAsyncCtxBlock);
      }
      // store the context variables
      for (size_t i = 0; i < CurEntry.ContextVariables.size(); ++i) {
        Indices.clear();
        Indices.push_back(ConstantInt::get(I32, 0));
        Indices.push_back(ConstantInt::get(I32, i + 1)); // the 0th element is the callback function
        GetElementPtrInst *AsyncVarAddr = GetElementPtrInst::Create(AsyncCtxAddrTy, AsyncCtxAddr, Indices, "", CurEntry.SaveAsyncCtxBlock);
        new StoreInst(CurEntry.ContextVariables[i], AsyncVarAddr, CurEntry.SaveAsyncCtxBlock);
      }
      // to exit the block, we want to return without unwinding the stack frame
      CallInst::Create(DoNotUnwindFunction, "", CurEntry.SaveAsyncCtxBlock);
      ReturnInst::Create(TheModule->getContext(), 
          (F.getReturnType()->isVoidTy() ? 0 : Constant::getNullValue(F.getReturnType())),
          CurEntry.SaveAsyncCtxBlock);
    }
  }

  // Pass 3
  // now all the SaveAsyncCtxBlock's have been constructed
  // we can clone F and construct callback functions 
  // we could not construct the callbacks in Pass 2 because we need _all_ those SaveAsyncCtxBlock's appear in _each_ callback
  for (std::vector<AsyncCallEntry>::iterator EI = AsyncCallEntries.begin(), EE = AsyncCallEntries.end();  EI != EE; ++EI) {
    AsyncCallEntry & CurEntry = *EI;

    Function *CurCallbackFunc = CurEntry.CallbackFunc;
    ValueToValueMapTy VMap;

    // Add the entry block
    // load variables from the context
    // also update VMap for CloneFunction
    BasicBlock *EntryBlock = BasicBlock::Create(TheModule->getContext(), "AsyncCallbackEntry", CurCallbackFunc);
    std::vector<LoadInst *> LoadedAsyncVars;
    {
      Type *AsyncCtxAddrTy = CurEntry.ContextStructType->getPointerTo();
      BitCastInst *AsyncCtxAddr = new BitCastInst(CurCallbackFunc->arg_begin(), AsyncCtxAddrTy, "AsyncCtx", EntryBlock);
      SmallVector<Value*, 2> Indices;
      for (size_t i = 0; i < CurEntry.ContextVariables.size(); ++i) {
        Indices.clear();
        Indices.push_back(ConstantInt::get(I32, 0));
        Indices.push_back(ConstantInt::get(I32, i + 1)); // the 0th element of AsyncCtx is the callback function
        GetElementPtrInst *AsyncVarAddr = GetElementPtrInst::Create(AsyncCtxAddrTy, AsyncCtxAddr, Indices, "", EntryBlock);
        LoadedAsyncVars.push_back(new LoadInst(AsyncVarAddr, "", EntryBlock));
        // we want the argument to be replaced by the loaded value
        if (isa<Argument>(CurEntry.ContextVariables[i]))
          VMap[CurEntry.ContextVariables[i]] = LoadedAsyncVars.back();
      }
    }

    // we don't need any argument, just leave dummy entries there to cheat CloneFunctionInto
    for (Function::const_arg_iterator AI = F.arg_begin(), AE = F.arg_end(); AI != AE; ++AI) {
      if (VMap.count(AI) == 0)
        VMap[AI] = Constant::getNullValue(AI->getType());
    }

    // Clone the function
    {
      SmallVector<ReturnInst*, 8> Returns;
      CloneFunctionInto(CurCallbackFunc, &F, VMap, false, Returns);
      
      // return type of the callback functions is always void
      // need to fix the return type
      if (!F.getReturnType()->isVoidTy()) {
        // for those return instructions that are from the original function
        // it means we are 'truly' leaving this function
        // need to store the return value right before ruturn
        for (size_t i = 0; i < OrigReturns.size(); ++i) {
          ReturnInst *RI = cast<ReturnInst>(VMap[OrigReturns[i]]);
          // Need to store the return value into the global area
          CallInst *RawRetValAddr = CallInst::Create(GetAsyncReturnValueAddrFunction, "", RI);
          BitCastInst *RetValAddr = new BitCastInst(RawRetValAddr, F.getReturnType()->getPointerTo(), "AsyncRetValAddr", RI);
          new StoreInst(RI->getOperand(0), RetValAddr, RI);
        }
        // we want to unwind the stack back to where it was before the original function as called
        // but we don't actually need to do this here
        // at this point it must be true that no callback is pended
        // so the scheduler will correct the stack pointer and pop the frame
        // here we just fix the return type
        for (size_t i = 0; i < Returns.size(); ++i) {
          ReplaceInstWithInst(Returns[i], ReturnInst::Create(TheModule->getContext()));
        }
      }
    }

    // the callback function does not have any return value
    // so clear all the attributes for return
    {
      AttributeSet Attrs = CurCallbackFunc->getAttributes();
      CurCallbackFunc->setAttributes(
        Attrs.removeAttributes(TheModule->getContext(), AttributeSet::ReturnIndex, Attrs.getRetAttributes())
      );
    }

    // in the callback function, we never allocate a new async frame
    // instead we reuse the existing one
    for (std::vector<AsyncCallEntry>::iterator EI = AsyncCallEntries.begin(), EE = AsyncCallEntries.end();  EI != EE; ++EI) {
      Instruction *I = cast<Instruction>(VMap[EI->AllocAsyncCtxInst]);
      ReplaceInstWithInst(I, CallInst::Create(ReallocAsyncCtxFunction, I->getOperand(0), "ReallocAsyncCtx"));
    }

    // mapped entry point & async call
    BasicBlock *ResumeBlock = cast<BasicBlock>(VMap[CurEntry.AfterCallBlock]);
    Instruction *MappedAsyncCall = cast<Instruction>(VMap[CurEntry.AsyncCallInst]);
   
    // To save space, for each async call in the callback function, we just ignore the sync case, and leave it to the scheduler
    // TODO need an option for this
    {
      for (std::vector<AsyncCallEntry>::iterator EI = AsyncCallEntries.begin(), EE = AsyncCallEntries.end();  EI != EE; ++EI) {
        AsyncCallEntry & CurEntry = *EI;
        Instruction *MappedAsyncCallInst = cast<Instruction>(VMap[CurEntry.AsyncCallInst]);
        BasicBlock *MappedAsyncCallBlock = MappedAsyncCallInst->getParent();
        BasicBlock *MappedAfterCallBlock = cast<BasicBlock>(VMap[CurEntry.AfterCallBlock]);

        // for the sync case of the call, go to NewBlock (instead of MappedAfterCallBlock)
        BasicBlock *NewBlock = BasicBlock::Create(TheModule->getContext(), "", CurCallbackFunc, MappedAfterCallBlock);
        MappedAsyncCallBlock->getTerminator()->setSuccessor(1, NewBlock);
        // store the return value
        if (!MappedAsyncCallInst->use_empty()) {
          CallInst *RawRetValAddr = CallInst::Create(GetAsyncReturnValueAddrFunction, "", NewBlock);
          BitCastInst *RetValAddr = new BitCastInst(RawRetValAddr, MappedAsyncCallInst->getType()->getPointerTo(), "AsyncRetValAddr", NewBlock);
          new StoreInst(MappedAsyncCallInst, RetValAddr, NewBlock);
        }
        // tell the scheduler that we want to keep the current async stack frame
        CallInst::Create(DoNotUnwindAsyncFunction, "", NewBlock);
        // finally we go to the SaveAsyncCtxBlock, to register the callbac, save the local variables and leave
        BasicBlock *MappedSaveAsyncCtxBlock = cast<BasicBlock>(VMap[CurEntry.SaveAsyncCtxBlock]);
        BranchInst::Create(MappedSaveAsyncCtxBlock, NewBlock);
      }
    }

    std::vector<AllocaInst*> ToPromote;
    // applying loaded variables in the entry block
    {
      BasicBlockSet ReachableBlocks = FindReachableBlocksFrom(ResumeBlock);
      for (size_t i = 0; i < CurEntry.ContextVariables.size(); ++i) {
        Value *OrigVar = CurEntry.ContextVariables[i];
        if (isa<Argument>(OrigVar)) continue; // already processed
        Value *CurVar = VMap[OrigVar];
        assert(CurVar != MappedAsyncCall);
        if (Instruction *Inst = dyn_cast<Instruction>(CurVar)) {
          if (ReachableBlocks.count(Inst->getParent())) {
            // Inst could be either defined or loaded from the async context
            // Do the dirty works in memory
            // TODO: might need to check the safety first
            // TODO: can we create phi directly?
            AllocaInst *Addr = DemoteRegToStack(*Inst, false);
            new StoreInst(LoadedAsyncVars[i], Addr, EntryBlock);
            ToPromote.push_back(Addr);
          } else {
            // The parent block is not reachable, which means there is no confliction
            // it's safe to replace Inst with the loaded value
            assert(Inst != LoadedAsyncVars[i]); // this should only happen when OrigVar is an Argument
            Inst->replaceAllUsesWith(LoadedAsyncVars[i]); 
          }
        }
      }
    }

    // resolve the return value of the previous async function
    // it could be the value just loaded from the global area
    // or directly returned by the function (in its sync case)
    if (!CurEntry.AsyncCallInst->use_empty()) {
      // load the async return value
      CallInst *RawRetValAddr = CallInst::Create(GetAsyncReturnValueAddrFunction, "", EntryBlock);
      BitCastInst *RetValAddr = new BitCastInst(RawRetValAddr, MappedAsyncCall->getType()->getPointerTo(), "AsyncRetValAddr", EntryBlock);
      LoadInst *RetVal = new LoadInst(RetValAddr, "AsyncRetVal", EntryBlock);

      AllocaInst *Addr = DemoteRegToStack(*MappedAsyncCall, false);
      new StoreInst(RetVal, Addr, EntryBlock);
      ToPromote.push_back(Addr);
    }

    // TODO remove unreachable blocks before creating phi
   
    // We go right to ResumeBlock from the EntryBlock
    BranchInst::Create(ResumeBlock, EntryBlock);
   
    /*
     * Creating phi's
     * Normal stack frames and async stack frames are interleaving with each other.
     * In a callback function, if we call an async function, we might need to realloc the async ctx.
     * at this point we don't want anything stored after the ctx, 
     * such that we can free and extend the ctx by simply update STACKTOP.
     * Therefore we don't want any alloca's in callback functions.
     *
     */
    if (!ToPromote.empty()) {
      DominatorTreeWrapperPass DTW;
      DTW.runOnFunction(*CurCallbackFunc);
      PromoteMemToReg(ToPromote, DTW.getDomTree());
    }

    removeUnreachableBlocks(*CurCallbackFunc);
  }

  // Pass 4
  // Here are modifications to the original function, which we won't want to be cloned into the callback functions
  for (std::vector<AsyncCallEntry>::iterator EI = AsyncCallEntries.begin(), EE = AsyncCallEntries.end();  EI != EE; ++EI) {
    AsyncCallEntry & CurEntry = *EI;
    // remove the frame if no async functinon has been called
    CallInst::Create(FreeAsyncCtxFunction, CurEntry.AllocAsyncCtxInst, "", CurEntry.AfterCallBlock->getFirstNonPHI());
  }
}
コード例 #7
0
//
// Method: runOnModule()
//
// Description:
//  Entry point for this LLVM pass.
//  If a function returns a struct, make it return
//  a pointer to the struct.
//
// Inputs:
//  M - A reference to the LLVM module to transform
//
// Outputs:
//  M - The transformed LLVM module.
//
// Return value:
//  true  - The module was modified.
//  false - The module was not modified.
//
bool StructRet::runOnModule(Module& M) {
  const llvm::DataLayout targetData(&M);

  std::vector<Function*> worklist;
  for (Module::iterator I = M.begin(); I != M.end(); ++I)
    if (!I->mayBeOverridden()) {
      if(I->hasAddressTaken())
        continue;
      if(I->getReturnType()->isStructTy()) {
        worklist.push_back(I);
      }
    }

  while(!worklist.empty()) {
    Function *F = worklist.back();
    worklist.pop_back();
    Type *NewArgType = F->getReturnType()->getPointerTo();

    // Construct the new Type
    std::vector<Type*>TP;
    TP.push_back(NewArgType);
    for (Function::arg_iterator ii = F->arg_begin(), ee = F->arg_end();
         ii != ee; ++ii) {
      TP.push_back(ii->getType());
    }

    FunctionType *NFTy = FunctionType::get(F->getReturnType(), TP, F->isVarArg());

    // Create the new function body and insert it into the module.
    Function *NF = Function::Create(NFTy, 
                                    F->getLinkage(),
                                    F->getName(), &M);
    ValueToValueMapTy ValueMap;
    Function::arg_iterator NI = NF->arg_begin();
    NI->setName("ret");
    ++NI;
    for (Function::arg_iterator II = F->arg_begin(); II != F->arg_end(); ++II, ++NI) {
      ValueMap[II] = NI;
      NI->setName(II->getName());
      AttributeSet attrs = F->getAttributes().getParamAttributes(II->getArgNo() + 1);
      if (!attrs.isEmpty())
        NI->addAttr(attrs);
    }
    // Perform the cloning.
    SmallVector<ReturnInst*,100> Returns;
    if (!F->isDeclaration())
      CloneFunctionInto(NF, F, ValueMap, false, Returns);
    std::vector<Value*> fargs;
    for(Function::arg_iterator ai = NF->arg_begin(), 
        ae= NF->arg_end(); ai != ae; ++ai) {
      fargs.push_back(ai);
    }
    NF->setAttributes(NF->getAttributes().addAttributes(
        M.getContext(), 0, F->getAttributes().getRetAttributes()));
    NF->setAttributes(NF->getAttributes().addAttributes(
        M.getContext(), ~0, F->getAttributes().getFnAttributes()));
    
    for (Function::iterator B = NF->begin(), FE = NF->end(); B != FE; ++B) {      
      for (BasicBlock::iterator I = B->begin(), BE = B->end(); I != BE;) {
        ReturnInst * RI = dyn_cast<ReturnInst>(I++);
        if(!RI)
          continue;
        LoadInst *LI = dyn_cast<LoadInst>(RI->getOperand(0));
        assert(LI && "Return should be preceded by a load instruction");
        IRBuilder<> Builder(RI);
        Builder.CreateMemCpy(fargs.at(0),
            LI->getPointerOperand(),
            targetData.getTypeStoreSize(LI->getType()),
            targetData.getPrefTypeAlignment(LI->getType()));
      }
    }

    for(Value::use_iterator ui = F->use_begin(), ue = F->use_end();
        ui != ue; ) {
      CallInst *CI = dyn_cast<CallInst>(*ui++);
      if(!CI)
        continue;
      if(CI->getCalledFunction() != F)
        continue;
      if(CI->hasByValArgument())
        continue;
      AllocaInst *AllocaNew = new AllocaInst(F->getReturnType(), 0, "", CI);
      SmallVector<Value*, 8> Args;

      //this should probably be done in a different manner
      AttributeSet NewCallPAL=AttributeSet();
      
      // Get the initial attributes of the call
      AttributeSet CallPAL = CI->getAttributes();
      AttributeSet RAttrs = CallPAL.getRetAttributes();
      AttributeSet FnAttrs = CallPAL.getFnAttributes();
      
      if (!RAttrs.isEmpty())
        NewCallPAL=NewCallPAL.addAttributes(F->getContext(),0, RAttrs);

      Args.push_back(AllocaNew);
      for(unsigned j = 0; j < CI->getNumOperands()-1; j++) {
        Args.push_back(CI->getOperand(j));
        // position in the NewCallPAL
        AttributeSet Attrs = CallPAL.getParamAttributes(j);
        if (!Attrs.isEmpty())
          NewCallPAL=NewCallPAL.addAttributes(F->getContext(),Args.size(), Attrs);
      }
      // Create the new attributes vec.
      if (!FnAttrs.isEmpty())
        NewCallPAL=NewCallPAL.addAttributes(F->getContext(),~0, FnAttrs);

      CallInst *CallI = CallInst::Create(NF, Args, "", CI);
      CallI->setCallingConv(CI->getCallingConv());
      CallI->setAttributes(NewCallPAL);
      LoadInst *LI = new LoadInst(AllocaNew, "", CI);
      CI->replaceAllUsesWith(LI);
      CI->eraseFromParent();
    }
    if(F->use_empty())
      F->eraseFromParent();
  }
  return true;
}