예제 #1
0
/// lowerUDIV - Given an UDiv expressing a divide by constant,
/// replace it by multiplying by a magic number.  See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
bool lowerUDiv(Instruction *inst) const {
  ConstantInt *Op1 = dyn_cast<ConstantInt>(inst->getOperand(1));
  // check if dividing by a constant
  if (!Op1 || inst->getType()->getPrimitiveSizeInBits() != 32) {
      return false;
  }

  BasicBlock::iterator ii(inst);
  Value *Op0 = inst->getOperand(0);

  APInt::mu magics = Op1->getValue().magicu();

  IntegerType * type64 = IntegerType::get(Mod->getContext(), 64);

  Instruction *zext = CastInst::CreateZExtOrBitCast(Op0, type64, "", inst);
  APInt m = APInt(64, magics.m.getZExtValue());
  Constant *magicNum = ConstantInt::get(type64, m);

  Instruction *magInst = BinaryOperator::CreateNSWMul(zext, magicNum, "", inst);

  if (!magics.a) {
    APInt apS = APInt(64, 32 + magics.s);
    Constant *magicShift = ConstantInt::get(type64, apS);
    Instruction *bigResult = BinaryOperator::Create(Instruction::LShr, magInst,
      magicShift, "", inst);
    Instruction *result = CastInst::CreateTruncOrBitCast(bigResult,
      inst->getType(), "");

    ReplaceInstWithInst(inst->getParent()->getInstList(), ii, result);
  } else {
    APInt ap = APInt(64, 32);
    Constant *movHiConst = ConstantInt::get(type64, ap);
    Instruction *movHi = BinaryOperator::Create(Instruction::LShr, magInst, 
      movHiConst, "", inst);
    Instruction *trunc = CastInst::CreateTruncOrBitCast(movHi, inst->getType(),
      "", inst);
    Instruction *sub = BinaryOperator::Create(Instruction::Sub, Op0, trunc, "",
      inst);
    APInt ap1 = APInt(32, 1);
    Constant *one = ConstantInt::get(inst->getType(), ap1);
    Instruction *shift = BinaryOperator::Create(Instruction::LShr, sub, one, "",
      inst);
    Instruction *add = BinaryOperator::Create(Instruction::Add, shift, trunc,
      "", inst);
    APInt apShr = APInt(32, magics.s - 1);
    Constant *shr = ConstantInt::get(inst->getType(), apShr);
    Instruction *result = BinaryOperator::Create(Instruction::LShr, add,
      shr, "");

    ReplaceInstWithInst(inst->getParent()->getInstList(), ii, result);
  }
  return true;
}
예제 #2
0
void replaceAllCallsWith(Value* OldFunc, Value* NewFunc) {
  
  for (Value::use_iterator I = OldFunc->use_begin(), E = OldFunc->use_end(); I != E; ++I) {

    if (CallInst* call = dyn_cast<CallInst>(*I)) {
    
      std::vector<Value*> args;
      for(int i=0; i<call->getNumArgOperands(); i++) {
        args.push_back(call->getArgOperand(i));
      }
      ArrayRef<Value*> Args(args);
  
      CallInst *newCall = CallInst::Create(NewFunc, Args);
      if (newCall->getType() != call->getType()) {
        if (call->use_begin() != call->use_end()) {
          errs() << "Cannot handle usage of non matching return types for " << *call->getType() << " and " << *newCall->getType() << "\n";
        }

        newCall->insertBefore(call);
        call->replaceAllUsesWith(newCall);
        call->eraseFromParent();
    
      } else {
        ReplaceInstWithInst(call, newCall);
      }
    } else {
      (*I)->print(errs()); errs() << "\n";
      exit(1);
    }
  }
}
예제 #3
0
/// lowerSDIV - Given an SDiv expressing a divide by constant,
/// replace it by multiplying by a magic number.  See:
/// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
bool lowerSDiv(Instruction *inst) const {
  ConstantInt *Op1 = dyn_cast<ConstantInt>(inst->getOperand(1));
  // check if dividing by a constant and not by a power of 2
  if (!Op1 || inst->getType()->getPrimitiveSizeInBits() != 32 ||
          Op1->getValue().isPowerOf2()) {
      return false;
  }

  BasicBlock::iterator ii(inst);
  Value *Op0 = inst->getOperand(0);
  APInt d = Op1->getValue();

  APInt::ms magics = Op1->getValue().magic();

  IntegerType * type64 = IntegerType::get(Mod->getContext(), 64);

  Instruction *sext = CastInst::CreateSExtOrBitCast(Op0, type64, "", inst);
  APInt m = APInt(64, magics.m.getSExtValue());
  Constant *magicNum = ConstantInt::get(type64, m);

  Instruction *magInst = BinaryOperator::CreateNSWMul(sext, magicNum, "", inst);

  APInt ap = APInt(64, 32);
  Constant *movHiConst = ConstantInt::get(type64, ap);
  Instruction *movHi = BinaryOperator::Create(Instruction::AShr, magInst, 
    movHiConst, "", inst);
  Instruction *trunc = CastInst::CreateTruncOrBitCast(movHi, inst->getType(),
    "", inst);
  if (d.isStrictlyPositive() && magics.m.isNegative()) {
    trunc = BinaryOperator::Create(Instruction::Add, trunc, Op0, "", inst);
  } else if (d.isNegative() && magics.m.isStrictlyPositive()) {
    trunc = BinaryOperator::Create(Instruction::Sub, trunc, Op0, "", inst);
  }
  if (magics.s > 0) {
    APInt apS = APInt(32, magics.s);
    Constant *magicShift = ConstantInt::get(inst->getType(), apS);
    trunc = BinaryOperator::Create(Instruction::AShr, trunc,
      magicShift, "", inst);
  }

  APInt ap31 = APInt(32, 31);
  Constant *thirtyOne = ConstantInt::get(inst->getType(), ap31);
  // get sign bit
  Instruction *sign = BinaryOperator::Create(Instruction::LShr, trunc,
    thirtyOne, "", inst);
  
  Instruction *result = BinaryOperator::Create(Instruction::Add, trunc, sign,
    "");
  ReplaceInstWithInst(inst->getParent()->getInstList(), ii, result);
  return true;
}
예제 #4
0
파일: if.cpp 프로젝트: apavlo/peloton
void If::ElseBlock(const std::string &name) {
  // Create an unconditional branch from where we are (from the 'then' branch)
  // to the merging block
  cg_->CreateBr(merge_bb_);

  // Create a new else block
  else_bb_ = llvm::BasicBlock::Create(cg_.GetContext(), name, fn_);
  last_bb_in_else_ = else_bb_;

  // Replace the previous branch instruction that normally went to the merging
  // block on a false predicate to now branch into the new else block
  auto *new_branch =
      llvm::BranchInst::Create(then_bb_, else_bb_, branch_->getCondition());
  ReplaceInstWithInst(branch_, new_branch);

  last_bb_in_then_ = cg_->GetInsertBlock();

  cg_->SetInsertPoint(else_bb_);
}
예제 #5
0
// Inserts an unwinding annotation (assume or assert, depending on the function
// given in the constructor) and removes the loop.
bool RmLoopPass::runOnLoop(Loop *L, LPPassManager &LPM){

  BasicBlock *latch  = L -> getLoopLatch();
  BasicBlock *header = L -> getHeader();
  SmallVector<BasicBlock *, 1> exitBBs;
  L -> getExitBlocks(exitBBs);
  BasicBlock *exitBB = NULL;
  SmallVector<BasicBlock *, 1>::iterator it = exitBBs.begin();
  for(; it != exitBBs.end() && !exitBB; ++it){
    if(std::find(createdBB.begin(), createdBB.end(), *it) == createdBB.end())
      exitBB = *it;
  }

  assert(exitBB && "exitBB is null");

  // std::cout << "\n\n LOOP REMOVAL:\n";
  // std::cout << "Latch: " << latch -> getName().str() << "\n";
  // std::cout << "Header: " << header -> getName().str() << "\n";
  // std::cout << "ExitBB: " << exitBB -> getName().str() << "\n";
  //assert(exitBBs.size() == 1 && "RmLoopPass - more than one exit BB");
  // At this point we have an header, a latch and an exit BasicBlock 
  // and they all are different
  assert(latch && header
        && "Not able to obtain some loop basic block; try to run doInitialization before");
  
  // Get loop last branch instruction
  BranchInst *br = cast<BranchInst>(latch -> getTerminator());
  // assert(br -> isConditional() && "loop terminator with unconditional branch");
  // Get loop's last iteration condition
  Value *cond = NULL;      // Loop last iteration branch condition
  if(br -> isConditional())
    cond = br -> getCondition();
  else{
    std::cout << "\n************************************************************\n";
    std::cout <<   "*BE CAREFUL!!!! There is a latch with unconditional branch!*\n";
    std::cout <<   "************************************************************\n";
  }

  // In order to remove the back edge, we need to remove the loop from the LPPAssManager
  LPM.deleteLoopFromQueue(L);

  // Create a new BasicBlock with the unwinding annotation.
  // Unreachable instruction is used as a terminator instruction in this BasicBlock
  BasicBlock *newBB = BasicBlock::Create(header -> getContext()
                                        , "unwinding_annotation"
                                        , header -> getParent());
  createdBB.push_back(newBB);
  Type *t = Type::getInt32Ty(header -> getContext());
  Constant *c = llvm::ConstantInt::get(t,uint32_t(0));
  ArrayRef<Value *> *param = new ArrayRef<Value *>(c);
  CallInst::Create(function, *param, "", newBB);
  if(unreachable){
    new UnreachableInst(header -> getContext(), newBB); 
  }else{
    BranchInst::Create(exitBB,newBB);
    for(BasicBlock::iterator it = exitBB->begin(); it != exitBB->end();++it){
      PHINode *phi = dyn_cast<PHINode>(it);
      if(!phi)
        break;
      //Value *latchValue = phi->getIncomingValueForBlock(latch);
      phi->addIncoming(UndefValue::get(phi -> getType()),newBB);
    }
  }

  BranchInst *newBr = NULL;
  if(cond){
    if(br -> getSuccessor(0) == header){
      newBr = BranchInst::Create(newBB,br -> getSuccessor(1),cond);
    }else{
      newBr = BranchInst::Create(br -> getSuccessor(1),newBB,cond);
    }
  }else{
    newBr = BranchInst::Create(newBB);
  }
  ReplaceInstWithInst(br,newBr);

  // The latch BasicBlock must be removed from the PHI nodes in
  // the header BasicBlock
  for(BasicBlock::iterator it = header->begin(); it != header->end();++it){
    PHINode *phi = dyn_cast<PHINode>(it);
    if(!phi)
      break;
    int latchIndex = phi->getBasicBlockIndex(latch);
    phi->removeIncomingValue(latchIndex);
  }

 //std::cout << "\n---- NewBB ------\n";

 //newBB -> print(outs());

 //std::cout << "\n---- ExitBB\n";

 //exitBB -> print(outs());


  return true;
}
// =============================================================================
// replaceCallsInProcess
// 
// Replace indirect calls to write() or read() by direct calls 
// in the given process.
// =============================================================================
void TLMBasicPassImpl::replaceCallsInProcess(sc_core::sc_module *initiatorMod,
                                         sc_core::sc_process_b *proc) {
    
    // Get associate function
    std::string fctName = proc->func_process;
	std::string modType = typeid(*initiatorMod).name();
	std::string mainFctName = "_ZN" + modType + 
    utostr(fctName.size()) + fctName + "Ev";
	Function *oldProcf = this->llvmMod->getFunction(mainFctName);
    if (oldProcf==NULL)
        return;
    
    // We do not modifie the original function
    // Instead, we create a clone.
    Function *procf = createProcess(oldProcf, initiatorMod);
    void *funPtr = this->engine->getPointerToFunction(procf); 
    sc_core::SC_ENTRY_FUNC_OPT scfun = 
    reinterpret_cast<sc_core::SC_ENTRY_FUNC_OPT>(funPtr);
    proc->m_semantics_p = scfun;
    std::string procfName = procf->getName();
    MSG("      Replace in the process's function : "+procfName+"\n");
    
    std::ostringstream oss;
    sc_core::sc_module *targetMod;
    std::vector<CallInfo*> *work = new std::vector<CallInfo*>;
    
    inst_iterator ii;
    for (ii = inst_begin(procf); ii!=inst_end(procf); ii++) {
        Instruction &i = *ii;
        CallSite cs(&i);
        if (cs.getInstruction()) {
            // Candidate for a replacement
            Function *oldfun = cs.getCalledFunction();
            if (oldfun!=NULL && !oldfun->isDeclaration()) {
                std::string name = oldfun->getName();
                // === Write ===
                if (!strcmp(name.c_str(), wFunName.c_str())) {
                    
                    CallInfo *info = new CallInfo();
                    info->oldcall = dyn_cast<CallInst>(cs.getInstruction());
                    MSG("       Checking adress : ");
                    // Retrieve the adress argument by executing 
                    // the appropriated piece of code
                    SCJit *scjit = new SCJit(this->llvmMod, this->elab);
                    Process *irProc = this->elab->getProcess(proc);
                    scjit->setCurrentProcess(irProc);                    
                    bool jitErr = false;
                    info->addrArg = cs.getArgument(1);
                    int value = 
                    scjit->jitInt(procf, info->oldcall, info->addrArg, &jitErr);
                    if(jitErr) {
                        std::cout << "       cannot get the address value!" 
                          << std::endl;
                    } else {
                    oss.str("");  oss << std::hex << value;
                    MSG("0x"+oss.str()+"\n");
                    basic::addr_t a = static_cast<basic::addr_t>(value);            
                    
                    // Checking address alignment
                    if(value % sizeof(basic::data_t)) {
                        std::cerr << "  unaligned write : " <<
                        std::hex << value << std::endl;
                        abort();
                    }

                    // Retreive the target module using the address
                    targetMod =  getTargetModule(initiatorMod, a);
                                    
                    // Save informations to build a new call later
                    FunctionType *writeFunType = 
                        this->basicWriteFun->getFunctionType();  
                    info->targetType = writeFunType->getParamType(0);
                    LLVMContext &context = getGlobalContext();
                    IntegerType *intType;
                    if (this->is64Bit) {
                        intType = Type::getInt64Ty(context);
                    } else {
                        intType = Type::getInt32Ty(context);
                    }
                    info->targetModVal = ConstantInt::getSigned(intType,
                                        reinterpret_cast<intptr_t>(targetMod));
                    info->dataArg = cs.getArgument(2);
                    work->push_back(info);
                    }
   
                } else
                    
                // === Read ===
                if (!strcmp(name.c_str(), rFunName.c_str())) {
                    
                    // Not yet supported
                                        
                }
            }  
        }
    
    }
        
    // Before
    //procf->dump();
    
    // Replace calls
    std::vector<CallInfo*>::iterator it;
    for (it = work->begin(); it!=work->end(); ++it) {
        CallInfo *i = *it;
        
        LLVMContext &context = getGlobalContext();
        FunctionType *writeFunType = 
        this->writeFun->getFunctionType();
        IntegerType *i64 = Type::getInt64Ty(context);
        // Get a pointer to the target module
        basic::target_module_base *tmb = 
        dynamic_cast<basic::target_module_base*>(targetMod);
        Value *ptr = 
        ConstantInt::getSigned(i64, reinterpret_cast<intptr_t>(tmb));
        IntToPtrInst *modPtr = new IntToPtrInst(ptr, 
                                                writeFunType->getParamType(0),
                                                "myitp", i->oldcall);
        // Get a the address value
        LoadInst *addr = new LoadInst(i->addrArg, "", i->oldcall);
        
        // Create the new call
        Value *args[] = {modPtr, addr, i->dataArg};
        i->newcall = CallInst::Create(this->writeFun, ArrayRef<Value*>(args, 3));
        
        // Replace the old call
        BasicBlock::iterator it(i->oldcall);
        ReplaceInstWithInst(i->oldcall->getParent()->getInstList(), it, i->newcall);
        i->oldcall->replaceAllUsesWith(i->newcall);
        
        // Inline the new call
        DataLayout *td = new DataLayout(this->llvmMod);
        InlineFunctionInfo ifi(NULL, td);
        bool success = InlineFunction(i->newcall, ifi);
        if(!success) {
            MSG("       The call cannot be inlined (it's not an error :D)");
        }
        
        MSG("       Call optimized (^_-)\n");
        callOptCounter++;
    }
    
    //std::cout << "==================================\n";
    // Run preloaded passes on the function to propagate constants
    funPassManager->run(*procf);
    // After
    //procf->dump();        
    // Check if the function is corrupt
    verifyFunction(*procf);
    this->engine->recompileAndRelinkFunction(procf);
}
예제 #7
0
  bool runOnFunction(Function &Func) override {
    if (Func.isDeclaration()) {
      return false;
    }
    vector<BranchInst *> BIs;
    for (inst_iterator I = inst_begin(Func); I != inst_end(Func); I++) {
      Instruction *Inst = &(*I);
      if (BranchInst *BI = dyn_cast<BranchInst>(Inst)) {
        BIs.push_back(BI);
      }
    } // Finish collecting branching conditions
    Value *zero =
        ConstantInt::get(Type::getInt32Ty(Func.getParent()->getContext()), 0);
    for (BranchInst *BI : BIs) {
      IRBuilder<> IRB(BI);
      vector<BasicBlock *> BBs;
      // We use the condition's evaluation result to generate the GEP
      // instruction  False evaluates to 0 while true evaluates to 1.  So here
      // we insert the false block first
      if (BI->isConditional()) {
        BBs.push_back(BI->getSuccessor(1));
      }
      BBs.push_back(BI->getSuccessor(0));
      ArrayType *AT = ArrayType::get(
          Type::getInt8PtrTy(Func.getParent()->getContext()), BBs.size());
      vector<Constant *> BlockAddresses;
      for (unsigned i = 0; i < BBs.size(); i++) {
        BlockAddresses.push_back(BlockAddress::get(BBs[i]));
      }
      GlobalVariable *LoadFrom = NULL;

      if (BI->isConditional() || indexmap.find(BI->getSuccessor(0))==indexmap.end()) {
        // Create a new GV
        Constant *BlockAddressArray =
            ConstantArray::get(AT, ArrayRef<Constant *>(BlockAddresses));
        LoadFrom = new GlobalVariable(*Func.getParent(), AT, false,
                                      GlobalValue::LinkageTypes::PrivateLinkage,
                                      BlockAddressArray);
      } else {
        LoadFrom =
            Func.getParent()->getGlobalVariable("IndirectBranchingGlobalTable",true);
      }
      Value *index = NULL;
      if (BI->isConditional()) {
        Value *condition = BI->getCondition();
        index = IRB.CreateZExt(
            condition, Type::getInt32Ty(Func.getParent()->getContext()));
      } else {
        index =
            ConstantInt::get(Type::getInt32Ty(Func.getParent()->getContext()),
                             indexmap[BI->getSuccessor(0)]);
      }
      Value *GEP = IRB.CreateGEP(LoadFrom, {zero, index});
      LoadInst *LI = IRB.CreateLoad(GEP, "IndirectBranchingTargetAddress");
      IndirectBrInst *indirBr = IndirectBrInst::Create(LI, BBs.size());
      for (BasicBlock *BB : BBs) {
        indirBr->addDestination(BB);
      }
      ReplaceInstWithInst(BI, indirBr);
    }
    return true;
  }