/// Do a syntactic pattern match to determine whether the call is a call
/// to swap(&base[index1], &base[index2]), which can
/// be replaced with a call to MutableCollection.swapAt(_:_:) on base.
///
/// Returns true if the call can be replaced. Returns the call expression,
/// the base expression, and the two indices as out expressions.
///
/// This method takes an array of all the ApplyInsts for calls to swap()
/// in the function to avoid needing to construct a parent map over the AST
/// to find the CallExpr for the inout accesses.
static bool
canReplaceWithCallToCollectionSwapAt(const BeginAccessInst *Access1,
                                     const BeginAccessInst *Access2,
                                     ArrayRef<ApplyInst *> CallsToSwap,
                                     ASTContext &Ctx,
                                     CallExpr *&FoundCall,
                                     Expr *&Base,
                                     Expr *&Index1,
                                     Expr *&Index2) {
  if (CallsToSwap.empty())
    return false;

  // Inout arguments must be modifications.
  if (Access1->getAccessKind() != SILAccessKind::Modify ||
      Access2->getAccessKind() != SILAccessKind::Modify) {
    return false;
  }

  SILLocation Loc1 = Access1->getLoc();
  SILLocation Loc2 = Access2->getLoc();
  if (Loc1.isNull() || Loc2.isNull())
    return false;

  auto *InOut1 = Loc1.getAsASTNode<InOutExpr>();
  auto *InOut2 = Loc2.getAsASTNode<InOutExpr>();
  if (!InOut1 || !InOut2)
    return false;

  FoundCall = nullptr;
  // Look through all the calls to swap() recorded in the function to find
  // which one we're diagnosing.
  for (ApplyInst *AI : CallsToSwap) {
    SILLocation CallLoc = AI->getLoc();
    if (CallLoc.isNull())
      continue;

    auto *CE = CallLoc.getAsASTNode<CallExpr>();
    if (!CE)
      continue;

    assert(isCallToStandardLibrarySwap(CE, Ctx));
    // swap() takes two arguments.
    auto *ArgTuple = cast<TupleExpr>(CE->getArg());
    const Expr *Arg1 = ArgTuple->getElement(0);
    const Expr *Arg2 = ArgTuple->getElement(1);
    if ((Arg1 == InOut1 && Arg2 == InOut2)) {
        FoundCall = CE;
      break;
    }
  }
  if (!FoundCall)
    return false;

  // We found a call to swap(&e1, &e2). Now check to see whether it
  // matches the form swap(&someCollection[index1], &someCollection[index2]).
  auto *SE1 = dyn_cast<SubscriptExpr>(InOut1->getSubExpr());
  if (!SE1)
    return false;
  auto *SE2 = dyn_cast<SubscriptExpr>(InOut2->getSubExpr());
  if (!SE2)
    return false;

  // Do the two subscripts refer to the same subscript declaration?
  auto *Decl1 = cast<SubscriptDecl>(SE1->getDecl().getDecl());
  auto *Decl2 = cast<SubscriptDecl>(SE2->getDecl().getDecl());
  if (Decl1 != Decl2)
    return false;

  ProtocolDecl *MutableCollectionDecl = Ctx.getMutableCollectionDecl();

  // Is the subcript either (1) on MutableCollection itself or (2) a
  // a witness for a subscript on MutableCollection?
  bool IsSubscriptOnMutableCollection = false;
  ProtocolDecl *ProtocolForDecl =
      Decl1->getDeclContext()->getAsProtocolOrProtocolExtensionContext();
  if (ProtocolForDecl) {
    IsSubscriptOnMutableCollection = (ProtocolForDecl == MutableCollectionDecl);
  } else {
    for (ValueDecl *Req : Decl1->getSatisfiedProtocolRequirements()) {
      DeclContext *ReqDC = Req->getDeclContext();
      ProtocolDecl *ReqProto = ReqDC->getAsProtocolOrProtocolExtensionContext();
      assert(ReqProto && "Protocol requirement not in a protocol?");

      if (ReqProto == MutableCollectionDecl) {
        IsSubscriptOnMutableCollection = true;
        break;
      }
    }
  }

  if (!IsSubscriptOnMutableCollection)
    return false;

  // We're swapping two subscripts on mutable collections -- but are they
  // the same collection? Approximate this by checking for textual
  // equality on the base expressions. This is just an approximation,
  // but is fine for a best-effort Fix-It.
  SourceManager &SM = Ctx.SourceMgr;
  StringRef Base1Text = extractExprText(SE1->getBase(), SM);
  StringRef Base2Text = extractExprText(SE2->getBase(), SM);

  if (Base1Text != Base2Text)
    return false;

  auto *Index1Paren = dyn_cast<ParenExpr>(SE1->getIndex());
  if (!Index1Paren)
    return false;

  auto *Index2Paren = dyn_cast<ParenExpr>(SE2->getIndex());
  if (!Index2Paren)
    return false;

  Base = SE1->getBase();
  Index1 = Index1Paren->getSubExpr();
  Index2 = Index2Paren->getSubExpr();
  return true;
}