/* To implement task reduce intents, we follow the steps in propagateExtraLeaderArgs() and setupOneReduceIntent() I.e. add the following to the AST: * before 'call' def globalOp = new reduceType(origSym.type); * pass globalOp to call(); corresponding formal in 'fn': parentOp * inside 'fn' def currOp = parentOp.clone() def symReplace = currOp.identify; ... currOp.accumulate(symReplace); parentOp.combine(currOp); delete currOp; * after 'call' and its _waitEndCount() origSym = parentOp.generate(); delete parentOp; Put in a different way, a coforall like this: var x: int; coforall ITER with (OP reduce x) { BODY(x); // will typically include: x OP= something; } with its corresponding task-function representation: var x: int; proc coforall_fn() { BODY(x); } call coforall_fn(); is transformed into var x: int; var globalOp = new OP_SCAN_REDUCE_CLASS(x.type); proc coforall_fn(parentOp) { var currOp = parentOp.clone() var symReplace = currOp.identify; BODY(symReplace); currOp.accumulate(symReplace); parentOp.combine(currOp); delete currOp; } call coforall_fn(globalOp); // wait for endCount - not shown x = globalOp.generate(); delete globalOp; Todo: to support cobegin constructs, need to share 'globalOp' across all fn+call pairs for the same construct. */ static void addReduceIntentSupport(FnSymbol* fn, CallExpr* call, TypeSymbol* reduceType, Symbol* origSym, ArgSymbol*& newFormal, Symbol*& newActual, Symbol*& symReplace, bool isCoforall, Expr*& redRef1, Expr*& redRef2) { setupRedRefs(fn, true, redRef1, redRef2); VarSymbol* globalOp = new VarSymbol("reduceGlobal"); globalOp->addFlag(FLAG_NO_CAPTURE_FOR_TASKING); newActual = globalOp; VarSymbol* eltType = newTemp("redEltType"); eltType->addFlag(FLAG_MAYBE_TYPE); Expr* headAnchor = call; if (isCoforall) headAnchor = headAnchor->parentExpr; headAnchor->insertBefore(new DefExpr(eltType)); headAnchor->insertBefore("'move'(%S, 'typeof'(%S))", eltType, origSym); headAnchor->insertBefore(new DefExpr(globalOp)); AggregateType* reduceAt = toAggregateType(reduceType->type); INT_ASSERT(reduceAt); CallExpr* newOp = new CallExpr(reduceAt->defaultInitializer->name, new NamedExpr("eltType", new SymExpr(eltType))); headAnchor->insertBefore(new CallExpr(PRIM_MOVE, globalOp, newOp)); Expr* tailAnchor = findTailInsertionPoint(call, isCoforall); // Doing insertAfter() calls in reverse order. // Can't insertBefore() on tailAnchor->next - that can be NULL. tailAnchor->insertAfter("'delete'(%S)", globalOp); tailAnchor->insertAfter("'='(%S, generate(%S,%S))", origSym, gMethodToken, globalOp); ArgSymbol* parentOp = new ArgSymbol(INTENT_BLANK, "reduceParent", dtUnknown); newFormal = parentOp; VarSymbol* currOp = new VarSymbol("reduceCurr"); VarSymbol* svar = new VarSymbol(origSym->name, origSym->type); symReplace = svar; redRef1->insertBefore(new DefExpr(currOp)); redRef1->insertBefore("'move'(%S, clone(%S,%S))", // init currOp, gMethodToken, parentOp); redRef1->insertBefore(new DefExpr(svar)); redRef1->insertBefore("'move'(%S, identity(%S,%S))", // init svar, gMethodToken, currOp); redRef2->insertBefore(new CallExpr("accumulate", gMethodToken, currOp, svar)); redRef2->insertBefore(new CallExpr("chpl__reduceCombine", parentOp, currOp)); redRef2->insertBefore(new CallExpr("chpl__cleanupLocalOp", parentOp, currOp)); }
// // The argument expr is a use of a wide reference. Insert a check to ensure // that it is on the current locale, then drop its wideness by moving the // addr field into a non-wide of otherwise the same type. Then, replace its // use with the non-wide version. // static void insertLocalTemp(Expr* expr) { SymExpr* se = toSymExpr(expr); Expr* stmt = expr->getStmtExpr(); INT_ASSERT(se && stmt); SET_LINENO(se); VarSymbol* var = newTemp(astr("local_", se->var->name), se->var->type->getField("addr")->type); if (!fNoLocalChecks) { stmt->insertBefore(new CallExpr(PRIM_LOCAL_CHECK, se->copy())); } stmt->insertBefore(new DefExpr(var)); stmt->insertBefore(new CallExpr(PRIM_MOVE, var, se->copy())); se->replace(new SymExpr(var)); }
void ReturnByRef::insertAssignmentToFormal(FnSymbol* fn, ArgSymbol* formal) { Expr* returnPrim = fn->body->body.tail; SET_LINENO(returnPrim); CallExpr* returnCall = toCallExpr(returnPrim); Expr* returnValue = returnCall->get(1)->remove(); CallExpr* moveExpr = new CallExpr(PRIM_ASSIGN, formal, returnValue); Expr* expr = returnPrim; // Walk backwards while the previous element is an autoDestroy call while (expr->prev != NULL) { bool stop = true; if (CallExpr* call = toCallExpr(expr->prev)) if (FnSymbol* calledFn = call->isResolved()) if (calledFn->hasFlag(FLAG_AUTO_DESTROY_FN)) stop = false; if (stop) break; expr = expr->prev; } Expr* returnOrFirstAutoDestroy = expr; // Add the move to return before the first autoDestroy // At this point we could also invoke some other function // if that turns out to be necessary. It might well be // necessary in order to return array slices by value. returnOrFirstAutoDestroy->insertBefore(moveExpr); }
void Expr::insertBefore(AList exprs) { Expr* curr = this; for_alist_backward(prev, exprs) { prev->remove(); curr->insertBefore(prev); curr = prev; }
void ReturnByRef::insertAssignmentToFormal(FnSymbol* fn, ArgSymbol* formal) { Expr* returnPrim = fn->body->body.tail; SET_LINENO(returnPrim); CallExpr* returnCall = toCallExpr(returnPrim); Expr* returnValue = returnCall->get(1)->remove(); CallExpr* moveExpr = new CallExpr(PRIM_MOVE, formal, returnValue); returnPrim->insertBefore(moveExpr); }
/* This function copies the body of a param for loop while adjusting it slightly - to stamp out each iteration. * Inserts the body before the expression beforeHere * i should be a loop variable index (used to label iterations) * Assumes that map already contains the mapping redifining the index variable. * continueSym is the symbol for the loop's continue label. This function will replace that with a new continue label local to this iteration. */ static void copyBodyHelper(Expr* beforeHere, int64_t i, SymbolMap* map, ParamForLoop* loop, Symbol* continueSym) { // Replace the continue label with a per-iteration label // that is at the end of that iteration. LabelSymbol* continueLabel = new LabelSymbol(astr("_continueLabel", istr(i))); Expr* defContinueLabel = new DefExpr(continueLabel); beforeHere->insertBefore(defContinueLabel); map->put(continueSym, continueLabel); defContinueLabel->insertBefore(loop->copyBody(map)); }
forv_Vec(CallExpr, call, gCallExprs) { if (call->isPrimitive(PRIM_CHECK_ERROR)) { SET_LINENO(call); SymExpr* errSe = toSymExpr(call->get(1)); Symbol* errorVar= errSe->symbol(); VarSymbol* errorExistsVar = newTemp("errorExists", dtBool); DefExpr* def = new DefExpr(errorExistsVar); CallExpr* errorExists = new CallExpr(PRIM_NOTEQUAL, errorVar, gNil); CallExpr* move = new CallExpr(PRIM_MOVE, errorExistsVar, errorExists); Expr* stmt = call->getStmtExpr(); stmt->insertBefore(def); def->insertAfter(move); call->replace(new SymExpr(errorExistsVar)); } }
// // If call has the potential to cause communication, assert that the wide // reference that might cause communication is local and remove its wide-ness // // The organization of this function follows the order of CallExpr::codegen() // leaving out primitives that don't communicate. // static void localizeCall(CallExpr* call) { if (call->primitive) { switch (call->primitive->tag) { case PRIM_ARRAY_SET: /* Fallthru */ case PRIM_ARRAY_SET_FIRST: if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(call->get(1)); } break; case PRIM_MOVE: case PRIM_ASSIGN: // Not sure about this one. if (CallExpr* rhs = toCallExpr(call->get(2))) { if (rhs->isPrimitive(PRIM_DEREF)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) || rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(rhs->get(1)); if (!rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_REF)) { INT_ASSERT(rhs->get(1)->typeInfo() == dtString); // special handling for wide strings rhs->replace(rhs->get(1)->remove()); } } break; } else if (rhs->isPrimitive(PRIM_GET_MEMBER) || rhs->isPrimitive(PRIM_GET_SVEC_MEMBER) || rhs->isPrimitive(PRIM_GET_MEMBER_VALUE) || rhs->isPrimitive(PRIM_GET_SVEC_MEMBER_VALUE)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) || rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { SymExpr* sym = toSymExpr(rhs->get(2)); INT_ASSERT(sym); if (!sym->var->hasFlag(FLAG_SUPER_CLASS)) { insertLocalTemp(rhs->get(1)); } } break; } else if (rhs->isPrimitive(PRIM_ARRAY_GET) || rhs->isPrimitive(PRIM_ARRAY_GET_VALUE)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { SymExpr* lhs = toSymExpr(call->get(1)); Expr* stmt = call->getStmtExpr(); INT_ASSERT(lhs && stmt); SET_LINENO(stmt); insertLocalTemp(rhs->get(1)); VarSymbol* localVar = NULL; if (rhs->isPrimitive(PRIM_ARRAY_GET)) localVar = newTemp(astr("local_", lhs->var->name), lhs->var->type->getField("addr")->type); else localVar = newTemp(astr("local_", lhs->var->name), lhs->var->type); stmt->insertBefore(new DefExpr(localVar)); lhs->replace(new SymExpr(localVar)); stmt->insertAfter(new CallExpr(PRIM_MOVE, lhs, new SymExpr(localVar))); } break; } else if (rhs->isPrimitive(PRIM_GET_UNION_ID)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) { insertLocalTemp(rhs->get(1)); } break; } else if (rhs->isPrimitive(PRIM_TESTCID) || rhs->isPrimitive(PRIM_GETCID)) { if (rhs->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(rhs->get(1)); } break; } ; } if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) && !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { break; } if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) && !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF) && !call->get(2)->typeInfo()->symbol->hasFlag(FLAG_REF)) { insertLocalTemp(call->get(1)); } break; case PRIM_DYNAMIC_CAST: if (call->get(2)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(call->get(2)); if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) || call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) { toSymExpr(call->get(1))->var->type = call->get(1)->typeInfo()->getField("addr")->type; } } break; case PRIM_SETCID: if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS)) { insertLocalTemp(call->get(1)); } break; case PRIM_SET_UNION_ID: if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) { insertLocalTemp(call->get(1)); } break; case PRIM_SET_MEMBER: case PRIM_SET_SVEC_MEMBER: if (call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_CLASS) || call->get(1)->typeInfo()->symbol->hasFlag(FLAG_WIDE_REF)) { insertLocalTemp(call->get(1)); } break; default: break; } } }
void AutoDestroyScope::variablesDestroy(Expr* refStmt, VarSymbol* excludeVar, std::set<VarSymbol*>* ignored) const { // Handle the primary locals if (mLocalsHandled == false) { Expr* insertBeforeStmt = refStmt; Expr* noop = NULL; size_t count = mLocalsAndDefers.size(); // If this is a simple nested block, insert after the final stmt // But always insert the destruction calls in reverse declaration order. // Do not get tricked by sequences of unreachable code if (refStmt->next == NULL) { if (mParent != NULL && isGotoStmt(refStmt) == false) { SET_LINENO(refStmt); // Add a PRIM_NOOP to insert before noop = new CallExpr(PRIM_NOOP); refStmt->insertAfter(noop); insertBeforeStmt = noop; } } for (size_t i = 1; i <= count; i++) { BaseAST* localOrDefer = mLocalsAndDefers[count - i]; VarSymbol* var = toVarSymbol(localOrDefer); DeferStmt* defer = toDeferStmt(localOrDefer); // This code only handles VarSymbols and DeferStmts. // It handles both in one vector because the order // of interleaving matters. INT_ASSERT(var || defer); if (var != NULL && var != excludeVar && (ignored == NULL || ignored->count(var) == 0)) { if (FnSymbol* autoDestroyFn = autoDestroyMap.get(var->type)) { SET_LINENO(var); INT_ASSERT(autoDestroyFn->hasFlag(FLAG_AUTO_DESTROY_FN)); CallExpr* autoDestroy = new CallExpr(autoDestroyFn, var); insertBeforeStmt->insertBefore(autoDestroy); } } if (defer != NULL) { SET_LINENO(defer); BlockStmt* deferBlockCopy = defer->body()->copy(); insertBeforeStmt->insertBefore(deferBlockCopy); deferBlockCopy->flattenAndRemove(); } } // remove the PRIM_NOOP if we added one. if (noop != NULL) noop->remove(); } // Handle the formal temps if (isReturnStmt(refStmt) == true) { size_t count = mFormalTemps.size(); for (size_t i = 1; i <= count; i++) { VarSymbol* var = mFormalTemps[count - i]; if (FnSymbol* autoDestroyFn = autoDestroyMap.get(var->type)) { SET_LINENO(var); refStmt->insertBefore(new CallExpr(autoDestroyFn, var)); } } } }