// change the node associated with nodeName to newNode; used in the KL-reg based adaptation to reduce feature copy // need to update all the mappings as well childrens void ComputationNetwork::ChangeNode(wstring nodeName, ComputationNodeBasePtr newNode) { InvalidateCompiledNetwork(); ComputationNodeBasePtr oldNode = GetNodeFromName(nodeName); if (oldNode->OperationName() != newNode->OperationName()) InvalidArgument("newNode must have the same type as the old node."); // change children for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr node = nodeIter->second; for (int i = 0; i < node->GetNumInputs(); i++) if (node->GetInputs()[i] == oldNode) node->SetInput(i, newNode); } // change name map m_nameToNodeMap[nodeName] = newNode; for (int i = 0; i < oldNode->GetNumInputs(); i++) newNode->SetInput(i, oldNode->GetInputs()[i]); // change other maps for (auto groupIter : GetAllNodeGroups()) { auto& group = *groupIter; for (int i = 0; i < group.size(); i++) if (group[i] == oldNode) group[i] = newNode; } }
// sets m_learningRateMultiplier in all LearnableParameters feeding into the passed rootNode // Called from MEL void ComputationNetwork::SetLearnableNodesBelowLearningRateMultiplier(const float learningRateMultiplier, const ComputationNodeBasePtr& rootNode) { // find nodes from all available nodes if (rootNode == nullptr) { for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr node = nodeIter->second; if (node->OperationName() == OperationNameOf(LearnableParameter)) node->SetLearningRateMultiplier(learningRateMultiplier); } } else { // for calculating a specific node if (!EvalOrderExists(rootNode)) const_cast<ComputationNetwork&>(*this).FormEvalOrder(rootNode); for (const auto& node : GetAllNodesForRoot(rootNode)) { if (node->OperationName() == OperationNameOf(LearnableParameter)) node->SetLearningRateMultiplier(learningRateMultiplier); } } }
// sets m_parameterUpdateRequired in all LearnableParameters feeding into the passed rootNode // Called from MEL --TODO: correct? void ComputationNetwork::SetLearnableNodesBelowNeedGradient(const bool needGradient, const ComputationNodeBasePtr& rootNode) { // find nodes from all available nodes if (rootNode == nullptr) { for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr node = nodeIter->second; if (node->OperationName() == OperationNameOf(LearnableParameter)) node->SetParameterUpdateRequired(needGradient); } } else { // for calculating a specific node for (const auto& node : GetEvalOrder(rootNode)) { if (node->OperationName() == OperationNameOf(LearnableParameter)) node->SetParameterUpdateRequired(needGradient); } } }
// recovers the processing order within a recurrent loop // TODO: Once we only use the nested network for recurrent traversal, this will be no longer necessary. void ComputationNetwork::DetermineLoopForwardOrder(unordered_set<ComputationNodeBasePtr>& visited, unordered_set<ComputationNodeBasePtr>& recStack, list<ComputationNodeBasePtr>& nodesStack, ComputationNodeBasePtr cur) { if (visited.find(cur) == visited.end()) { visited.insert(cur); recStack.insert(cur); if (GetRecurrenceSteppingDirection(cur) == 0) // recurrence stops at delay nodes { for (size_t i = 0; i < cur->GetNumInputs(); i++) if (cur->Input(i)->m_loopId == cur->m_loopId) DetermineLoopForwardOrder(visited, recStack, nodesStack, cur->Input(i)); } recStack.erase(cur); nodesStack.push_back(cur); } else if (recStack.find(cur) != recStack.end()) LogicError("%ls %ls operation is part of an infinite loop that cannot be unrolled.", cur->NodeName().c_str(), cur->OperationName().c_str()); }