// sets m_learningRateMultiplier in all LearnableParameters feeding into the passed rootNode // Called from MEL void ComputationNetwork::SetLearnableNodesBelowLearningRateMultiplier(const float learningRateMultiplier, const ComputationNodeBasePtr& rootNode) { // find nodes from all available nodes if (rootNode == nullptr) { for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr node = nodeIter->second; if (node->OperationName() == OperationNameOf(LearnableParameter)) node->SetLearningRateMultiplier(learningRateMultiplier); } } else { // for calculating a specific node if (!EvalOrderExists(rootNode)) const_cast<ComputationNetwork&>(*this).FormEvalOrder(rootNode); for (const auto& node : GetAllNodesForRoot(rootNode)) { if (node->OperationName() == OperationNameOf(LearnableParameter)) node->SetLearningRateMultiplier(learningRateMultiplier); } } }
// sets m_parameterUpdateRequired in all LearnableParameters feeding into the passed rootNode // Called from MEL --TODO: correct? void ComputationNetwork::SetLearnableNodesBelowNeedGradient(const bool needGradient, const ComputationNodeBasePtr& rootNode) { // find nodes from all available nodes if (rootNode == nullptr) { for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) { ComputationNodeBasePtr node = nodeIter->second; if (node->OperationName() == OperationNameOf(LearnableParameter)) node->SetParameterUpdateRequired(needGradient); } } else { // for calculating a specific node for (const auto& node : GetEvalOrder(rootNode)) { if (node->OperationName() == OperationNameOf(LearnableParameter)) node->SetParameterUpdateRequired(needGradient); } } }
void ComputationNetwork::SetBatchNormlizationNodesBelowEvalMode(const bool evalMode, const ComputationNodeBasePtr& rootNode /* = nullptr */) { vector<ComputationNodeBasePtr> nodes; if (rootNode == nullptr) { for (auto pair : m_nameToNodeMap) { nodes.push_back(pair.second); } } else { auto allnodes = rootNode->EnumerateNodes(); for (auto node : allnodes) nodes.push_back(node); } for (auto& node : nodes) { if (node->OperationName() == OperationNameOf(BatchNormalizationNode)) { auto pNode = dynamic_pointer_cast<BatchNormalizationNode<float>>(node); if (!pNode) { auto pNode2 = dynamic_pointer_cast<BatchNormalizationNode<double>>(node); if (!pNode2) { RuntimeError("Invalid node type: node name=%ls. We assume either BatchNormalizationNode<float> or BatchNormalizationNode<double>\n", node->NodeName().c_str()); } } else { pNode->SetEvalMode(evalMode); } } } }