// recovers the processing order within a recurrent loop // TODO: Once we only use the nested network for recurrent traversal, this will be no longer necessary. void ComputationNetwork::DetermineLoopForwardOrder(unordered_set<ComputationNodeBasePtr>& visited, unordered_set<ComputationNodeBasePtr>& recStack, list<ComputationNodeBasePtr>& nodesStack, ComputationNodeBasePtr cur) { if (visited.find(cur) == visited.end()) { visited.insert(cur); recStack.insert(cur); if (GetRecurrenceSteppingDirection(cur) == 0) // recurrence stops at delay nodes { for (size_t i = 0; i < cur->GetNumInputs(); i++) if (cur->Input(i)->m_loopId == cur->m_loopId) DetermineLoopForwardOrder(visited, recStack, nodesStack, cur->Input(i)); } recStack.erase(cur); nodesStack.push_back(cur); } else if (recStack.find(cur) != recStack.end()) LogicError("%ls %ls operation is part of an infinite loop that cannot be unrolled.", cur->NodeName().c_str(), cur->OperationName().c_str()); }
// traverse sub-graph feeding this node (which is a top-level node at start, e.g. training criterion) and list // - all nodes that participate in a loop -> recurrentResult[loopId][] // - all nodes that don't -> noRecurrentResult[] // in order of traversal (depth-first). // This is part of the FormRecurrentLoops() process, and only called from there from one place. void ComputationNetwork::GatherLoopNodesR(const ComputationNodeBasePtr& node, unordered_set<ComputationNodeBasePtr>& visited, map<int, list<ComputationNodeBasePtr>>& recurrentResult, list<ComputationNodeBasePtr>& noRecurrentResult) { if (visited.find(node) != visited.end()) return; // do each node only once visited.insert(node); for (int i = 0; i < node->GetNumInputs(); i++) GatherLoopNodesR(node->Input(i), visited, recurrentResult, noRecurrentResult); if (node->m_loopId >= 0) recurrentResult[node->m_loopId].push_back(node); else noRecurrentResult.push_back(node); }
// TODO: how does the file distinguish float vs double nodes? void ComputationNetwork::SaveToFileImpl(const wstring& fileName, const FileOptions fileFormat) const { File fstream(fileName, fileFormat | FileOptions::fileOptionsWrite); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCN"); // model version fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BVersion"); fstream << (size_t) CURRENT_CNTK_MODEL_VERSION; fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EVersion"); fstream << (size_t) m_nameToNodeMap.size(); // put all node info first fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BNodeList"); for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr nodePtr = nodeIter->second; nodePtr->Save(fstream); } fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ENodeList"); // put relationship fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BRelation"); for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr nodePtr = nodeIter->second; fstream << nodePtr->NodeName() << nodePtr->GetNumInputs(); for (size_t i = 0; i < nodePtr->GetNumInputs(); i++) { if (!nodePtr->Input(i)) fprintf(stderr, "Warning: node %ls 's child is null, please check your ndl/mel file.\n", nodePtr->NodeName().c_str()); else fstream << nodePtr->Input(i)->NodeName(); } } fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ERelation"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BRootNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BFeatureNodes"); fstream << m_features.size(); for (size_t i = 0; i < m_features.size(); i++) fstream << m_features[i]->NodeName(); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EFeatureNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BLabelNodes"); fstream << m_labels.size(); for (size_t i = 0; i < m_labels.size(); i++) fstream << m_labels[i]->NodeName(); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ELabelNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCriterionNodes"); fstream << m_finalCriteria.size(); for (size_t i = 0; i < m_finalCriteria.size(); i++) fstream << m_finalCriteria[i]->NodeName(); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECriterionNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BEvalNodes"); fstream << m_evalNodes.size(); for (size_t i = 0; i < m_evalNodes.size(); i++) fstream << m_evalNodes[i]->NodeName(); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EEvalNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BOutputNodes"); fstream << m_outputNodes.size(); for (size_t i = 0; i < m_outputNodes.size(); i++) { fstream << m_outputNodes[i]->NodeName(); } fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EOutputNodes"); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ERootNodes"); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECN"); fstream.Flush(); }
// (recursive part of DetermineSCCs()) void ComputationNetwork::DetermineSCCsR(ComputationNodeBasePtr cur, list<ComputationNodeBasePtr>& sccStack, size_t& index, size_t& loopId) { assert(!cur->m_visited); // set the index (in order of visitation) cur->m_index = index; // TODO: can this be used as m_visitedOrder? cur->m_minIndex = index; // also set m_minIndex index++; cur->m_visited = true; sccStack.push_back(cur); cur->m_inStack = true; // set m_minIndex to min over m_lowLinks of children for (int i = 0; i < cur->GetNumInputs(); i++) { if (!cur->Input(i)->m_visited) { DetermineSCCsR(cur->Input(i), sccStack, index, loopId); cur->m_minIndex = min(cur->m_minIndex, cur->Input(i)->m_minIndex); } else if (cur->Input(i)->m_inStack) { cur->m_minIndex = min(cur->m_minIndex, cur->Input(i)->m_minIndex); } } // if we closed a loop then create an entry in m_allSEQNodes if (cur->m_minIndex == cur->m_index) // m_minIndex is still equal to m_index, as we set it at the start of this function: we closed a loop { // gather the list of all nodes in this loop vector<ComputationNodeBasePtr> nestedNodes; // TODO: build array first in a local array. Only if succeeds, then construct the node off it. SEQTraversalFlowControlNode rInfo(m_allSEQNodes.size() /*loopId*/, cur); for (;;) { ComputationNodeBasePtr w = sccStack.back(); sccStack.pop_back(); w->m_inStack = false; nestedNodes.push_back(w); if (w == cur) // hit our starting point: done break; } // insert loop into m_allSEQNodes if (nestedNodes.size() > 1) // non-looped nodes are detected here as loops of size 1 --skip those { // only add to the array if the loop is not already there // We end up producing the same loop multiple times because: // - FormRecurrentLoops() is called multiple times from different roots // - depth-first traversal might have led us to enter a loop multiple times? // TODO: Check whether this edge case of idempotence is done correctly: // - a recurrent loop with two delay nodes // - two root nodes // - the first root takes the first delay node's value, the second root that of the second delay node // I.e. the depth-first tree traversals enter the loop at two different places (m_sourceNode). // -> Are these two loops detected as identical? (determined by m_minIndex, but m_index depends on traversal from each root, so maybe not) bool bFound = false; // find a dup --TODO: check whether there is an STL algorithm for this for (const auto& iter2 : m_allSEQNodes) { if (iter2->m_sourceNode == cur) { bFound = true; break; } } if (!bFound) { #if 1 if (loopId != m_allSEQNodes.size()) LogicError("DetermineSCCsR(): inconsistent loopId (%d) vs. m_allSEQNodes.size() (%d)", (int) loopId, (int) m_allSEQNodes.size()); SEQTraversalFlowControlNode rInfo(m_allSEQNodes.size(), cur); #else assert(loopId == m_allSEQNodes.size()); // BUGBUG: Only true if all loops are shared among roots. Fix: use m_allSEQNodes.size() instead SEQTraversalFlowControlNode rInfo(loopId, cur); #endif // TODO: can we prove that 'cur' == nestedNodes.front()? If so, we won't need to store it separately. rInfo.m_nestedNodes = move(nestedNodes); // TODO: make these two part of the constructor rInfo.m_steppingDirection = DetermineLoopDirection(rInfo.m_nestedNodes); m_allSEQNodes.push_back(make_shared<SEQTraversalFlowControlNode>(move(rInfo))); loopId++; // and count it TODO: may be removed } } } }