/*
To implement task reduce intents, we follow the steps in
propagateExtraLeaderArgs() and setupOneReduceIntent()

I.e. add the following to the AST:

* before 'call'
    def globalOp = new reduceType(origSym.type);

* pass globalOp to call();
  corresponding formal in 'fn': parentOp

* inside 'fn'
    def currOp = parentOp.clone()
    def symReplace = currOp.identify;
    ...
    currOp.accumulate(symReplace);
    parentOp.combine(currOp);
    delete currOp;

* after 'call' and its _waitEndCount()
    origSym = parentOp.generate();
    delete parentOp;

Put in a different way, a coforall like this:

    var x: int;
    coforall ITER with (OP reduce x) {
      BODY(x);      // will typically include:  x OP= something;
    }

with its corresponding task-function representation:

    var x: int;
    proc coforall_fn() {
      BODY(x);
    }
    call coforall_fn();

is transformed into

    var x: int;
    var globalOp = new OP_SCAN_REDUCE_CLASS(x.type);

    proc coforall_fn(parentOp) {
      var currOp = parentOp.clone()
      var symReplace = currOp.identify;

      BODY(symReplace);

      currOp.accumulate(symReplace);
      parentOp.combine(currOp);
      delete currOp;
    }

    call coforall_fn(globalOp);
    // wait for endCount - not shown
    x = globalOp.generate();
    delete globalOp;

Todo: to support cobegin constructs, need to share 'globalOp'
across all fn+call pairs for the same construct.
*/
static void addReduceIntentSupport(FnSymbol* fn, CallExpr* call,
                                   TypeSymbol* reduceType, Symbol* origSym,
                                   ArgSymbol*& newFormal, Symbol*& newActual,
                                   Symbol*& symReplace, bool isCoforall,
                                   Expr*& redRef1, Expr*& redRef2)
{
  setupRedRefs(fn, true, redRef1, redRef2);
  VarSymbol* globalOp = new VarSymbol("reduceGlobal");
  globalOp->addFlag(FLAG_NO_CAPTURE_FOR_TASKING);
  newActual = globalOp;

  VarSymbol* eltType = newTemp("redEltType");
  eltType->addFlag(FLAG_MAYBE_TYPE);

  Expr* headAnchor = call;
  if (isCoforall) headAnchor = headAnchor->parentExpr;
  headAnchor->insertBefore(new DefExpr(eltType));
  headAnchor->insertBefore("'move'(%S, 'typeof'(%S))", eltType, origSym);
  headAnchor->insertBefore(new DefExpr(globalOp));

  AggregateType* reduceAt = toAggregateType(reduceType->type);
  INT_ASSERT(reduceAt);

  CallExpr* newOp = new CallExpr(reduceAt->defaultInitializer->name,
                                 new NamedExpr("eltType", new SymExpr(eltType)));
  headAnchor->insertBefore(new CallExpr(PRIM_MOVE, globalOp, newOp));

  Expr* tailAnchor = findTailInsertionPoint(call, isCoforall);
  // Doing insertAfter() calls in reverse order.
  // Can't insertBefore() on tailAnchor->next - that can be NULL.
  tailAnchor->insertAfter("'delete'(%S)",
                         globalOp);
  tailAnchor->insertAfter("'='(%S, generate(%S,%S))",
                         origSym, gMethodToken, globalOp);

  ArgSymbol* parentOp = new ArgSymbol(INTENT_BLANK, "reduceParent", dtUnknown);
  newFormal = parentOp;

  VarSymbol* currOp = new VarSymbol("reduceCurr");
  VarSymbol* svar  = new VarSymbol(origSym->name, origSym->type);
  symReplace = svar;

  redRef1->insertBefore(new DefExpr(currOp));
  redRef1->insertBefore("'move'(%S, clone(%S,%S))", // init
                        currOp, gMethodToken, parentOp);
  redRef1->insertBefore(new DefExpr(svar));
  redRef1->insertBefore("'move'(%S, identity(%S,%S))", // init
                        svar, gMethodToken, currOp);

  redRef2->insertBefore(new CallExpr("accumulate",
                                     gMethodToken, currOp, svar));
  redRef2->insertBefore(new CallExpr("chpl__reduceCombine",
                                     parentOp, currOp));
  redRef2->insertBefore(new CallExpr("chpl__cleanupLocalOp",
                                     parentOp, currOp));
}
//
// If call has the potential to cause communication, assert that the wide
// reference that might cause communication is local and remove its wide-ness
//
// The organization of this function follows the order of CallExpr::codegen()
// leaving out primitives that don't communicate.
//
static void localizeCall(CallExpr* call) {
  if (call->primitive) {
    switch (call->primitive->tag) {
    case PRIM_ARRAY_SET: /* Fallthru */
    case PRIM_ARRAY_SET_FIRST:
      if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) {
        insertLocalTemp(call->get(1));
      }
      break;
    case PRIM_MOVE:
    case PRIM_ASSIGN: // Not sure about this one.
      if (CallExpr* rhs = toCallExpr(call->get(2))) {
        if (rhs->isPrimitive(PRIM_DEREF)) {
          if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) ||
              rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) {
            insertLocalTemp(rhs->get(1));
            if (!rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_REF)) {
              INT_ASSERT(rhs->get(1)->typeInfo() == dtString);
              // special handling for wide strings
              rhs->replace(rhs->get(1)->remove());
            }
          }
          break;
        } 
        else if (rhs->isPrimitive(PRIM_GET_MEMBER) ||
                   rhs->isPrimitive(PRIM_GET_SVEC_MEMBER) ||
                   rhs->isPrimitive(PRIM_GET_MEMBER_VALUE) ||
                   rhs->isPrimitive(PRIM_GET_SVEC_MEMBER_VALUE)) {
          if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) ||
              rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) {
            SymExpr* sym = toSymExpr(rhs->get(2));
            INT_ASSERT(sym);
            if (!sym->var->hasFlag(FLAG_SUPER_CLASS)) {
              insertLocalTemp(rhs->get(1));
            }
          }
          break;
        } else if (rhs->isPrimitive(PRIM_ARRAY_GET) ||
                   rhs->isPrimitive(PRIM_ARRAY_GET_VALUE)) {
          if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) {
            SymExpr* lhs = toSymExpr(call->get(1));
            Expr* stmt = call->getStmtExpr();
            INT_ASSERT(lhs && stmt);

            SET_LINENO(stmt);
            insertLocalTemp(rhs->get(1));
            VarSymbol* localVar = NULL;
            if (rhs->isPrimitive(PRIM_ARRAY_GET))
              localVar = newTemp(astr("local_", lhs->var->name),
                                 lhs->var->type->getField("addr")->type);
            else
              localVar = newTemp(astr("local_", lhs->var->name),
                                 lhs->var->type);
            stmt->insertBefore(new DefExpr(localVar));
            lhs->replace(new SymExpr(localVar));
            stmt->insertAfter(new CallExpr(PRIM_MOVE, lhs,
                                           new SymExpr(localVar)));
          }
          break;
        } else if (rhs->isPrimitive(PRIM_GET_UNION_ID)) {
          if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) {
            insertLocalTemp(rhs->get(1));
          }
          break;
        } else if (rhs->isPrimitive(PRIM_TESTCID) ||
                   rhs->isPrimitive(PRIM_GETCID)) {
          if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) {
            insertLocalTemp(rhs->get(1));
          }
          break;
        }
        ;
      }
      if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) &&
          !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) {
        break;
      }
      if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) &&
          !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) &&
          !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_REF)) {
        insertLocalTemp(call->get(1));
      }
      break;
    case PRIM_DYNAMIC_CAST:
      if (call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) {
        insertLocalTemp(call->get(2));
        if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) ||
            call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) {
          toSymExpr(call->get(1))->var->type = call->get(1)->typeInfo()->getField("addr")->type;
        }
      }
      break;
    case PRIM_SETCID:
      if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) {
        insertLocalTemp(call->get(1));
      }
      break;
    case PRIM_SET_UNION_ID:
      if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) {
        insertLocalTemp(call->get(1));
      }
      break;
    case PRIM_SET_MEMBER:
    case PRIM_SET_SVEC_MEMBER:
      if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) ||
          call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) {
        insertLocalTemp(call->get(1));
      }
      break;
    default:
      break;
    }
  }
}