Beispiel #1
0
void DoTrain(const ConfigRecordType& config)
{
    bool makeMode = config(L"makeMode", true);
    DEVICEID_TYPE deviceId = DeviceFromConfig(config);
    int traceLevel = config(L"traceLevel", 0);

    shared_ptr<SGD<ElemType>> optimizer;
    if (config.Exists(L"optimizer"))
    {
        optimizer = CreateObject<SGD<ElemType>>(config, L"optimizer");
    }
    else // legacy CNTK config syntax: needs a record called 'SGD'
    {
        const ConfigRecordType& configSGD(config(L"SGD"));
        optimizer = make_shared<SGD<ElemType>>(configSGD);
    }

    // determine which epoch to start with, including recovering a checkpoint if any and 'makeMode' enabled
    int startEpoch = optimizer->DetermineStartEpoch(makeMode);
    if (startEpoch == optimizer->GetMaxEpochs())
    {
        LOGPRINTF(stderr, "No further training is necessary.\n");
        return;
    }

    wstring modelFileName = optimizer->GetModelNameForEpoch(int(startEpoch) - 1);
    bool loadNetworkFromCheckpoint = startEpoch >= 0;
    if (loadNetworkFromCheckpoint)
        LOGPRINTF(stderr, "\nStarting from checkpoint. Loading network from '%ls'.\n", modelFileName.c_str());
    else if (traceLevel > 0)
        LOGPRINTF(stderr, "\nCreating virgin network.\n");

    // determine the network-creation function
    // We have several ways to create that network.
    function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn;

    createNetworkFn = GetNetworkFactory<ConfigRecordType, ElemType>(config);

    // create or load from checkpoint
    shared_ptr<ComputationNetwork> net = !loadNetworkFromCheckpoint ? createNetworkFn(deviceId) : ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelFileName);

    auto dataReader = CreateObject<DataReader>(config, L"reader");

    shared_ptr<DataReader> cvDataReader;
    if (config.Exists(L"cvReader"))
        cvDataReader = CreateObject<DataReader>(config, L"cvReader");

    optimizer->InitMPI(MPIWrapper::GetInstance());
    optimizer->Train(net, deviceId, dataReader.get(), cvDataReader.get(), startEpoch, loadNetworkFromCheckpoint);
}
Beispiel #2
0
void LibSVMBinaryReader<ElemType>::RenamedMatrices(const ConfigRecordType& config, std::map<std::wstring, std::wstring>& rename)
{
    for (const auto& id : config.GetMemberIds())
    {
        if (!config.CanBeConfigRecord(id))
            continue;
        const ConfigRecordType& temp = config(id);
        // see if we have a config parameters that contains a "dim" element, it's a sub key, use it
        if (temp.ExistsCurrent(L"rename"))
        {

            std::wstring ren = temp(L"rename");
            rename.emplace(msra::strfun::utf16(id), msra::strfun::utf16(ren));
        }
    }
}
Beispiel #3
0
WriteFormattingOptions::WriteFormattingOptions(const ConfigRecordType& config) :
    WriteFormattingOptions()
{
    // gather additional formatting options
    if (config.Exists(L"format"))
    {
        const ConfigRecordType& formatConfig(config(L"format", ConfigRecordType::Record()));
        if (formatConfig.ExistsCurrent(L"type")) // do not inherit 'type' from outer block
        {
            wstring type = formatConfig(L"type");
            if      (type == L"real")     ; // default
            else if (type == L"category") isCategoryLabel = true;
            else if (type == L"sparse")   isSparse = true;
            else                         InvalidArgument("write: type must be 'real', 'category', or 'sparse'");
            labelMappingFile = (wstring)formatConfig(L"labelMappingFile", L"");
        }
        transpose = formatConfig(L"transpose", transpose);
        prologue  = formatConfig(L"prologue",  prologue);
        epilogue  = formatConfig(L"epilogue",  epilogue);
        sequenceSeparator = msra::strfun::utf8(formatConfig(L"sequenceSeparator", (wstring)msra::strfun::utf16(sequenceSeparator)));
        sequencePrologue  = msra::strfun::utf8(formatConfig(L"sequencePrologue",  (wstring)msra::strfun::utf16(sequencePrologue)));
        sequenceEpilogue  = msra::strfun::utf8(formatConfig(L"sequenceEpilogue",  (wstring)msra::strfun::utf16(sequenceEpilogue)));
        elementSeparator  = msra::strfun::utf8(formatConfig(L"elementSeparator",  (wstring)msra::strfun::utf16(elementSeparator)));
        sampleSeparator   = msra::strfun::utf8(formatConfig(L"sampleSeparator",   (wstring)msra::strfun::utf16(sampleSeparator)));
        precisionFormat   = msra::strfun::utf8(formatConfig(L"precisionFormat",   (wstring)msra::strfun::utf16(precisionFormat)));
        // TODO: change those strings into wstrings to avoid this conversion mess
    }
}
Beispiel #4
0
void LibSVMBinaryReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig)
{

    std::map<std::wstring, std::wstring> rename;
    RenamedMatrices(readerConfig, rename);

    if (readerConfig.Exists(L"randomize"))
    {
        string randomizeString = readerConfig(L"randomize");
        if (randomizeString == "None")
        {
            m_randomize = 0L;
        }
        else if (randomizeString == "Auto")
        {
            time_t rawtime;
            struct tm* timeinfo;
            time(&rawtime);
            timeinfo = localtime(&rawtime);
            m_randomize = (unsigned long) (timeinfo->tm_sec + timeinfo->tm_min * 60 + timeinfo->tm_hour * 60 * 60);
        }
        else
        {
            m_randomize = readerConfig(L"randomize", 0);
        }
    }
    else
    {
        m_randomize = 0L;
    }

    m_partialMinibatch = true;
    std::string minibatchMode(readerConfig(L"minibatchMode", "Partial"));
    m_partialMinibatch = EqualCI(minibatchMode, "Partial");

    std::wstring file = readerConfig(L"file", L"");

    m_dataInput = make_shared<SparseBinaryInput<ElemType>>(file);
    m_dataInput->Init(rename);

    m_mbSize = (size_t) readerConfig(L"minibatch", 0);
    if (m_mbSize > 0)
    {
        if (m_dataInput->GetMBSize() != m_mbSize)
        {
            RuntimeError("Data file and config file have mismatched minibatch sizes.\n");
            return;
        }
    }
    else
    {
        m_mbSize = m_dataInput->GetMBSize();
    }

    m_prefetchEnabled = true;
}
Beispiel #5
0
bool TryGetNetworkFactory(const ConfigRecordType& config, function<ComputationNetworkPtr(DEVICEID_TYPE)>& createNetworkFn)
{
    DEVICEID_TYPE deviceId = DeviceFromConfig(config);

    int traceLevel = config(L"traceLevel", 0);
    if (config.Exists(L"createNetwork"))
    {
        createNetworkFn = GetCreateNetworkFn(config); // (we need a separate function needed due to template code)
        return true;
    }
    else if (config.Exists(L"SimpleNetworkBuilder"))
    {
        const ConfigRecordType& simpleNetworkBuilderConfig(config(L"SimpleNetworkBuilder"));
        auto netBuilder = make_shared<SimpleNetworkBuilder<ElemType>>(simpleNetworkBuilderConfig); // parses the configuration and stores it in the SimpleNetworkBuilder object
        createNetworkFn = [netBuilder, traceLevel](DEVICEID_TYPE deviceId)
        {
            auto net = shared_ptr<ComputationNetwork>(netBuilder->BuildNetworkFromDescription()); // this operates based on the configuration saved above
            net->SetTraceLevel(traceLevel);
            return net;
        };
        return true;
    }
    // legacy NDL
    else if (config.Exists(L"NDLNetworkBuilder"))
    {
        const ConfigRecordType& ndlNetworkBuilderConfig(config(L"NDLNetworkBuilder"));
        shared_ptr<NDLBuilder<ElemType>> netBuilder = make_shared<NDLBuilder<ElemType>>(ndlNetworkBuilderConfig);
        createNetworkFn = [netBuilder, traceLevel](DEVICEID_TYPE deviceId)
        {
            auto net = shared_ptr<ComputationNetwork>(netBuilder->BuildNetworkFromDescription());
            net->SetTraceLevel(traceLevel);
            return net;
        };
        return true;
    }
    // legacy test mode for BrainScript. Will go away once we fully integrate with BS.
    else if (config.Exists(L"BrainScriptNetworkBuilder") || config.Exists(L"ExperimentalNetworkBuilder" /*legacy name*/))
    {
        // We interface with outer old CNTK config by taking the inner part, which we get as a string, as BrainScript.
        // We prepend a few standard definitions, and also definition of deviceId and precision, which all objects will pull out again when they are being constructed.
        // BUGBUG: We are not getting TextLocations right in this way! Do we need to inject location markers into the source? Moot once we fully switch to BS
        wstring sourceOfNetwork = config.Exists(L"BrainScriptNetworkBuilder") ? config(L"BrainScriptNetworkBuilder") : config(L"ExperimentalNetworkBuilder");
        if (sourceOfNetwork.find_first_of(L"([{") != 0)
            InvalidArgument("BrainScript network description must be either a BS expression in ( ) or a config record in { }");

        // set the include paths to all paths that configs were read from; no additional configurable include paths are supported by BrainScriptNetworkBuilder
        auto includePaths = ConfigParameters::GetBrainScriptNetworkBuilderIncludePaths();

        // inject additional items into the source code
        // We support two ways of specifying the network in BrainScript:
        //  - BrainScriptNetworkBuilder = ( any BS expression that evaluates to a ComputationNetwork )
        //  - BrainScriptNetworkBuilder = { constructor parameters for a ComputationNetwork }
        // For back-compat, [ ] is allowed and means the same as { }
        if (sourceOfNetwork[0] == '{' || sourceOfNetwork[0] == '[') // if { } form then we turn it into ComputationNetwork by constructing a ComputationNetwork from it
            sourceOfNetwork = L"new ComputationNetwork " + sourceOfNetwork;
        let sourceOfBS = msra::strfun::wstrprintf(L"include \'cntk.core.bs\'\n" // include our core lib. Note: Using lowercase here to match the Linux name of the CNTK exe.
            L"deviceId = %d\n"            // deviceId as passed in
            L"traceLevel = %d\n"
            L"precision = '%ls'\n"        // 'float' or 'double'
            L"network = %ls",             // source code of expression that evaluates to a ComputationNetwork
            (int)deviceId, traceLevel, ElemTypeName<ElemType>(), sourceOfNetwork.c_str());
        let expr = BS::ParseConfigDictFromString(sourceOfBS, L"BrainScriptNetworkBuilder", move(includePaths));

        // the rest is done in a lambda that is only evaluated when a virgin network is needed
        // Note that evaluating the BrainScript *is* instantiating the network, so the evaluate call must be inside the lambda.
        createNetworkFn = [expr](DEVICEID_TYPE /*deviceId*/)
        {
            // evaluate the parse tree, particularly the top-level field 'network'
            // Evaluating it will create the network.
            let object = EvaluateField(expr, L"network");                   // this comes back as a BS::Object
            let network = dynamic_pointer_cast<ComputationNetwork>(object); // cast it
            if (!network)
                LogicError("BuildNetworkFromDescription: ComputationNetwork not what it was meant to be");
            return network;
        };
        return true;
    }
    else
        return false;
}
Beispiel #6
0
void DSSMReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig)
{
    std::vector<std::wstring> features;
    std::vector<std::wstring> labels;

    // Determine the names of the features and lables sections in the config file.
    // features - [in,out] a vector of feature name strings
    // labels - [in,out] a vector of label name strings
    // For DSSM dataset, we only need features. No label is necessary. The following "labels" just serves as a place holder
    GetFileConfigNames(readerConfig, features, labels);

    // For DSSM dataset, it must have exactly two features
    // In the config file, we must specify query features first, then document features. The sequence is different here. Pay attention
    if (features.size() == 2 && labels.size() == 1)
    {
        m_featuresNameQuery = features[1];
        m_featuresNameDoc = features[0];
        m_labelsName = labels[0];
    }
    else
    {
        RuntimeError("DSSM requires exactly two features and one label. Their names should match those in NDL definition");
        return;
    }

    m_mbStartSample = m_epoch = m_totalSamples = m_epochStartSample = 0;
    m_labelIdMax = m_labelDim = 0;
    m_partialMinibatch = m_endReached = false;
    m_labelType = labelCategory;
    m_readNextSample = 0;
    m_traceLevel = readerConfig(L"traceLevel", 0);

    if (readerConfig.Exists(L"randomize"))
    {
        // BUGBUG: reading out string and number... ugh
        wstring randomizeString = readerConfig(L"randomize");
        if (randomizeString == L"None")
        {
            m_randomizeRange = randomizeNone;
        }
        else if (randomizeString == L"Auto")
        {
            m_randomizeRange = randomizeAuto;
        }
        else
        {
            m_randomizeRange = readerConfig(L"randomize");
        }
    }
    else
    {
        m_randomizeRange = randomizeNone;
    }

    std::string minibatchMode(readerConfig(L"minibatchMode", "Partial"));
    m_partialMinibatch = EqualCI(minibatchMode, "Partial");

    // Get the config parameters for query feature and doc feature
    ConfigParameters configFeaturesQuery = readerConfig(m_featuresNameQuery, "");
    ConfigParameters configFeaturesDoc   = readerConfig(m_featuresNameDoc, "");

    if (configFeaturesQuery.size() == 0)
        RuntimeError("features file not found, required in configuration: i.e. 'features=[file=c:\\myfile.txt;start=1;dim=123]'");
    if (configFeaturesDoc.size() == 0)
        RuntimeError("features file not found, required in configuration: i.e. 'features=[file=c:\\myfile.txt;start=1;dim=123]'");

    // Read in feature size information
    // This information will be used to handle OOVs
    m_featuresDimQuery = configFeaturesQuery(L"dim");
    m_featuresDimDoc   = configFeaturesDoc(L"dim");

    std::wstring fileQ = configFeaturesQuery("file");
    std::wstring fileD = configFeaturesDoc("file");

    dssm_queryInput.Init(fileQ, m_featuresDimQuery);
    dssm_docInput.Init(fileD, m_featuresDimDoc);

    m_totalSamples = dssm_queryInput.numRows;
    if (read_order == NULL)
    {
        read_order = new int[m_totalSamples];
        for (int c = 0; c < m_totalSamples; c++)
        {
            read_order[c] = c;
        }
    }
    m_mbSize = 0;
}