Beispiel #1
0
void DoCommands(const ConfigParameters& config, const shared_ptr<MPIWrapper>& mpi)
{
    ConfigArray command = config(L"command", "train");

    if (Globals::ShouldForceDeterministicAlgorithms())
        ForceDeterministicAlgorithmsOnCPU();
    else
    {
        // Setting specified number of threads.
        int numCPUThreads = config(L"numCPUThreads", "0");
        numCPUThreads = CPUMatrix<ElemType>::SetNumThreads(numCPUThreads);
        if (numCPUThreads > 0)
        {
            LOGPRINTF(stderr, "Using %d CPU threads.\n", numCPUThreads);
        }
    }

    bool progressTracing = config(L"progressTracing", false);

    // temporary hack to prevent users from failing due to a small breaking change related to the "truncated" flag (will be redone bigger and better some day)
    DisableLegacyUsage(config, command);

    // summarize command info upfront in the log and stdout
    size_t fullTotalMaxEpochs = 0;
    for (int i = 0; i < command.size(); i++)
    {
        // get the configuration parameters that match the command
        ConfigParameters commandParams(config(command[i]));
        ConfigArray action = commandParams("action", "train");

        // determine the action to perform, and do it
        for (int j = 0; j < action.size(); j++)
        {
            if (action[j] == "train" || action[j] == "trainRNN")
            {
                wstring modelPath = commandParams("modelPath");
                size_t maxEpochs = GetMaxEpochs(commandParams);
                if (progressTracing)
                {
                    LOGPRINTF(stderr, "CNTKModelPath: %ls\n", modelPath.c_str());
                    LOGPRINTF(stderr, "CNTKCommandTrainInfo: %s : %d\n", command[i].c_str(), (int)maxEpochs);
                }
                fullTotalMaxEpochs += maxEpochs;
            }
        }
    }
    if (progressTracing)
    {
        LOGPRINTF(stderr, "CNTKCommandTrainInfo: CNTKNoMoreCommands_Total : %d\n", (int)fullTotalMaxEpochs);
    }

    // set up progress tracing for compute cluster management
    if (progressTracing && (!mpi || mpi->IsMainNode()))
    {
        ProgressTracing::SetTracingFlag();
        ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
    }

    size_t fullEpochsOffset = 0;

    // execute the commands
    for (int i = 0; i < command.size(); i++)
    {
        // get the configuration parameters that match the command
        const string thisCommand = command[i];
        ConfigParameters commandParams(config(thisCommand));
        ConfigArray action = commandParams("action", "train");
        int traceLevel = commandParams("traceLevel", "0");

        if (progressTracing && ((mpi == nullptr) || mpi->IsMainNode()))
        {
            ProgressTracing::SetStepOffset(fullEpochsOffset); // this is the epoch number that SGD will log relative to
        }

        // determine the action to perform, and do it
        for (int j = 0; j < action.size(); j++)
        {
            const string thisAction = action[j];

            // print a banner to visually separate each action in the log
            const char* delim = "##############################################################################";
            string showActionAs = thisCommand + " command (" + thisAction + " action)";
            fprintf(stderr, "\n");
            LOGPRINTF(stderr, "%s\n", delim);
            LOGPRINTF(stderr, "#%*s#\n", (int)(strlen(delim) - 2), "");
            LOGPRINTF(stderr, "# %s%*s #\n", showActionAs.c_str(), (int)(strlen(delim) - showActionAs.size() - 4), "");
            LOGPRINTF(stderr, "#%*s#\n", (int)(strlen(delim) - 2), "");
            LOGPRINTF(stderr, "%s\n\n", delim);

            if ((mpi == nullptr) || (commandstoRunOnAllRanks.find(thisAction) != commandstoRunOnAllRanks.end()) || mpi->IsMainNode())
            {
                if (thisAction == "train" || thisAction == "trainRNN")
                {
                    if (progressTracing)
                    {
                        LOGPRINTF(stderr, "CNTKCommandTrainBegin: %s\n", command[i].c_str());
                    }
                    DoTrain<ConfigParameters, ElemType>(commandParams);
                    if (progressTracing)
                    {
                        LOGPRINTF(stderr, "CNTKCommandTrainEnd: %s\n", command[i].c_str());
                    }
                    fullEpochsOffset += GetMaxEpochs(commandParams);
                }
                // TODO: Choose a clearer name.
                else if (thisAction == "pbn")
                {
                    DoEvalBN<ElemType>(commandParams);
                }
                else if (thisAction == "adapt")
                {
                    DoAdapt<ElemType>(commandParams);
                }
                else if (thisAction == "test" || thisAction == "eval")
                {
                    DoEval<ElemType>(commandParams);
                }
                else if (thisAction == "edit")
                {
                    DoEdit<ElemType>(commandParams);
                }
                else if (thisAction == "cv")
                {
                    DoCrossValidate<ElemType>(commandParams);
                }
                else if (thisAction == "write")
                {
                    DoWriteOutput<ElemType>(commandParams);
                }
                else if (thisAction == "devtest")
                {
                    TestCn<ElemType>(config); // for "devtest" action pass the root config instead
                }
                else if (thisAction == "dumpNodes" /*deprecated:*/ || thisAction == "dumpNode" || thisAction == "dumpnode")
                {
                    DoDumpNodes<ElemType>(commandParams);
                }
                else if (thisAction == "convertdbn")
                {
                    DoConvertFromDbn<ElemType>(commandParams);
                }
                else if (thisAction == "exportdbn")
                {
                    DoExportToDbn<ElemType>(commandParams);
                }
                else if (thisAction == "createLabelMap")
                {
                    DoCreateLabelMap<ElemType>(commandParams);
                }
                else if (thisAction == "writeWordAndClass")
                {
                    DoWriteWordAndClassInfo<ElemType>(commandParams);
                }
                else if (thisAction == "plot")
                {
                    DoTopologyPlot<ElemType>(commandParams);
                }
                else if (thisAction == "SVD")
                {
                    DoParameterSVD<ElemType>(commandParams);
                }
                else
                {
                    RuntimeError("unknown action: %s  in command set: %s", thisAction.c_str(), command[i].c_str());
                }
            }

            fprintf(stderr, "\n");
            if (traceLevel > 0)
            {
                LOGPRINTF(stderr, "Action \"%s\" complete.\n\n", thisAction.c_str());
            }

            NDLScript<ElemType> ndlScript;
            ndlScript.ClearGlobal(); // clear global macros between commands

            // Synchronize all ranks before proceeding to next action/command
            if (mpi)
                mpi->WaitAll();
        }
    }
}
Beispiel #2
0
void DoCommands(const ConfigParameters& config)
{
    ConfigArray command = config(L"command", "train");

    int numCPUThreads = config(L"numCPUThreads", "0");
    numCPUThreads = CPUMatrix<ElemType>::SetNumThreads(numCPUThreads);

    if (numCPUThreads > 0)
    {
        std::cerr << "Using " << numCPUThreads << " CPU threads" << endl;
    }

    bool progressTracing = config(L"progressTracing", false);

    // temporary hack to prevent users from failling for a small breaking change related to the "truncated" flag (will be redone bigger and better some day)
    DisableLegacyUsage(config, command);

    // summarize command info upfront in the log and stdout
    size_t fullTotalMaxEpochs = 0;
    for (int i = 0; i < command.size(); i++)
    {
        // get the configuration parameters that match the command
        ConfigParameters commandParams(config(command[i]));
        ConfigArray action = commandParams("action", "train");

        // determine the action to perform, and do it
        for (int j = 0; j < action.size(); j++)
        {
            if (action[j] == "train" || action[j] == "trainRNN")
            {
                wstring modelPath = commandParams("modelPath");
                std::wcerr << "CNTKModelPath: " << modelPath << endl;
                size_t maxEpochs = GetMaxEpochs(commandParams);
                std::cerr << "CNTKCommandTrainInfo: " + command[i] << " : " << maxEpochs << endl;
                fullTotalMaxEpochs += maxEpochs;
            }
        }
    }
    std::cerr << "CNTKCommandTrainInfo: CNTKNoMoreCommands_Total : " << fullTotalMaxEpochs << endl;

    // set up progress tracing for compute cluster management
    if (progressTracing && ((g_mpi == nullptr) || g_mpi->IsMainNode()))
    {
        ProgressTracing::TraceTotalNumberOfSteps(fullTotalMaxEpochs); // enable tracing, using this as the total number of epochs
    }

    size_t fullEpochsOffset = 0;

    // execute the commands
    for (int i = 0; i < command.size(); i++)
    {
        // get the configuration parameters that match the command
        ConfigParameters commandParams(config(command[i]));
        ConfigArray action = commandParams("action", "train");

        if (progressTracing && ((g_mpi == nullptr) || g_mpi->IsMainNode()))
        {
            ProgressTracing::SetStepOffset(fullEpochsOffset); // this is the epoch number that SGD will log relative to
        }

        // determine the action to perform, and do it
        for (int j = 0; j < action.size(); j++)
        {
            if (action[j] == "train" || action[j] == "trainRNN")
            {
                std::cerr << "CNTKCommandTrainBegin: " + command[i] << endl;
                DoTrain<ConfigParameters, ElemType>(commandParams);
                std::cerr << "CNTKCommandTrainEnd: " + command[i] << endl;
                fullEpochsOffset += GetMaxEpochs(commandParams);
            }
            else if (action[j] == "adapt")
            {
                DoAdapt<ElemType>(commandParams);
            }
            else if (action[j] == "test" || action[j] == "eval")
            {
                DoEval<ElemType>(commandParams);
            }
            else if (action[j] == "edit")
            {
                DoEdit<ElemType>(commandParams);
            }
            else if (action[j] == "cv")
            {
                DoCrossValidate<ElemType>(commandParams);
            }
            else if (action[j] == "write")
            {
                DoWriteOutput<ElemType>(commandParams);
            }
            else if (action[j] == "devtest")
            {
                TestCn<ElemType>(config); // for "devtest" action pass the root config instead
            }
            else if (action[j] == "dumpnode")
            {
                DumpNodeInfo<ElemType>(commandParams);
            }
            else if (action[j] == "convertdbn")
            {
                DoConvertFromDbn<ElemType>(commandParams);
            }
            else if (action[j] == "createLabelMap")
            {
                DoCreateLabelMap<ElemType>(commandParams);
            }
            else if (action[j] == "writeWordAndClass")
            {
                DoWriteWordAndClassInfo<ElemType>(commandParams);
            }
            else if (action[j] == "plot")
            {
                DoTopologyPlot<ElemType>(commandParams);
            }
            else if (action[j] == "SVD")
            {
                DoParameterSVD<ElemType>(commandParams);
            }
            else
            {
                RuntimeError("unknown action: %s  in command set: %s", action[j].c_str(), command[i].c_str());
            }

            NDLScript<ElemType> ndlScript;
            ndlScript.ClearGlobal(); // clear global macros between commands
        }
    }
}