/// This method check when nullable pointer or null value is returned from a
/// function that has nonnull return type.
void NullabilityChecker::checkPreStmt(const ReturnStmt *S,
                                      CheckerContext &C) const {
  auto RetExpr = S->getRetValue();
  if (!RetExpr)
    return;

  if (!RetExpr->getType()->isAnyPointerType())
    return;

  ProgramStateRef State = C.getState();
  if (State->get<InvariantViolated>())
    return;

  auto RetSVal = C.getSVal(S).getAs<DefinedOrUnknownSVal>();
  if (!RetSVal)
    return;

  bool InSuppressedMethodFamily = false;

  QualType RequiredRetType;
  AnalysisDeclContext *DeclCtxt =
      C.getLocationContext()->getAnalysisDeclContext();
  const Decl *D = DeclCtxt->getDecl();
  if (auto *MD = dyn_cast<ObjCMethodDecl>(D)) {
    // HACK: This is a big hammer to avoid warning when there are defensive
    // nil checks in -init and -copy methods. We should add more sophisticated
    // logic here to suppress on common defensive idioms but still
    // warn when there is a likely problem.
    ObjCMethodFamily Family = MD->getMethodFamily();
    if (OMF_init == Family || OMF_copy == Family || OMF_mutableCopy == Family)
      InSuppressedMethodFamily = true;

    RequiredRetType = MD->getReturnType();
  } else if (auto *FD = dyn_cast<FunctionDecl>(D)) {
    RequiredRetType = FD->getReturnType();
  } else {
    return;
  }

  NullConstraint Nullness = getNullConstraint(*RetSVal, State);

  Nullability RequiredNullability = getNullabilityAnnotation(RequiredRetType);

  // If the returned value is null but the type of the expression
  // generating it is nonnull then we will suppress the diagnostic.
  // This enables explicit suppression when returning a nil literal in a
  // function with a _Nonnull return type:
  //    return (NSString * _Nonnull)0;
  Nullability RetExprTypeLevelNullability =
        getNullabilityAnnotation(lookThroughImplicitCasts(RetExpr)->getType());

  bool NullReturnedFromNonNull = (RequiredNullability == Nullability::Nonnull &&
                                  Nullness == NullConstraint::IsNull);
  if (Filter.CheckNullReturnedFromNonnull &&
      NullReturnedFromNonNull &&
      RetExprTypeLevelNullability != Nullability::Nonnull &&
      !InSuppressedMethodFamily &&
      C.getLocationContext()->inTopFrame()) {
    static CheckerProgramPointTag Tag(this, "NullReturnedFromNonnull");
    ExplodedNode *N = C.generateErrorNode(State, &Tag);
    if (!N)
      return;

    SmallString<256> SBuf;
    llvm::raw_svector_ostream OS(SBuf);
    OS << (RetExpr->getType()->isObjCObjectPointerType() ? "nil" : "Null");
    OS << " returned from a " << C.getDeclDescription(D) <<
          " that is expected to return a non-null value";
    reportBugIfInvariantHolds(OS.str(),
                              ErrorKind::NilReturnedToNonnull, N, nullptr, C,
                              RetExpr);
    return;
  }

  // If null was returned from a non-null function, mark the nullability
  // invariant as violated even if the diagnostic was suppressed.
  if (NullReturnedFromNonNull) {
    State = State->set<InvariantViolated>(true);
    C.addTransition(State);
    return;
  }

  const MemRegion *Region = getTrackRegion(*RetSVal);
  if (!Region)
    return;

  const NullabilityState *TrackedNullability =
      State->get<NullabilityMap>(Region);
  if (TrackedNullability) {
    Nullability TrackedNullabValue = TrackedNullability->getValue();
    if (Filter.CheckNullableReturnedFromNonnull &&
        Nullness != NullConstraint::IsNotNull &&
        TrackedNullabValue == Nullability::Nullable &&
        RequiredNullability == Nullability::Nonnull) {
      static CheckerProgramPointTag Tag(this, "NullableReturnedFromNonnull");
      ExplodedNode *N = C.addTransition(State, C.getPredecessor(), &Tag);

      SmallString<256> SBuf;
      llvm::raw_svector_ostream OS(SBuf);
      OS << "Nullable pointer is returned from a " << C.getDeclDescription(D) <<
            " that is expected to return a non-null value";

      reportBugIfInvariantHolds(OS.str(),
                                ErrorKind::NullableReturnedToNonnull, N,
                                Region, C);
    }
    return;
  }
  if (RequiredNullability == Nullability::Nullable) {
    State = State->set<NullabilityMap>(Region,
                                       NullabilityState(RequiredNullability,
                                                        S));
    C.addTransition(State);
  }
}