/* To implement task reduce intents, we follow the steps in propagateExtraLeaderArgs() and setupOneReduceIntent() I.e. add the following to the AST: * before 'call' def globalOp = new reduceType(origSym.type); * pass globalOp to call(); corresponding formal in 'fn': parentOp * inside 'fn' def currOp = parentOp.clone() def symReplace = currOp.identify; ... currOp.accumulate(symReplace); parentOp.combine(currOp); delete currOp; * after 'call' and its _waitEndCount() origSym = parentOp.generate(); delete parentOp; Put in a different way, a coforall like this: var x: int; coforall ITER with (OP reduce x) { BODY(x); // will typically include: x OP= something; } with its corresponding task-function representation: var x: int; proc coforall_fn() { BODY(x); } call coforall_fn(); is transformed into var x: int; var globalOp = new OP_SCAN_REDUCE_CLASS(x.type); proc coforall_fn(parentOp) { var currOp = parentOp.clone() var symReplace = currOp.identify; BODY(symReplace); currOp.accumulate(symReplace); parentOp.combine(currOp); delete currOp; } call coforall_fn(globalOp); // wait for endCount - not shown x = globalOp.generate(); delete globalOp; Todo: to support cobegin constructs, need to share 'globalOp' across all fn+call pairs for the same construct. */ static void addReduceIntentSupport(FnSymbol* fn, CallExpr* call, TypeSymbol* reduceType, Symbol* origSym, ArgSymbol*& newFormal, Symbol*& newActual, Symbol*& symReplace, bool isCoforall, Expr*& redRef1, Expr*& redRef2) { setupRedRefs(fn, true, redRef1, redRef2); VarSymbol* globalOp = new VarSymbol("reduceGlobal"); globalOp->addFlag(FLAG_NO_CAPTURE_FOR_TASKING); newActual = globalOp; VarSymbol* eltType = newTemp("redEltType"); eltType->addFlag(FLAG_MAYBE_TYPE); Expr* headAnchor = call; if (isCoforall) headAnchor = headAnchor->parentExpr; headAnchor->insertBefore(new DefExpr(eltType)); headAnchor->insertBefore("'move'(%S, 'typeof'(%S))", eltType, origSym); headAnchor->insertBefore(new DefExpr(globalOp)); AggregateType* reduceAt = toAggregateType(reduceType->type); INT_ASSERT(reduceAt); CallExpr* newOp = new CallExpr(reduceAt->defaultInitializer->name, new NamedExpr("eltType", new SymExpr(eltType))); headAnchor->insertBefore(new CallExpr(PRIM_MOVE, globalOp, newOp)); Expr* tailAnchor = findTailInsertionPoint(call, isCoforall); // Doing insertAfter() calls in reverse order. // Can't insertBefore() on tailAnchor->next - that can be NULL. tailAnchor->insertAfter("'delete'(%S)", globalOp); tailAnchor->insertAfter("'='(%S, generate(%S,%S))", origSym, gMethodToken, globalOp); ArgSymbol* parentOp = new ArgSymbol(INTENT_BLANK, "reduceParent", dtUnknown); newFormal = parentOp; VarSymbol* currOp = new VarSymbol("reduceCurr"); VarSymbol* svar = new VarSymbol(origSym->name, origSym->type); symReplace = svar; redRef1->insertBefore(new DefExpr(currOp)); redRef1->insertBefore("'move'(%S, clone(%S,%S))", // init currOp, gMethodToken, parentOp); redRef1->insertBefore(new DefExpr(svar)); redRef1->insertBefore("'move'(%S, identity(%S,%S))", // init svar, gMethodToken, currOp); redRef2->insertBefore(new CallExpr("accumulate", gMethodToken, currOp, svar)); redRef2->insertBefore(new CallExpr("chpl__reduceCombine", parentOp, currOp)); redRef2->insertBefore(new CallExpr("chpl__cleanupLocalOp", parentOp, currOp)); }
// // If call has the potential to cause communication, assert that the wide // reference that might cause communication is local and remove its wide-ness // // The organization of this function follows the order of CallExpr::codegen() // leaving out primitives that don't communicate. // static void localizeCall(CallExpr* call) { if (call->primitive) { switch (call->primitive->tag) { case PRIM_ARRAY_SET: /* Fallthru */ case PRIM_ARRAY_SET_FIRST: if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(call->get(1)); } break; case PRIM_MOVE: case PRIM_ASSIGN: // Not sure about this one. if (CallExpr* rhs = toCallExpr(call->get(2))) { if (rhs->isPrimitive(PRIM_DEREF)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) || rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(rhs->get(1)); if (!rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_REF)) { INT_ASSERT(rhs->get(1)->typeInfo() == dtString); // special handling for wide strings rhs->replace(rhs->get(1)->remove()); } } break; } else if (rhs->isPrimitive(PRIM_GET_MEMBER) || rhs->isPrimitive(PRIM_GET_SVEC_MEMBER) || rhs->isPrimitive(PRIM_GET_MEMBER_VALUE) || rhs->isPrimitive(PRIM_GET_SVEC_MEMBER_VALUE)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) || rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { SymExpr* sym = toSymExpr(rhs->get(2)); INT_ASSERT(sym); if (!sym->var->hasFlag(FLAG_SUPER_CLASS)) { insertLocalTemp(rhs->get(1)); } } break; } else if (rhs->isPrimitive(PRIM_ARRAY_GET) || rhs->isPrimitive(PRIM_ARRAY_GET_VALUE)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { SymExpr* lhs = toSymExpr(call->get(1)); Expr* stmt = call->getStmtExpr(); INT_ASSERT(lhs && stmt); SET_LINENO(stmt); insertLocalTemp(rhs->get(1)); VarSymbol* localVar = NULL; if (rhs->isPrimitive(PRIM_ARRAY_GET)) localVar = newTemp(astr("local_", lhs->var->name), lhs->var->type->getField("addr")->type); else localVar = newTemp(astr("local_", lhs->var->name), lhs->var->type); stmt->insertBefore(new DefExpr(localVar)); lhs->replace(new SymExpr(localVar)); stmt->insertAfter(new CallExpr(PRIM_MOVE, lhs, new SymExpr(localVar))); } break; } else if (rhs->isPrimitive(PRIM_GET_UNION_ID)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) { insertLocalTemp(rhs->get(1)); } break; } else if (rhs->isPrimitive(PRIM_TESTCID) || rhs->isPrimitive(PRIM_GETCID)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(rhs->get(1)); } break; } ; } if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) && !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { break; } if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) && !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) && !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_REF)) { insertLocalTemp(call->get(1)); } break; case PRIM_DYNAMIC_CAST: if (call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(call->get(2)); if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) || call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) { toSymExpr(call->get(1))->var->type = call->get(1)->typeInfo()->getField("addr")->type; } } break; case PRIM_SETCID: if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(call->get(1)); } break; case PRIM_SET_UNION_ID: if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) { insertLocalTemp(call->get(1)); } break; case PRIM_SET_MEMBER: case PRIM_SET_SVEC_MEMBER: if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) || call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) { insertLocalTemp(call->get(1)); } break; default: break; } } }