// sets m_learningRateMultiplier in all LearnableParameters feeding into the passed rootNode
// Called from MEL
void ComputationNetwork::SetLearnableNodesBelowLearningRateMultiplier(const float learningRateMultiplier, const ComputationNodeBasePtr& rootNode)
{
    // find nodes from all available nodes
    if (rootNode == nullptr)
    {
        for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
        {
            ComputationNodeBasePtr node = nodeIter->second;
            if (node->OperationName() == OperationNameOf(LearnableParameter))
                node->SetLearningRateMultiplier(learningRateMultiplier);
        }
    }
    else
    {
        // for calculating a specific node
        if (!EvalOrderExists(rootNode))
            const_cast<ComputationNetwork&>(*this).FormEvalOrder(rootNode);

        for (const auto& node : GetAllNodesForRoot(rootNode))
        {
            if (node->OperationName() == OperationNameOf(LearnableParameter))
                node->SetLearningRateMultiplier(learningRateMultiplier);
        }
    }
}
Example #2
0
// sets m_parameterUpdateRequired in all LearnableParameters feeding into the passed rootNode
// Called from MEL  --TODO: correct?
void ComputationNetwork::SetLearnableNodesBelowNeedGradient(const bool needGradient, const ComputationNodeBasePtr& rootNode)
{
    // find nodes from all available nodes
    if (rootNode == nullptr)
    {
        for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
        {
            ComputationNodeBasePtr node = nodeIter->second;
            if (node->OperationName() == OperationNameOf(LearnableParameter))
                node->SetParameterUpdateRequired(needGradient);
        }
    }
    else
    {
        // for calculating a specific node
        for (const auto& node : GetEvalOrder(rootNode))
        {
            if (node->OperationName() == OperationNameOf(LearnableParameter))
                node->SetParameterUpdateRequired(needGradient);
        }
    }
}
Example #3
0
void ComputationNetwork::SetBatchNormlizationNodesBelowEvalMode(const bool evalMode, const ComputationNodeBasePtr& rootNode /* = nullptr */)
{
    vector<ComputationNodeBasePtr> nodes;
    if (rootNode == nullptr)
    {
        for (auto pair : m_nameToNodeMap)
        {
            nodes.push_back(pair.second);
        }
    }
    else
    {
        auto allnodes = rootNode->EnumerateNodes();
        for (auto node : allnodes)
            nodes.push_back(node);
    }

    for (auto& node : nodes)
    {
        if (node->OperationName() == OperationNameOf(BatchNormalizationNode))
        {
            auto pNode = dynamic_pointer_cast<BatchNormalizationNode<float>>(node);
            if (!pNode)
            {
                auto pNode2 = dynamic_pointer_cast<BatchNormalizationNode<double>>(node);
                if (!pNode2)
                {
                    RuntimeError("Invalid node type: node name=%ls. We assume either BatchNormalizationNode<float> or BatchNormalizationNode<double>\n", node->NodeName().c_str());
                }
            }
            else
            {
                pNode->SetEvalMode(evalMode);
            }
        }
    }
}