void ComputationNetwork::ReplaceFinalCriterionNode(wstring oldNodeName, ComputationNodeBasePtr newNode) { InvalidateCompiledNetwork(); // Checks if the node is a criterion node. int index = -1; for (int i = 0; i < m_finalCriteria.size(); ++i) { if (m_finalCriteria[i]->NodeName() == oldNodeName) { index = i; break; } } if (index == -1) RuntimeError("ReplaceFinalCriterionNode: the node to be replaced is not a criterion node."); // Replaces children. for (int i = 0; i < newNode->GetNumInputs(); ++i) { if (m_nameToNodeMap.find(newNode->GetInputs()[i]->NodeName()) == m_nameToNodeMap.end()) RuntimeError("Child node does not exist."); newNode->SetInput(i, m_nameToNodeMap[newNode->GetInputs()[i]->NodeName()]); } // Addes it to criterion node list. m_finalCriteria[index] = newNode; m_nameToNodeMap[newNode->NodeName()] = newNode; }
// We only remove the node, not delete it. void ComputationNetwork::RemoveFeatureNode(ComputationNodeBasePtr featureNode) { InvalidateCompiledNetwork(); wstring nodeName = featureNode->NodeName(); if (!NodeNameExists(nodeName)) RuntimeError("RemoveFeatureNode: feature node does not exist."); // Removes links. for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); ++nodeIter) { ComputationNodeBasePtr node = nodeIter->second; for (size_t i = 0; i < node->GetNumInputs(); ++i) { ComputationNodeBasePtr child = node->GetInputs()[i]; if (child == featureNode) { node->SetInput(i, NULL); break; } } } // Removes from feature list. auto search = std::find(m_features.begin(), m_features.end(), featureNode); if (search != m_features.end()) m_features.erase(search); m_nameToNodeMap.erase(nodeName); }
void ComputationNetwork::RenameNode(ComputationNodeBasePtr node, const std::wstring& newNodeName) { // TODO: check if new name exists m_nameToNodeMap.erase(node->NodeName()); node->SetNodeName(newNodeName); AddNodeToNet(node); }
// replace a named node by newNode of the same type under the same name, including moving over all network links // This is used in // 1. Update nodes to quantized versions. // 2. The KL-reg based adaptation to reduce feature copy (deprecated) // need to update all the mappings as well childrens. void ComputationNetwork::ReplaceNode(wstring nodeName, ComputationNodeBasePtr newNode) { ComputationNodeBasePtr oldNode = GetNodeFromName(nodeName); if (newNode->NodeName() != nodeName) // TODO: This was not tested for earlier; I hope no code depends on this. InvalidArgument("ChangeNode: newNode must have the same name as the old node."); InvalidateCompiledNetwork(); // change all nodes that have old node as input to point to the new node instead ChangeNodeInputs(oldNode, newNode); // change all inputs of this new node to share the old one's inputs for (int i = 0; i < oldNode->GetNumInputs(); i++) { newNode->SetInput(i, oldNode->GetInputs()[i]); // TODO: use AttachInput()? //oldNode->SetInput(i, nullptr); // BUGBUG: old node should no longer point into the network } // replace the node in the network RemoveNodeFromNet(oldNode); AddNodeToNet(newNode); // also update node groups for (auto groupIter : GetAllNodeGroups()) { auto& group = *groupIter; for (int i = 0; i < group.size(); i++) if (group[i] == oldNode) group[i] = newNode; } }
void ComputationNetwork::AddFeatureNode(ComputationNodeBasePtr featureNode) { InvalidateCompiledNetwork(); wstring nodeName = featureNode->NodeName(); if (NodeNameExists(nodeName)) RuntimeError("AddFeatureNode: feature node already exists."); m_nameToNodeMap[nodeName] = featureNode; m_features.push_back(featureNode); }
// TODO: how does the file distinguish float vs double nodes? void ComputationNetwork::SaveToFileImpl(const wstring& fileName, const FileOptions fileFormat) const { File fstream(fileName, fileFormat | FileOptions::fileOptionsWrite); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCN"); // model version fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BVersion"); fstream << (size_t) CURRENT_CNTK_MODEL_VERSION; fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EVersion"); fstream << (size_t) m_nameToNodeMap.size(); // put all node info first fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BNodeList"); for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr nodePtr = nodeIter->second; nodePtr->Save(fstream); } fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ENodeList"); // put relationship fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BRelation"); for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr nodePtr = nodeIter->second; fstream << nodePtr->NodeName() << nodePtr->GetNumInputs(); for (size_t i = 0; i < nodePtr->GetNumInputs(); i++) { if (!nodePtr->Input(i)) fprintf(stderr, "Warning: node %ls 's child is null, please check your ndl/mel file.\n", nodePtr->NodeName().c_str()); else fstream << nodePtr->Input(i)->NodeName(); } } fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ERelation"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BRootNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BFeatureNodes"); fstream << m_features.size(); for (size_t i = 0; i < m_features.size(); i++) fstream << m_features[i]->NodeName(); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EFeatureNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BLabelNodes"); fstream << m_labels.size(); for (size_t i = 0; i < m_labels.size(); i++) fstream << m_labels[i]->NodeName(); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ELabelNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCriterionNodes"); fstream << m_finalCriteria.size(); for (size_t i = 0; i < m_finalCriteria.size(); i++) fstream << m_finalCriteria[i]->NodeName(); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECriterionNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BEvalNodes"); fstream << m_evalNodes.size(); for (size_t i = 0; i < m_evalNodes.size(); i++) fstream << m_evalNodes[i]->NodeName(); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EEvalNodes"); fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BOutputNodes"); fstream << m_outputNodes.size(); for (size_t i = 0; i < m_outputNodes.size(); i++) { fstream << m_outputNodes[i]->NodeName(); } fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EOutputNodes"); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ERootNodes"); fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECN"); fstream.Flush(); }
// recovers the processing order within a recurrent loop // TODO: Once we only use the nested network for recurrent traversal, this will be no longer necessary. void ComputationNetwork::DetermineLoopForwardOrder(unordered_set<ComputationNodeBasePtr>& visited, unordered_set<ComputationNodeBasePtr>& recStack, list<ComputationNodeBasePtr>& nodesStack, ComputationNodeBasePtr cur) { if (visited.find(cur) == visited.end()) { visited.insert(cur); recStack.insert(cur); if (GetRecurrenceSteppingDirection(cur) == 0) // recurrence stops at delay nodes { for (size_t i = 0; i < cur->GetNumInputs(); i++) if (cur->Input(i)->m_loopId == cur->m_loopId) DetermineLoopForwardOrder(visited, recStack, nodesStack, cur->Input(i)); } recStack.erase(cur); nodesStack.push_back(cur); } else if (recStack.find(cur) != recStack.end()) LogicError("%ls %ls operation is part of an infinite loop that cannot be unrolled.", cur->NodeName().c_str(), cur->OperationName().c_str()); }