// change the node associated with nodeName to newNode; used in the KL-reg based adaptation to reduce feature copy
// need to update all the mappings as well childrens
void ComputationNetwork::ChangeNode(wstring nodeName, ComputationNodeBasePtr newNode)
{
    InvalidateCompiledNetwork();

    ComputationNodeBasePtr oldNode = GetNodeFromName(nodeName);
    if (oldNode->OperationName() != newNode->OperationName())
        InvalidArgument("newNode must have the same type as the old node.");

    // change children
    for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
    {
        ComputationNodeBasePtr node = nodeIter->second;
        for (int i = 0; i < node->GetNumInputs(); i++)
            if (node->GetInputs()[i] == oldNode)
                node->SetInput(i, newNode);
    }

    // change name map
    m_nameToNodeMap[nodeName] = newNode;
    for (int i = 0; i < oldNode->GetNumInputs(); i++)
        newNode->SetInput(i, oldNode->GetInputs()[i]);

    // change other maps
    for (auto groupIter : GetAllNodeGroups())
    {
        auto& group = *groupIter;
        for (int i = 0; i < group.size(); i++)
            if (group[i] == oldNode)
                group[i] = newNode;
    }
}
// sets m_learningRateMultiplier in all LearnableParameters feeding into the passed rootNode
// Called from MEL
void ComputationNetwork::SetLearnableNodesBelowLearningRateMultiplier(const float learningRateMultiplier, const ComputationNodeBasePtr& rootNode)
{
    // find nodes from all available nodes
    if (rootNode == nullptr)
    {
        for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
        {
            ComputationNodeBasePtr node = nodeIter->second;
            if (node->OperationName() == OperationNameOf(LearnableParameter))
                node->SetLearningRateMultiplier(learningRateMultiplier);
        }
    }
    else
    {
        // for calculating a specific node
        if (!EvalOrderExists(rootNode))
            const_cast<ComputationNetwork&>(*this).FormEvalOrder(rootNode);

        for (const auto& node : GetAllNodesForRoot(rootNode))
        {
            if (node->OperationName() == OperationNameOf(LearnableParameter))
                node->SetLearningRateMultiplier(learningRateMultiplier);
        }
    }
}
// sets m_parameterUpdateRequired in all LearnableParameters feeding into the passed rootNode
// Called from MEL  --TODO: correct?
void ComputationNetwork::SetLearnableNodesBelowNeedGradient(const bool needGradient, const ComputationNodeBasePtr& rootNode)
{
    // find nodes from all available nodes
    if (rootNode == nullptr)
    {
        for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
        {
            ComputationNodeBasePtr node = nodeIter->second;
            if (node->OperationName() == OperationNameOf(LearnableParameter))
                node->SetParameterUpdateRequired(needGradient);
        }
    }
    else
    {
        // for calculating a specific node
        for (const auto& node : GetEvalOrder(rootNode))
        {
            if (node->OperationName() == OperationNameOf(LearnableParameter))
                node->SetParameterUpdateRequired(needGradient);
        }
    }
}
// recovers the processing order within a recurrent loop
// TODO: Once we only use the nested network for recurrent traversal, this will be no longer necessary.
void ComputationNetwork::DetermineLoopForwardOrder(unordered_set<ComputationNodeBasePtr>& visited,
                                                   unordered_set<ComputationNodeBasePtr>& recStack,
                                                   list<ComputationNodeBasePtr>& nodesStack,
                                                   ComputationNodeBasePtr cur)
{
    if (visited.find(cur) == visited.end())
    {
        visited.insert(cur);
        recStack.insert(cur);

        if (GetRecurrenceSteppingDirection(cur) == 0) // recurrence stops at delay nodes
        {
            for (size_t i = 0; i < cur->GetNumInputs(); i++)
                if (cur->Input(i)->m_loopId == cur->m_loopId)
                    DetermineLoopForwardOrder(visited, recStack, nodesStack, cur->Input(i));
        }
        recStack.erase(cur);
        nodesStack.push_back(cur);
    }
    else if (recStack.find(cur) != recStack.end())
        LogicError("%ls %ls operation is part of an infinite loop that cannot be unrolled.", cur->NodeName().c_str(), cur->OperationName().c_str());
}