Exemplo n.º 1
0
/// Handle frees of entire structures whose dependency is a store
/// to a field of that structure.
static bool handleFree(CallInst *F, AliasAnalysis *AA,
                       MemoryDependenceResults *MD, DominatorTree *DT,
                       const TargetLibraryInfo *TLI,
                       InstOverlapIntervalsTy &IOL,
                       DenseMap<Instruction*, size_t> *InstrOrdering) {
  bool MadeChange = false;

  MemoryLocation Loc = MemoryLocation(F->getOperand(0));
  SmallVector<BasicBlock *, 16> Blocks;
  Blocks.push_back(F->getParent());
  const DataLayout &DL = F->getModule()->getDataLayout();

  while (!Blocks.empty()) {
    BasicBlock *BB = Blocks.pop_back_val();
    Instruction *InstPt = BB->getTerminator();
    if (BB == F->getParent()) InstPt = F;

    MemDepResult Dep =
        MD->getPointerDependencyFrom(Loc, false, InstPt->getIterator(), BB);
    while (Dep.isDef() || Dep.isClobber()) {
      Instruction *Dependency = Dep.getInst();
      if (!hasMemoryWrite(Dependency, *TLI) || !isRemovable(Dependency))
        break;

      Value *DepPointer =
          GetUnderlyingObject(getStoredPointerOperand(Dependency), DL);

      // Check for aliasing.
      if (!AA->isMustAlias(F->getArgOperand(0), DepPointer))
        break;

      DEBUG(dbgs() << "DSE: Dead Store to soon to be freed memory:\n  DEAD: "
                   << *Dependency << '\n');

      // DCE instructions only used to calculate that store.
      BasicBlock::iterator BBI(Dependency);
      deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, InstrOrdering);
      ++NumFastStores;
      MadeChange = true;

      // Inst's old Dependency is now deleted. Compute the next dependency,
      // which may also be dead, as in
      //    s[0] = 0;
      //    s[1] = 0; // This has just been deleted.
      //    free(s);
      Dep = MD->getPointerDependencyFrom(Loc, false, BBI, BB);
    }

    if (Dep.isNonLocal())
      findUnconditionalPreds(Blocks, BB, DT);
  }

  return MadeChange;
}
Exemplo n.º 2
0
DataDependence::DepInfo DataDependence::getDepInfo(MemDepResult dep) {
  if (dep.isClobber())
    return DepInfo(dep.getInst(), Clobber);
  if (dep.isDef())
    return DepInfo(dep.getInst(), Def);
  if (dep.isNonFuncLocal())
    return DepInfo(dep.getInst(), NonFuncLocal);
  if (dep.isUnknown())
    return DepInfo(dep.getInst(), Unknown);
  if (dep.isNonLocal())
    return DepInfo(dep.getInst(), NonLocal);
  llvm_unreachable("unknown dependence type");
}
Exemplo n.º 3
0
/// HandleFree - Handle frees of entire structures whose dependency is a store
/// to a field of that structure.
bool DSE::HandleFree(CallInst *F) {
  bool MadeChange = false;

  MemoryLocation Loc = MemoryLocation(F->getOperand(0));
  SmallVector<BasicBlock *, 16> Blocks;
  Blocks.push_back(F->getParent());
  const DataLayout &DL = F->getModule()->getDataLayout();

  while (!Blocks.empty()) {
    BasicBlock *BB = Blocks.pop_back_val();
    Instruction *InstPt = BB->getTerminator();
    if (BB == F->getParent()) InstPt = F;

    MemDepResult Dep = MD->getPointerDependencyFrom(Loc, false, InstPt, BB);
    while (Dep.isDef() || Dep.isClobber()) {
      Instruction *Dependency = Dep.getInst();
      if (!hasMemoryWrite(Dependency, *TLI) || !isRemovable(Dependency))
        break;

      Value *DepPointer =
          GetUnderlyingObject(getStoredPointerOperand(Dependency), DL);

      // Check for aliasing.
      if (!AA->isMustAlias(F->getArgOperand(0), DepPointer))
        break;

      Instruction *Next = std::next(BasicBlock::iterator(Dependency));

      // DCE instructions only used to calculate that store
      DeleteDeadInstruction(Dependency, *MD, *TLI);
      ++NumFastStores;
      MadeChange = true;

      // Inst's old Dependency is now deleted. Compute the next dependency,
      // which may also be dead, as in
      //    s[0] = 0;
      //    s[1] = 0; // This has just been deleted.
      //    free(s);
      Dep = MD->getPointerDependencyFrom(Loc, false, Next, BB);
    }

    if (Dep.isNonLocal())
      FindUnconditionalPreds(Blocks, BB, DT);
  }

  return MadeChange;
}
Exemplo n.º 4
0
/// processMemCpy - perform simplification of memcpy's.  If we have memcpy A
/// which copies X to Y, and memcpy B which copies Y to Z, then we can rewrite
/// B to be a memcpy from X to Z (or potentially a memmove, depending on
/// circumstances). This allows later passes to remove the first memcpy
/// altogether.
bool MemCpyOpt::processMemCpy(MemCpyInst *M) {
  // We can only optimize non-volatile memcpy's.
  if (M->isVolatile()) return false;

  // If the source and destination of the memcpy are the same, then zap it.
  if (M->getSource() == M->getDest()) {
    MD->removeInstruction(M);
    M->eraseFromParent();
    return false;
  }

  // If copying from a constant, try to turn the memcpy into a memset.
  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(M->getSource()))
    if (GV->isConstant() && GV->hasDefinitiveInitializer())
      if (Value *ByteVal = isBytewiseValue(GV->getInitializer())) {
        IRBuilder<> Builder(M);
        Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(),
                             M->getAlignment(), false);
        MD->removeInstruction(M);
        M->eraseFromParent();
        ++NumCpyToSet;
        return true;
      }

  MemDepResult DepInfo = MD->getDependency(M);

  // Try to turn a partially redundant memset + memcpy into
  // memcpy + smaller memset.  We don't need the memcpy size for this.
  if (DepInfo.isClobber())
    if (MemSetInst *MDep = dyn_cast<MemSetInst>(DepInfo.getInst()))
      if (processMemSetMemCpyDependence(M, MDep))
        return true;

  // The optimizations after this point require the memcpy size.
  ConstantInt *CopySize = dyn_cast<ConstantInt>(M->getLength());
  if (!CopySize) return false;

  // There are four possible optimizations we can do for memcpy:
  //   a) memcpy-memcpy xform which exposes redundance for DSE.
  //   b) call-memcpy xform for return slot optimization.
  //   c) memcpy from freshly alloca'd space or space that has just started its
  //      lifetime copies undefined data, and we can therefore eliminate the
  //      memcpy in favor of the data that was already at the destination.
  //   d) memcpy from a just-memset'd source can be turned into memset.
  if (DepInfo.isClobber()) {
    if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) {
      if (performCallSlotOptzn(M, M->getDest(), M->getSource(),
                               CopySize->getZExtValue(), M->getAlignment(),
                               C)) {
        MD->removeInstruction(M);
        M->eraseFromParent();
        return true;
      }
    }
  }

  AliasAnalysis::Location SrcLoc = AliasAnalysis::getLocationForSource(M);
  MemDepResult SrcDepInfo = MD->getPointerDependencyFrom(SrcLoc, true,
                                                         M, M->getParent());

  if (SrcDepInfo.isClobber()) {
    if (MemCpyInst *MDep = dyn_cast<MemCpyInst>(SrcDepInfo.getInst()))
      return processMemCpyMemCpyDependence(M, MDep);
  } else if (SrcDepInfo.isDef()) {
    Instruction *I = SrcDepInfo.getInst();
    bool hasUndefContents = false;

    if (isa<AllocaInst>(I)) {
      hasUndefContents = true;
    } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
      if (II->getIntrinsicID() == Intrinsic::lifetime_start)
        if (ConstantInt *LTSize = dyn_cast<ConstantInt>(II->getArgOperand(0)))
          if (LTSize->getZExtValue() >= CopySize->getZExtValue())
            hasUndefContents = true;
    }

    if (hasUndefContents) {
      MD->removeInstruction(M);
      M->eraseFromParent();
      ++NumMemCpyInstr;
      return true;
    }
  }

  if (SrcDepInfo.isClobber())
    if (MemSetInst *MDep = dyn_cast<MemSetInst>(SrcDepInfo.getInst()))
      if (performMemCpyToMemSetOptzn(M, MDep)) {
        MD->removeInstruction(M);
        M->eraseFromParent();
        ++NumCpyToSet;
        return true;
      }

  return false;
}
Exemplo n.º 5
0
bool DSE::runOnBasicBlock(BasicBlock &BB) {
  bool MadeChange = false;

  // Do a top-down walk on the BB.
  for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) {
    Instruction *Inst = BBI++;

    // Handle 'free' calls specially.
    if (CallInst *F = isFreeCall(Inst, TLI)) {
      MadeChange |= HandleFree(F);
      continue;
    }

    // If we find something that writes memory, get its memory dependence.
    if (!hasMemoryWrite(Inst, TLI))
      continue;

    MemDepResult InstDep = MD->getDependency(Inst);

    // Ignore any store where we can't find a local dependence.
    // FIXME: cross-block DSE would be fun. :)
    if (!InstDep.isDef() && !InstDep.isClobber())
      continue;

    // If we're storing the same value back to a pointer that we just
    // loaded from, then the store can be removed.
    if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
      if (LoadInst *DepLoad = dyn_cast<LoadInst>(InstDep.getInst())) {
        if (SI->getPointerOperand() == DepLoad->getPointerOperand() &&
            SI->getOperand(0) == DepLoad && isRemovable(SI)) {
          DEBUG(dbgs() << "DSE: Remove Store Of Load from same pointer:\n  "
                       << "LOAD: " << *DepLoad << "\n  STORE: " << *SI << '\n');

          // DeleteDeadInstruction can delete the current instruction.  Save BBI
          // in case we need it.
          WeakVH NextInst(BBI);

          DeleteDeadInstruction(SI, *MD, TLI);

          if (!NextInst)  // Next instruction deleted.
            BBI = BB.begin();
          else if (BBI != BB.begin())  // Revisit this instruction if possible.
            --BBI;
          ++NumFastStores;
          MadeChange = true;
          continue;
        }
      }
    }

    // Figure out what location is being stored to.
    AliasAnalysis::Location Loc = getLocForWrite(Inst, *AA);

    // If we didn't get a useful location, fail.
    if (!Loc.Ptr)
      continue;

    while (InstDep.isDef() || InstDep.isClobber()) {
      // Get the memory clobbered by the instruction we depend on.  MemDep will
      // skip any instructions that 'Loc' clearly doesn't interact with.  If we
      // end up depending on a may- or must-aliased load, then we can't optimize
      // away the store and we bail out.  However, if we depend on on something
      // that overwrites the memory location we *can* potentially optimize it.
      //
      // Find out what memory location the dependent instruction stores.
      Instruction *DepWrite = InstDep.getInst();
      AliasAnalysis::Location DepLoc = getLocForWrite(DepWrite, *AA);
      // If we didn't get a useful location, or if it isn't a size, bail out.
      if (!DepLoc.Ptr)
        break;

      // If we find a write that is a) removable (i.e., non-volatile), b) is
      // completely obliterated by the store to 'Loc', and c) which we know that
      // 'Inst' doesn't load from, then we can remove it.
      if (isRemovable(DepWrite) &&
          !isPossibleSelfRead(Inst, Loc, DepWrite, *AA)) {
        int64_t InstWriteOffset, DepWriteOffset;
        OverwriteResult OR = isOverwrite(Loc, DepLoc, *AA,
                                         DepWriteOffset, InstWriteOffset);
        if (OR == OverwriteComplete) {
          DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
                << *DepWrite << "\n  KILLER: " << *Inst << '\n');

          // Delete the store and now-dead instructions that feed it.
          DeleteDeadInstruction(DepWrite, *MD, TLI);
          ++NumFastStores;
          MadeChange = true;

          // DeleteDeadInstruction can delete the current instruction in loop
          // cases, reset BBI.
          BBI = Inst;
          if (BBI != BB.begin())
            --BBI;
          break;
        } else if (OR == OverwriteEnd && isShortenable(DepWrite)) {
          // TODO: base this on the target vector size so that if the earlier
          // store was too small to get vector writes anyway then its likely
          // a good idea to shorten it
          // Power of 2 vector writes are probably always a bad idea to optimize
          // as any store/memset/memcpy is likely using vector instructions so
          // shortening it to not vector size is likely to be slower
          MemIntrinsic* DepIntrinsic = cast<MemIntrinsic>(DepWrite);
          unsigned DepWriteAlign = DepIntrinsic->getAlignment();
          if (llvm::isPowerOf2_64(InstWriteOffset) ||
              ((DepWriteAlign != 0) && InstWriteOffset % DepWriteAlign == 0)) {

            DEBUG(dbgs() << "DSE: Remove Dead Store:\n  OW END: "
                  << *DepWrite << "\n  KILLER (offset "
                  << InstWriteOffset << ", "
                  << DepLoc.Size << ")"
                  << *Inst << '\n');

            Value* DepWriteLength = DepIntrinsic->getLength();
            Value* TrimmedLength = ConstantInt::get(DepWriteLength->getType(),
                                                    InstWriteOffset -
                                                    DepWriteOffset);
            DepIntrinsic->setLength(TrimmedLength);
            MadeChange = true;
          }
        }
      }

      // If this is a may-aliased store that is clobbering the store value, we
      // can keep searching past it for another must-aliased pointer that stores
      // to the same location.  For example, in:
      //   store -> P
      //   store -> Q
      //   store -> P
      // we can remove the first store to P even though we don't know if P and Q
      // alias.
      if (DepWrite == &BB.front()) break;

      // Can't look past this instruction if it might read 'Loc'.
      if (AA->getModRefInfo(DepWrite, Loc) & AliasAnalysis::Ref)
        break;

      InstDep = MD->getPointerDependencyFrom(Loc, false, DepWrite, &BB);
    }
  }

  // If this block ends in a return, unwind, or unreachable, all allocas are
  // dead at its end, which means stores to them are also dead.
  if (BB.getTerminator()->getNumSuccessors() == 0)
    MadeChange |= handleEndBlock(BB);

  return MadeChange;
}
Exemplo n.º 6
0
void DSWP::buildPDG(Loop *L) {
    //Initialize PDG
	for (Loop::block_iterator bi = L->getBlocks().begin(); bi != L->getBlocks().end(); bi++) {
		BasicBlock *BB = *bi;
		for (BasicBlock::iterator ui = BB->begin(); ui != BB->end(); ui++) {
			Instruction *inst = &(*ui);

			//standardlize the name for all expr
			if (util.hasNewDef(inst)) {
				inst->setName(util.genId());
				dname[inst] = inst->getNameStr();
			} else {
				dname[inst] = util.genId();
			}

			pdg[inst] = new vector<Edge>();
			rev[inst] = new vector<Edge>();
		}
	}

	//LoopInfo &li = getAnalysis<LoopInfo>();

	/*
	 * Memory dependency analysis
	 */
	MemoryDependenceAnalysis &mda = getAnalysis<MemoryDependenceAnalysis>();

	for (Loop::block_iterator bi = L->getBlocks().begin(); bi != L->getBlocks().end(); bi++) {
		BasicBlock *BB = *bi;
		for (BasicBlock::iterator ii = BB->begin(); ii != BB->end(); ii++) {
			Instruction *inst = &(*ii);

			//data dependence = register dependence + memory dependence

			//begin register dependence
			for (Value::use_iterator ui = ii->use_begin(); ui != ii->use_end(); ui++) {
				if (Instruction *user = dyn_cast<Instruction>(*ui)) {
					addEdge(inst, user, REG);
				}
			}
			//finish register dependence

			//begin memory dependence
			MemDepResult mdr = mda.getDependency(inst);
			//TODO not sure clobbers mean!!

			if (mdr.isDef()) {
				Instruction *dep = mdr.getInst();

				if (isa<LoadInst>(inst)) {
					if (isa<StoreInst>(dep)) {
						addEdge(dep, inst, DTRUE);	//READ AFTER WRITE
					}
				}
				if (isa<StoreInst>(inst)) {
					if (isa<LoadInst>(dep)) {
						addEdge(dep, inst, DANTI);	//WRITE AFTER READ
					}
					if (isa<StoreInst>(dep)) {
						addEdge(dep, inst, DOUT);	//WRITE AFTER WRITE
					}
				}
				//READ AFTER READ IS INSERT AFTER PDG BUILD
			}
			//end memory dependence
		}//for ii
	}//for bi

	/*
	 * begin control dependence
	 */
	PostDominatorTree &pdt = getAnalysis<PostDominatorTree>();
	//cout << pdt.getRootNode()->getBlock()->getNameStr() << endl;

	/*
	 * alien code part 1
	 */
	LoopInfo *LI = &getAnalysis<LoopInfo>();
	std::set<BranchInst*> backedgeParents;
	for (Loop::block_iterator bi = L->getBlocks().begin(); bi
			!= L->getBlocks().end(); bi++) {
		BasicBlock *BB = *bi;
		for (BasicBlock::iterator ii = BB->begin(); ii != BB->end(); ii++) {
			Instruction *inst = ii;
			if (BranchInst *brInst = dyn_cast<BranchInst>(inst)) {
				// get the loop this instruction (and therefore basic block) belongs to
				Loop *instLoop = LI->getLoopFor(BB);
				bool branchesToHeader = false;
				for (int i = brInst->getNumSuccessors() - 1; i >= 0
						&& !branchesToHeader; i--) {
					// if the branch could exit, store it
					if (LI->getLoopFor(brInst->getSuccessor(i)) != instLoop) {
						branchesToHeader = true;
					}
				}
				if (branchesToHeader) {
					backedgeParents.insert(brInst);
				}
			}
		}
	}

	//build information for predecessor of blocks in post dominator tree
	for (Function::iterator bi = func->begin(); bi != func->end(); bi++) {
		BasicBlock *BB = bi;
		DomTreeNode *dn = pdt.getNode(BB);

		for (DomTreeNode::iterator di = dn->begin(); di != dn->end(); di++) {
			BasicBlock *CB = (*di)->getBlock();
			pre[CB] = BB;
		}
	}
//
//	//add dependency within a basicblock
//	for (Loop::block_iterator bi = L->getBlocks().begin(); bi != L->getBlocks().end(); bi++) {
//		BasicBlock *BB = *bi;
//		Instruction *pre = NULL;
//		for (BasicBlock::iterator ui = BB->begin(); ui != BB->end(); ui++) {
//			Instruction *inst = &(*ui);
//			if (pre != NULL) {
//				addEdge(pre, inst, CONTROL);
//			}
//			pre = inst;
//		}
//	}

//	//the special kind of dependence need loop peeling ? I don't know whether this is needed
//	for (Loop::block_iterator bi = L->getBlocks().begin(); bi != L->getBlocks().end(); bi++) {
//		BasicBlock *BB = *bi;
//		for (succ_iterator PI = succ_begin(BB); PI != succ_end(BB); ++PI) {
//			BasicBlock *succ = *PI;
//
//			checkControlDependence(BB, succ, pdt);
//		}
//	}


	/*
	 * alien code part 2
	 */
	// add normal control dependencies
	// loop through each instruction
	for (Loop::block_iterator bbIter = L->block_begin(); bbIter
			!= L->block_end(); ++bbIter) {
		BasicBlock *bb = *bbIter;
		// check the successors of this basic block
		if (BranchInst *branchInst = dyn_cast<BranchInst>(bb->getTerminator())) {
			if (branchInst->getNumSuccessors() > 1) {
				BasicBlock * succ = branchInst->getSuccessor(0);
				// if the successor is nested shallower than the current basic block, continue
				if (LI->getLoopDepth(bb) < LI->getLoopDepth(succ)) {
					continue;
				}
				// otherwise, add all instructions to graph as control dependence
				while (succ != NULL && succ != bb && LI->getLoopDepth(succ)
						>= LI->getLoopDepth(bb)) {
					Instruction *terminator = bb->getTerminator();
					for (BasicBlock::iterator succInstIter = succ->begin(); &(*succInstIter)
							!= succ->getTerminator(); ++succInstIter) {
						addEdge(terminator, &(*succInstIter), CONTROL);
					}
					if (BranchInst *succBrInst = dyn_cast<BranchInst>(succ->getTerminator())) {
						if (succBrInst->getNumSuccessors() > 1) {
							addEdge(terminator, succ->getTerminator(),
									CONTROL);
						}
					}
					if (BranchInst *br = dyn_cast<BranchInst>(succ->getTerminator())) {
						if (br->getNumSuccessors() == 1) {
							succ = br->getSuccessor(0);
						} else {
							succ = NULL;
						}
					} else {
						succ = NULL;
					}
				}
			}
		}
	}


	/*
	 * alien code part 3
	 */
    for (std::set<BranchInst*>::iterator exitIter = backedgeParents.begin(); exitIter != backedgeParents.end(); ++exitIter) {
        BranchInst *exitBranch = *exitIter;
        if (exitBranch->isConditional()) {
            BasicBlock *header = LI->getLoopFor(exitBranch->getParent())->getHeader();
            for (BasicBlock::iterator ctrlIter = header->begin(); ctrlIter != header->end(); ++ctrlIter) {
                addEdge(exitBranch, &(*ctrlIter), CONTROL);
            }
        }
    }

	//end control dependence
}
Exemplo n.º 7
0
/// processMemCpy - perform simplification of memcpy's.  If we have memcpy A
/// which copies X to Y, and memcpy B which copies Y to Z, then we can rewrite
/// B to be a memcpy from X to Z (or potentially a memmove, depending on
/// circumstances). This allows later passes to remove the first memcpy
/// altogether.
bool MemCpyOpt::processMemCpy(MemCpyInst *M) {
  // We can only optimize non-volatile memcpy's.
  if (M->isVolatile()) return false;

  // If the source and destination of the memcpy are the same, then zap it.
  if (M->getSource() == M->getDest()) {
    MD->removeInstruction(M);
    M->eraseFromParent();
    return false;
  }

  // If copying from a constant, try to turn the memcpy into a memset.
  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(M->getSource()))
    if (GV->isConstant() && GV->hasDefinitiveInitializer())
      if (Value *ByteVal = isBytewiseValue(GV->getInitializer())) {
        IRBuilder<> Builder(M);
        Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(),
                             M->getAlignment(), false);
        MD->removeInstruction(M);
        M->eraseFromParent();
        ++NumCpyToSet;
        return true;
      }

  // The optimizations after this point require the memcpy size.
  ConstantInt *CopySize = dyn_cast<ConstantInt>(M->getLength());
  if (CopySize == 0) return false;

  // The are three possible optimizations we can do for memcpy:
  //   a) memcpy-memcpy xform which exposes redundance for DSE.
  //   b) call-memcpy xform for return slot optimization.
  //   c) memcpy from freshly alloca'd space copies undefined data, and we can
  //      therefore eliminate the memcpy in favor of the data that was already
  //      at the destination.
  MemDepResult DepInfo = MD->getDependency(M);
  if (DepInfo.isClobber()) {
    if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) {
      if (performCallSlotOptzn(M, M->getDest(), M->getSource(),
                               CopySize->getZExtValue(), M->getAlignment(),
                               C)) {
        MD->removeInstruction(M);
        M->eraseFromParent();
        return true;
      }
    }
  }

  AliasAnalysis::Location SrcLoc = AliasAnalysis::getLocationForSource(M);
  MemDepResult SrcDepInfo = MD->getPointerDependencyFrom(SrcLoc, true,
                                                         M, M->getParent());
  if (SrcDepInfo.isClobber()) {
    if (MemCpyInst *MDep = dyn_cast<MemCpyInst>(SrcDepInfo.getInst()))
      return processMemCpyMemCpyDependence(M, MDep, CopySize->getZExtValue());
  } else if (SrcDepInfo.isDef()) {
    if (isa<AllocaInst>(SrcDepInfo.getInst())) {
      MD->removeInstruction(M);
      M->eraseFromParent();
      ++NumMemCpyInstr;
      return true;
    }
  }

  return false;
}
Exemplo n.º 8
0
bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
  if (skipFunction(F))
    return false;

  // TODO: Could probably handle variadic functions.
  if (F.isVarArg() || F.hasStructRetAttr() ||
      AMDGPU::isEntryFunctionCC(F.getCallingConv()))
    return false;

  MDA = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();

  unsigned ReturnNumRegs = 0;
  SmallSet<int, 4> OutArgIndexes;
  SmallVector<Type *, 4> ReturnTypes;
  Type *RetTy = F.getReturnType();
  if (!RetTy->isVoidTy()) {
    ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4;

    if (ReturnNumRegs >= MaxNumRetRegs)
      return false;

    ReturnTypes.push_back(RetTy);
  }

  SmallVector<Argument *, 4> OutArgs;
  for (Argument &Arg : F.args()) {
    if (isOutArgumentCandidate(Arg)) {
      LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
                        << " in function " << F.getName() << '\n');
      OutArgs.push_back(&Arg);
    }
  }

  if (OutArgs.empty())
    return false;

  using ReplacementVec = SmallVector<std::pair<Argument *, Value *>, 4>;

  DenseMap<ReturnInst *, ReplacementVec> Replacements;

  SmallVector<ReturnInst *, 4> Returns;
  for (BasicBlock &BB : F) {
    if (ReturnInst *RI = dyn_cast<ReturnInst>(&BB.back()))
      Returns.push_back(RI);
  }

  if (Returns.empty())
    return false;

  bool Changing;

  do {
    Changing = false;

    // Keep retrying if we are able to successfully eliminate an argument. This
    // helps with cases with multiple arguments which may alias, such as in a
    // sincos implemntation. If we have 2 stores to arguments, on the first
    // attempt the MDA query will succeed for the second store but not the
    // first. On the second iteration we've removed that out clobbering argument
    // (by effectively moving it into another function) and will find the second
    // argument is OK to move.
    for (Argument *OutArg : OutArgs) {
      bool ThisReplaceable = true;
      SmallVector<std::pair<ReturnInst *, StoreInst *>, 4> ReplaceableStores;

      Type *ArgTy = OutArg->getType()->getPointerElementType();

      // Skip this argument if converting it will push us over the register
      // count to return limit.

      // TODO: This is an approximation. When legalized this could be more. We
      // can ask TLI for exactly how many.
      unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4;
      if (ArgNumRegs + ReturnNumRegs > MaxNumRetRegs)
        continue;

      // An argument is convertible only if all exit blocks are able to replace
      // it.
      for (ReturnInst *RI : Returns) {
        BasicBlock *BB = RI->getParent();

        MemDepResult Q = MDA->getPointerDependencyFrom(MemoryLocation(OutArg),
                                                       true, BB->end(), BB, RI);
        StoreInst *SI = nullptr;
        if (Q.isDef())
          SI = dyn_cast<StoreInst>(Q.getInst());

        if (SI) {
          LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI << '\n');
          ReplaceableStores.emplace_back(RI, SI);
        } else {
          ThisReplaceable = false;
          break;
        }
      }

      if (!ThisReplaceable)
        continue; // Try the next argument candidate.

      for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) {
        Value *ReplVal = Store.second->getValueOperand();

        auto &ValVec = Replacements[Store.first];
        if (llvm::find_if(ValVec,
              [OutArg](const std::pair<Argument *, Value *> &Entry) {
                 return Entry.first == OutArg;}) != ValVec.end()) {
          LLVM_DEBUG(dbgs()
                     << "Saw multiple out arg stores" << *OutArg << '\n');
          // It is possible to see stores to the same argument multiple times,
          // but we expect these would have been optimized out already.
          ThisReplaceable = false;
          break;
        }

        ValVec.emplace_back(OutArg, ReplVal);
        Store.second->eraseFromParent();
      }

      if (ThisReplaceable) {
        ReturnTypes.push_back(ArgTy);
        OutArgIndexes.insert(OutArg->getArgNo());
        ++NumOutArgumentsReplaced;
        Changing = true;
      }
    }
  } while (Changing);

  if (Replacements.empty())
    return false;

  LLVMContext &Ctx = F.getParent()->getContext();
  StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName());

  FunctionType *NewFuncTy = FunctionType::get(NewRetTy,
                                              F.getFunctionType()->params(),
                                              F.isVarArg());

  LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n');

  Function *NewFunc = Function::Create(NewFuncTy, Function::PrivateLinkage,
                                       F.getName() + ".body");
  F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc);
  NewFunc->copyAttributesFrom(&F);
  NewFunc->setComdat(F.getComdat());

  // We want to preserve the function and param attributes, but need to strip
  // off any return attributes, e.g. zeroext doesn't make sense with a struct.
  NewFunc->stealArgumentListFrom(F);

  AttrBuilder RetAttrs;
  RetAttrs.addAttribute(Attribute::SExt);
  RetAttrs.addAttribute(Attribute::ZExt);
  RetAttrs.addAttribute(Attribute::NoAlias);
  NewFunc->removeAttributes(AttributeList::ReturnIndex, RetAttrs);
  // TODO: How to preserve metadata?

  // Move the body of the function into the new rewritten function, and replace
  // this function with a stub.
  NewFunc->getBasicBlockList().splice(NewFunc->begin(), F.getBasicBlockList());

  for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) {
    ReturnInst *RI = Replacement.first;
    IRBuilder<> B(RI);
    B.SetCurrentDebugLocation(RI->getDebugLoc());

    int RetIdx = 0;
    Value *NewRetVal = UndefValue::get(NewRetTy);

    Value *RetVal = RI->getReturnValue();
    if (RetVal)
      NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);

    for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second) {
      Argument *Arg = ReturnPoint.first;
      Value *Val = ReturnPoint.second;
      Type *EltTy = Arg->getType()->getPointerElementType();
      if (Val->getType() != EltTy) {
        Type *EffectiveEltTy = EltTy;
        if (StructType *CT = dyn_cast<StructType>(EltTy)) {
          assert(CT->getNumElements() == 1);
          EffectiveEltTy = CT->getElementType(0);
        }

        if (DL->getTypeSizeInBits(EffectiveEltTy) !=
            DL->getTypeSizeInBits(Val->getType())) {
          assert(isVec3ToVec4Shuffle(EffectiveEltTy, Val->getType()));
          Val = B.CreateShuffleVector(Val, UndefValue::get(Val->getType()),
                                      { 0, 1, 2 });
        }

        Val = B.CreateBitCast(Val, EffectiveEltTy);

        // Re-create single element composite.
        if (EltTy != EffectiveEltTy)
          Val = B.CreateInsertValue(UndefValue::get(EltTy), Val, 0);
      }

      NewRetVal = B.CreateInsertValue(NewRetVal, Val, RetIdx++);
    }

    if (RetVal)
      RI->setOperand(0, NewRetVal);
    else {
      B.CreateRet(NewRetVal);
      RI->eraseFromParent();
    }
  }

  SmallVector<Value *, 16> StubCallArgs;
  for (Argument &Arg : F.args()) {
    if (OutArgIndexes.count(Arg.getArgNo())) {
      // It's easier to preserve the type of the argument list. We rely on
      // DeadArgumentElimination to take care of these.
      StubCallArgs.push_back(UndefValue::get(Arg.getType()));
    } else {
      StubCallArgs.push_back(&Arg);
    }
  }

  BasicBlock *StubBB = BasicBlock::Create(Ctx, "", &F);
  IRBuilder<> B(StubBB);
  CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs);

  int RetIdx = RetTy->isVoidTy() ? 0 : 1;
  for (Argument &Arg : F.args()) {
    if (!OutArgIndexes.count(Arg.getArgNo()))
      continue;

    PointerType *ArgType = cast<PointerType>(Arg.getType());

    auto *EltTy = ArgType->getElementType();
    unsigned Align = Arg.getParamAlignment();
    if (Align == 0)
      Align = DL->getABITypeAlignment(EltTy);

    Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
    Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());

    // We can peek through bitcasts, so the type may not match.
    Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);

    B.CreateAlignedStore(Val, PtrVal, Align);
  }

  if (!RetTy->isVoidTy()) {
    B.CreateRet(B.CreateExtractValue(StubCall, 0));
  } else {
    B.CreateRetVoid();
  }

  // The function is now a stub we want to inline.
  F.addFnAttr(Attribute::AlwaysInline);

  ++NumOutArgumentFunctionsReplaced;
  return true;
}