// Collect the list of loop induction variables with respect to which it might
// be possible to reroll the loop.
void LoopReroll::collectPossibleIVs(Loop *L,
                                    SmallInstructionVector &PossibleIVs) {
  BasicBlock *Header = L->getHeader();
  for (BasicBlock::iterator I = Header->begin(),
       IE = Header->getFirstInsertionPt(); I != IE; ++I) {
    if (!isa<PHINode>(I))
      continue;
    if (!I->getType()->isIntegerTy())
      continue;

    if (const SCEVAddRecExpr *PHISCEV =
        dyn_cast<SCEVAddRecExpr>(SE->getSCEV(I))) {
      if (PHISCEV->getLoop() != L)
        continue;
      if (!PHISCEV->isAffine())
        continue;
      if (const SCEVConstant *IncSCEV =
          dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE))) {
        if (!IncSCEV->getValue()->getValue().isStrictlyPositive())
          continue;
        if (IncSCEV->getValue()->uge(MaxInc))
          continue;

        DEBUG(dbgs() << "LRR: Possible IV: " << *I << " = " <<
              *PHISCEV << "\n");
        PossibleIVs.push_back(I);
      }
    }
  }
}
// For all selected reductions, remove all parts except those in the first
// iteration (and the PHI). Replace outside uses of the reduced value with uses
// of the first-iteration reduced value (in other words, reroll the selected
// reductions).
void LoopReroll::ReductionTracker::replaceSelected() {
  // Fixup reductions to refer to the last instruction associated with the
  // first iteration (not the last).
  for (DenseSet<int>::iterator RI = Reds.begin(), RIE = Reds.end();
       RI != RIE; ++RI) {
    int i = *RI;
    int j = 0;
    for (int e = PossibleReds[i].size(); j != e; ++j)
      if (PossibleRedIter[PossibleReds[i][j]] != 0) {
        --j;
        break;
      }

    // Replace users with the new end-of-chain value.
    SmallInstructionVector Users;
    for (Value::use_iterator UI =
           PossibleReds[i].getReducedValue()->use_begin(),
         UIE = PossibleReds[i].getReducedValue()->use_end(); UI != UIE; ++UI)
      Users.push_back(cast<Instruction>(*UI));

    for (SmallInstructionVector::iterator J = Users.begin(),
         JE = Users.end(); J != JE; ++J)
      (*J)->replaceUsesOfWith(PossibleReds[i].getReducedValue(),
                              PossibleReds[i][j]);
  }
}
// Recognize loops that are setup like this:
//
// %iv = phi [ (preheader, ...), (body, %iv.next) ]
// %scaled.iv = mul %iv, scale
// f(%scaled.iv)
// %scaled.iv.1 = add %scaled.iv, 1
// f(%scaled.iv.1)
// %scaled.iv.2 = add %scaled.iv, 2
// f(%scaled.iv.2)
// %scaled.iv.scale_m_1 = add %scaled.iv, scale-1
// f(%scaled.iv.scale_m_1)
// ...
// %iv.next = add %iv, 1
// %cmp = icmp(%iv, ...)
// br %cmp, header, exit
//
// and, if found, set IV = %scaled.iv, and add %iv.next to LoopIncs.
bool LoopReroll::findScaleFromMul(Instruction *RealIV, uint64_t &Scale,
                                  Instruction *&IV,
                                  SmallInstructionVector &LoopIncs) {
  // This is a special case: here we're looking for all uses (except for
  // the increment) to be multiplied by a common factor. The increment must
  // be by one. This is to capture loops like:
  //   for (int i = 0; i < 500; ++i) {
  //     foo(3*i); foo(3*i+1); foo(3*i+2);
  //   }
  if (RealIV->getNumUses() != 2)
    return false;
  const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(RealIV));
  Instruction *User1 = cast<Instruction>(*RealIV->use_begin()),
              *User2 = cast<Instruction>(*llvm::next(RealIV->use_begin()));
  if (!SE->isSCEVable(User1->getType()) || !SE->isSCEVable(User2->getType()))
    return false;
  const SCEVAddRecExpr *User1SCEV =
                         dyn_cast<SCEVAddRecExpr>(SE->getSCEV(User1)),
                       *User2SCEV =
                         dyn_cast<SCEVAddRecExpr>(SE->getSCEV(User2));
  if (!User1SCEV || !User1SCEV->isAffine() ||
      !User2SCEV || !User2SCEV->isAffine())
    return false;

  // We assume below that User1 is the scale multiply and User2 is the
  // increment. If this can't be true, then swap them.
  if (User1SCEV == RealIVSCEV->getPostIncExpr(*SE)) {
    std::swap(User1, User2);
    std::swap(User1SCEV, User2SCEV);
  }

  if (User2SCEV != RealIVSCEV->getPostIncExpr(*SE))
    return false;
  assert(User2SCEV->getStepRecurrence(*SE)->isOne() &&
         "Invalid non-unit step for multiplicative scaling");
  LoopIncs.push_back(User2);

  if (const SCEVConstant *MulScale =
      dyn_cast<SCEVConstant>(User1SCEV->getStepRecurrence(*SE))) {
    // Make sure that both the start and step have the same multiplier.
    if (RealIVSCEV->getStart()->getType() != MulScale->getType())
      return false;
    if (SE->getMulExpr(RealIVSCEV->getStart(), MulScale) !=
        User1SCEV->getStart())
      return false;

    ConstantInt *MulScaleCI = MulScale->getValue();
    if (!MulScaleCI->uge(2) || MulScaleCI->uge(MaxInc))
      return false;
    Scale = MulScaleCI->getZExtValue();
    IV = User1;
  } else
    return false;

  DEBUG(dbgs() << "LRR: Found possible scaling " << *User1 << "\n");
  return true;
}
// Collect all root increments with respect to the provided induction variable
// (normally the PHI, but sometimes a multiply). A root increment is an
// instruction, normally an add, with a positive constant less than Scale. In a
// rerollable loop, each of these increments is the root of an instruction
// graph isomorphic to the others. Also, we collect the final induction
// increment (the increment equal to the Scale), and its users in LoopIncs.
bool LoopReroll::collectAllRoots(Loop *L, uint64_t Inc, uint64_t Scale,
                                 Instruction *IV,
                                 SmallVector<SmallInstructionVector, 32> &Roots,
                                 SmallInstructionSet &AllRoots,
                                 SmallInstructionVector &LoopIncs) {
  for (Value::use_iterator UI = IV->use_begin(),
       UIE = IV->use_end(); UI != UIE; ++UI) {
    Instruction *User = cast<Instruction>(*UI);
    if (!SE->isSCEVable(User->getType()))
      continue;
    if (User->getType() != IV->getType())
      continue;
    if (!L->contains(User))
      continue;
    if (hasUsesOutsideLoop(User, L))
      continue;

    if (const SCEVConstant *Diff = dyn_cast<SCEVConstant>(SE->getMinusSCEV(
          SE->getSCEV(User), SE->getSCEV(IV)))) {
      uint64_t Idx = Diff->getValue()->getValue().getZExtValue();
      if (Idx > 0 && Idx < Scale) {
        Roots[Idx-1].push_back(User);
        AllRoots.insert(User);
      } else if (Idx == Scale && Inc > 1) {
        LoopIncs.push_back(User);
      }
    }
  }

  if (Roots[0].empty())
    return false;
  bool AllSame = true;
  for (unsigned i = 1; i < Scale-1; ++i)
    if (Roots[i].size() != Roots[0].size()) {
      AllSame = false;
      break;
    }

  if (!AllSame)
    return false;

  return true;
}