/// Match an index pointer that is fed by a sizeof(T)*Distance offset.
static IndexRawPointerInst *
matchSizeOfMultiplication(SILValue I, MetatypeInst *RequiredType,
                          BuiltinInst *&TruncOrBitCast, SILValue &Ptr,
                          SILValue &Distance) {
  IndexRawPointerInst *Res = dyn_cast<IndexRawPointerInst>(I);
  if (!Res)
    return nullptr;

  SILValue Dist;
  MetatypeInst *StrideType;
  if (match(
          Res->getOperand(1),
          m_ApplyInst(
              BuiltinValueKind::TruncOrBitCast,
              m_TupleExtractInst(
                  m_ApplyInst(
                      BuiltinValueKind::SMulOver, m_SILValue(Dist),
                      m_ApplyInst(BuiltinValueKind::ZExtOrBitCast,
                                  m_ApplyInst(BuiltinValueKind::StrideofNonZero,
                                              m_MetatypeInst(StrideType)))),
                  0))) ||
      match(
          Res->getOperand(1),
          m_ApplyInst(
              BuiltinValueKind::TruncOrBitCast,
              m_TupleExtractInst(
                  m_ApplyInst(
                      BuiltinValueKind::SMulOver,
                      m_ApplyInst(BuiltinValueKind::ZExtOrBitCast,
                                  m_ApplyInst(BuiltinValueKind::StrideofNonZero,
                                              m_MetatypeInst(StrideType))),
                      m_SILValue(Dist)),
                  0)))) {
    if (StrideType != RequiredType)
      return nullptr;
    TruncOrBitCast = cast<BuiltinInst>(Res->getOperand(1));
    Distance = Dist;
    Ptr = Res->getOperand(0);
    return Res;
  }
  return nullptr;
}
SILInstruction *
SILCombiner::
visitPointerToAddressInst(PointerToAddressInst *PTAI) {
  Builder.setCurrentDebugScope(PTAI->getDebugScope());

  // If we reach this point, we know that the types must be different since
  // otherwise simplifyInstruction would have handled the identity case. This is
  // always legal to do since address-to-pointer pointer-to-address implies
  // layout compatibility.
  //
  // (pointer-to-address (address-to-pointer %x)) -> (unchecked_addr_cast %x)
  if (auto *ATPI = dyn_cast<AddressToPointerInst>(PTAI->getOperand())) {
    return Builder.createUncheckedAddrCast(PTAI->getLoc(), ATPI->getOperand(),
                                           PTAI->getType());
  }

  // Turn this also into a index_addr. We generate this pattern after switching
  // the Word type to an explicit Int32 or Int64 in the stdlib.
  //
  // %101 = builtin "strideof_nonzero"<Int>(%84 : $@thick Int.Type) :
  //         $Builtin.Word
  // %102 = builtin "zextOrBitCast_Word_Int64"(%101 : $Builtin.Word) :
  //         $Builtin.Int64
  // %111 = builtin "smul_with_overflow_Int64"(%108 : $Builtin.Int64,
  //                               %102 : $Builtin.Int64, %20 : $Builtin.Int1) :
  //         $(Builtin.Int64, Builtin.Int1)
  // %112 = tuple_extract %111 : $(Builtin.Int64, Builtin.Int1), 0
  // %113 = builtin "truncOrBitCast_Int64_Word"(%112 : $Builtin.Int64) :
  //         $Builtin.Word
  // %114 = index_raw_pointer %100 : $Builtin.RawPointer, %113 : $Builtin.Word
  // %115 = pointer_to_address %114 : $Builtin.RawPointer to $*Int
  SILValue Distance;
  SILValue TruncOrBitCast;
  MetatypeInst *Metatype;
  IndexRawPointerInst *IndexRawPtr;
  BuiltinInst *StrideMul;
  if (match(
          PTAI->getOperand(),
          m_IndexRawPointerInst(IndexRawPtr))) {
    SILValue Ptr = IndexRawPtr->getOperand(0);
    SILValue TruncOrBitCast = IndexRawPtr->getOperand(1);
    if (match(TruncOrBitCast,
              m_ApplyInst(BuiltinValueKind::TruncOrBitCast,
                          m_TupleExtractInst(m_BuiltinInst(StrideMul), 0)))) {
      if (match(StrideMul,
                m_ApplyInst(
                    BuiltinValueKind::SMulOver, m_SILValue(Distance),
                    m_ApplyInst(BuiltinValueKind::ZExtOrBitCast,
                                m_ApplyInst(BuiltinValueKind::StrideofNonZero,
                                            m_MetatypeInst(Metatype))))) ||
          match(StrideMul,
                m_ApplyInst(
                    BuiltinValueKind::SMulOver,
                    m_ApplyInst(BuiltinValueKind::ZExtOrBitCast,
                                m_ApplyInst(BuiltinValueKind::StrideofNonZero,
                                            m_MetatypeInst(Metatype))),
                    m_SILValue(Distance)))) {
        SILType InstanceType =
            Metatype->getType().getMetatypeInstanceType(PTAI->getModule());
        auto *Trunc = cast<BuiltinInst>(TruncOrBitCast);

        // Make sure that the type of the metatype matches the type that we are
        // casting to so we stride by the correct amount.
        if (InstanceType.getAddressType() != PTAI->getType()) {
          return nullptr;
        }

        auto *NewPTAI = Builder.createPointerToAddress(PTAI->getLoc(), Ptr,
                                                        PTAI->getType());
        auto DistanceAsWord = Builder.createBuiltin(
            PTAI->getLoc(), Trunc->getName(), Trunc->getType(), {}, Distance);

        return Builder.createIndexAddr(PTAI->getLoc(), NewPTAI, DistanceAsWord);
      }
    }
  }
  // Turn:
  //
  //   %stride = Builtin.strideof(T) * %distance
  //   %ptr' = index_raw_pointer %ptr, %stride
  //   %result = pointer_to_address %ptr, $T'
  //
  // To:
  //
  //   %addr = pointer_to_address %ptr, $T
  //   %result = index_addr %addr, %distance
  //
  BuiltinInst *Bytes;
  if (match(PTAI->getOperand(),
            m_IndexRawPointerInst(m_ValueBase(),
                                  m_TupleExtractInst(m_BuiltinInst(Bytes),
                                                     0)))) {
    if (match(Bytes, m_ApplyInst(BuiltinValueKind::SMulOver, m_ValueBase(),
                                 m_ApplyInst(BuiltinValueKind::Strideof,
                                             m_MetatypeInst(Metatype)),
                                 m_ValueBase())) ||
        match(Bytes, m_ApplyInst(BuiltinValueKind::SMulOver, m_ValueBase(),
                                 m_ApplyInst(BuiltinValueKind::StrideofNonZero,
                                             m_MetatypeInst(Metatype)),
                                 m_ValueBase()))) {
      SILType InstanceType =
        Metatype->getType().getMetatypeInstanceType(PTAI->getModule());

      // Make sure that the type of the metatype matches the type that we are
      // casting to so we stride by the correct amount.
      if (InstanceType.getAddressType() != PTAI->getType())
        return nullptr;

      auto IRPI = cast<IndexRawPointerInst>(PTAI->getOperand().getDef());
      SILValue Ptr = IRPI->getOperand(0);
      SILValue Distance = Bytes->getArguments()[0];
      auto *NewPTAI =
          Builder.createPointerToAddress(PTAI->getLoc(), Ptr, PTAI->getType());
      return Builder.createIndexAddr(PTAI->getLoc(), NewPTAI, Distance);
    }
  }

  return nullptr;
}
SILInstruction *SILCombiner::visitBuiltinInst(BuiltinInst *I) {
  if (I->getBuiltinInfo().ID == BuiltinValueKind::CanBeObjCClass)
    return optimizeBuiltinCanBeObjCClass(I);
  if (I->getBuiltinInfo().ID == BuiltinValueKind::TakeArrayFrontToBack ||
      I->getBuiltinInfo().ID == BuiltinValueKind::TakeArrayBackToFront ||
      I->getBuiltinInfo().ID == BuiltinValueKind::CopyArray)
    return optimizeBuiltinArrayOperation(I, Builder);

  if (I->getNumOperands() >= 2 && I->getOperand(0) == I->getOperand(1)) {
    // It's a builtin which has the same value in its first and second operand.
    auto *Replacement = optimizeBuiltinWithSameOperands(Builder, I, this);
    if (Replacement)
      return Replacement;
  }

  // Optimize this case for unsigned and equality comparisons:
  //   cmp_*_T . (zext U->T x, zext U->T y)
  //      => cmp_*_T (x, y)
  switch (I->getBuiltinInfo().ID) {
  case BuiltinValueKind::ICMP_EQ:
  case BuiltinValueKind::ICMP_NE:
  case BuiltinValueKind::ICMP_ULE:
  case BuiltinValueKind::ICMP_ULT:
  case BuiltinValueKind::ICMP_UGE:
  case BuiltinValueKind::ICMP_UGT: {
    SILValue LCast, RCast;
    if (match(I->getArguments()[0],
              m_ApplyInst(BuiltinValueKind::ZExtOrBitCast,
                          m_SILValue(LCast))) &&
        match(I->getArguments()[1],
              m_ApplyInst(BuiltinValueKind::ZExtOrBitCast,
                          m_SILValue(RCast))) &&
        LCast->getType() == RCast->getType()) {

      auto *NewCmp = Builder.createBuiltinBinaryFunction(
          I->getLoc(), getBuiltinName(I->getBuiltinInfo().ID),
          LCast->getType(), I->getType(), {LCast, RCast});

      I->replaceAllUsesWith(NewCmp);
      replaceInstUsesWith(*I, NewCmp);
      return eraseInstFromFunction(*I);
    }
    break;
  }
  case BuiltinValueKind::And:
    return optimizeBitOp(I,
      [](APInt &left, const APInt &right) { left &= right; }    /* combine */,
      [](const APInt &i) -> bool { return i.isAllOnesValue(); } /* isNeutral */,
      [](const APInt &i) -> bool { return i.isMinValue(); }     /* isZero */,
      Builder, this);
  case BuiltinValueKind::Or:
    return optimizeBitOp(I,
      [](APInt &left, const APInt &right) { left |= right; }    /* combine */,
      [](const APInt &i) -> bool { return i.isMinValue(); }     /* isNeutral */,
      [](const APInt &i) -> bool { return i.isAllOnesValue(); } /* isZero */,
      Builder, this);
  case BuiltinValueKind::Xor:
    return optimizeBitOp(I,
      [](APInt &left, const APInt &right) { left ^= right; } /* combine */,
      [](const APInt &i) -> bool { return i.isMinValue(); }  /* isNeutral */,
      [](const APInt &i) -> bool { return false; }           /* isZero */,
      Builder, this);
  case BuiltinValueKind::DestroyArray: {
    ArrayRef<Substitution> Substs = I->getSubstitutions();
    // Check if the element type is a trivial type.
    if (Substs.size() == 1) {
      Substitution Subst = Substs[0];
      Type ElemType = Subst.getReplacement();
      if (ElemType->isCanonical() && ElemType->isLegalSILType()) {
        SILType SILElemTy = SILType::getPrimitiveObjectType(CanType(ElemType));
        // Destroying an array of trivial types is a no-op.
        if (SILElemTy.isTrivial(I->getModule()))
          return eraseInstFromFunction(*I);
      }
    }
    break;
  }
  default:
    break;
  }
  
  if (I->getBuiltinInfo().ID == BuiltinValueKind::ICMP_EQ)
    return optimizeBuiltinCompareEq(I, /*Negate Eq result*/ false);

  if (I->getBuiltinInfo().ID == BuiltinValueKind::ICMP_NE)
    return optimizeBuiltinCompareEq(I, /*Negate Eq result*/ true);

  // Optimize sub(ptrtoint(index_raw_pointer(v, x)), ptrtoint(v)) -> x.
  BuiltinInst *Bytes2;
  IndexRawPointerInst *Indexraw;
  if (I->getNumOperands() == 2 &&
      match(I, m_BuiltinInst(BuiltinValueKind::Sub,
                             m_BuiltinInst(BuiltinValueKind::PtrToInt,
                                           m_IndexRawPointerInst(Indexraw)),
                             m_BuiltinInst(Bytes2)))) {
    if (match(Bytes2,
              m_BuiltinInst(BuiltinValueKind::PtrToInt, m_ValueBase()))) {
      if (Indexraw->getOperand(0) == Bytes2->getOperand(0) &&
          Indexraw->getOperand(1)->getType() == I->getType()) {
        replaceInstUsesWith(*I, Indexraw->getOperand(1));
        return eraseInstFromFunction(*I);
      }
    }
  }

  // Canonicalize multiplication by a stride to be such that the stride is
  // always the second argument.
  if (I->getNumOperands() != 3)
    return nullptr;

  if (match(I, m_ApplyInst(BuiltinValueKind::SMulOver,
                            m_ApplyInst(BuiltinValueKind::Strideof),
                            m_ValueBase(), m_IntegerLiteralInst())) ||
      match(I, m_ApplyInst(BuiltinValueKind::SMulOver,
                            m_ApplyInst(BuiltinValueKind::StrideofNonZero),
                            m_ValueBase(), m_IntegerLiteralInst()))) {
    I->swapOperands(0, 1);
    return I;
  }

  return nullptr;
}