void ComputationNetwork::ReplaceFinalCriterionNode(wstring oldNodeName, ComputationNodeBasePtr newNode)
{
    InvalidateCompiledNetwork();

    // Checks if the node is a criterion node.
    int index = -1;
    for (int i = 0; i < m_finalCriteria.size(); ++i)
    {
        if (m_finalCriteria[i]->NodeName() == oldNodeName)
        {
            index = i;
            break;
        }
    }
    if (index == -1)
        RuntimeError("ReplaceFinalCriterionNode: the node to be replaced is not a criterion node.");

    // Replaces children.
    for (int i = 0; i < newNode->GetNumInputs(); ++i)
    {
        if (m_nameToNodeMap.find(newNode->GetInputs()[i]->NodeName()) == m_nameToNodeMap.end())
            RuntimeError("Child node does not exist.");
        newNode->SetInput(i, m_nameToNodeMap[newNode->GetInputs()[i]->NodeName()]);
    }

    // Addes it to criterion node list.
    m_finalCriteria[index] = newNode;
    m_nameToNodeMap[newNode->NodeName()] = newNode;
}
// We only remove the node, not delete it.
void ComputationNetwork::RemoveFeatureNode(ComputationNodeBasePtr featureNode)
{
    InvalidateCompiledNetwork();

    wstring nodeName = featureNode->NodeName();
    if (!NodeNameExists(nodeName))
        RuntimeError("RemoveFeatureNode: feature node does not exist.");

    // Removes links.
    for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); ++nodeIter)
    {
        ComputationNodeBasePtr node = nodeIter->second;
        for (size_t i = 0; i < node->GetNumInputs(); ++i)
        {
            ComputationNodeBasePtr child = node->GetInputs()[i];
            if (child == featureNode)
            {
                node->SetInput(i, NULL);
                break;
            }
        }
    }

    // Removes from feature list.
    auto search = std::find(m_features.begin(), m_features.end(), featureNode);
    if (search != m_features.end())
        m_features.erase(search);

    m_nameToNodeMap.erase(nodeName);
}
void ComputationNetwork::RenameNode(ComputationNodeBasePtr node, const std::wstring& newNodeName)
{
    // TODO: check if new name exists
    m_nameToNodeMap.erase(node->NodeName());
    node->SetNodeName(newNodeName);
    AddNodeToNet(node);
}
// replace a named node by newNode of the same type under the same name, including moving over all network links
// This is used in 
// 1. Update nodes to quantized versions.
// 2. The KL-reg based adaptation to reduce feature copy (deprecated)
// need to update all the mappings as well childrens.
void ComputationNetwork::ReplaceNode(wstring nodeName, ComputationNodeBasePtr newNode)
{
    ComputationNodeBasePtr oldNode = GetNodeFromName(nodeName);

    if (newNode->NodeName() != nodeName) // TODO: This was not tested for earlier; I hope no code depends on this.
        InvalidArgument("ChangeNode: newNode must have the same name as the old node.");

    InvalidateCompiledNetwork();

    // change all nodes that have old node as input to point to the new node instead
    ChangeNodeInputs(oldNode, newNode);

    // change all inputs of this new node to share the old one's inputs
    for (int i = 0; i < oldNode->GetNumInputs(); i++)
    {
        newNode->SetInput(i, oldNode->GetInputs()[i]); // TODO: use AttachInput()?
        //oldNode->SetInput(i, nullptr); // BUGBUG: old node should no longer point into the network
    }

    // replace the node in the network
    RemoveNodeFromNet(oldNode);
    AddNodeToNet(newNode);

    // also update node groups
    for (auto groupIter : GetAllNodeGroups())
    {
        auto& group = *groupIter;
        for (int i = 0; i < group.size(); i++)
            if (group[i] == oldNode)
                group[i] = newNode;
    }
}
void ComputationNetwork::AddFeatureNode(ComputationNodeBasePtr featureNode)
{
    InvalidateCompiledNetwork();

    wstring nodeName = featureNode->NodeName();
    if (NodeNameExists(nodeName))
        RuntimeError("AddFeatureNode: feature node already exists.");
    m_nameToNodeMap[nodeName] = featureNode;
    m_features.push_back(featureNode);
}
Beispiel #6
0
// TODO: how does the file distinguish float vs double nodes?
void ComputationNetwork::SaveToFileImpl(const wstring& fileName, const FileOptions fileFormat) const
{
    File fstream(fileName, fileFormat | FileOptions::fileOptionsWrite);
    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCN");

    // model version
    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BVersion");
    fstream << (size_t) CURRENT_CNTK_MODEL_VERSION;
    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EVersion");

    fstream << (size_t) m_nameToNodeMap.size();

    // put all node info first
    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BNodeList");
    for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
    {
        ComputationNodeBasePtr nodePtr = nodeIter->second;
        nodePtr->Save(fstream);
    }

    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ENodeList");

    // put relationship
    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BRelation");
    for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
    {
        ComputationNodeBasePtr nodePtr = nodeIter->second;
        fstream << nodePtr->NodeName() << nodePtr->GetNumInputs();
        for (size_t i = 0; i < nodePtr->GetNumInputs(); i++)
        {
            if (!nodePtr->Input(i))
                fprintf(stderr, "Warning: node %ls 's child is null, please check your ndl/mel file.\n", nodePtr->NodeName().c_str());
            else
                fstream << nodePtr->Input(i)->NodeName();
        }
    }
    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ERelation");

    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BRootNodes");

    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BFeatureNodes");
    fstream << m_features.size();
    for (size_t i = 0; i < m_features.size(); i++)
        fstream << m_features[i]->NodeName();
    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EFeatureNodes");

    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BLabelNodes");
    fstream << m_labels.size();
    for (size_t i = 0; i < m_labels.size(); i++)
        fstream << m_labels[i]->NodeName();
    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ELabelNodes");

    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCriterionNodes");
    fstream << m_finalCriteria.size();
    for (size_t i = 0; i < m_finalCriteria.size(); i++)
        fstream << m_finalCriteria[i]->NodeName();
    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECriterionNodes");

    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BEvalNodes");
    fstream << m_evalNodes.size();
    for (size_t i = 0; i < m_evalNodes.size(); i++)
        fstream << m_evalNodes[i]->NodeName();
    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EEvalNodes");

    fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BOutputNodes");
    fstream << m_outputNodes.size();
    for (size_t i = 0; i < m_outputNodes.size(); i++)
    {
        fstream << m_outputNodes[i]->NodeName();
    }
    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EOutputNodes");

    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ERootNodes");

    fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECN");

    fstream.Flush();
}
// recovers the processing order within a recurrent loop
// TODO: Once we only use the nested network for recurrent traversal, this will be no longer necessary.
void ComputationNetwork::DetermineLoopForwardOrder(unordered_set<ComputationNodeBasePtr>& visited,
                                                   unordered_set<ComputationNodeBasePtr>& recStack,
                                                   list<ComputationNodeBasePtr>& nodesStack,
                                                   ComputationNodeBasePtr cur)
{
    if (visited.find(cur) == visited.end())
    {
        visited.insert(cur);
        recStack.insert(cur);

        if (GetRecurrenceSteppingDirection(cur) == 0) // recurrence stops at delay nodes
        {
            for (size_t i = 0; i < cur->GetNumInputs(); i++)
                if (cur->Input(i)->m_loopId == cur->m_loopId)
                    DetermineLoopForwardOrder(visited, recStack, nodesStack, cur->Input(i));
        }
        recStack.erase(cur);
        nodesStack.push_back(cur);
    }
    else if (recStack.find(cur) != recStack.end())
        LogicError("%ls %ls operation is part of an infinite loop that cannot be unrolled.", cur->NodeName().c_str(), cur->OperationName().c_str());
}