void DoTrain(const ConfigRecordType& config) { bool makeMode = config(L"makeMode", true); DEVICEID_TYPE deviceId = DeviceFromConfig(config); int traceLevel = config(L"traceLevel", 0); shared_ptr<SGD<ElemType>> optimizer; if (config.Exists(L"optimizer")) { optimizer = CreateObject<SGD<ElemType>>(config, L"optimizer"); } else // legacy CNTK config syntax: needs a record called 'SGD' { const ConfigRecordType& configSGD(config(L"SGD")); optimizer = make_shared<SGD<ElemType>>(configSGD); } // determine which epoch to start with, including recovering a checkpoint if any and 'makeMode' enabled int startEpoch = optimizer->DetermineStartEpoch(makeMode); if (startEpoch == optimizer->GetMaxEpochs()) { LOGPRINTF(stderr, "No further training is necessary.\n"); return; } wstring modelFileName = optimizer->GetModelNameForEpoch(int(startEpoch) - 1); bool loadNetworkFromCheckpoint = startEpoch >= 0; if (loadNetworkFromCheckpoint) LOGPRINTF(stderr, "\nStarting from checkpoint. Loading network from '%ls'.\n", modelFileName.c_str()); else if (traceLevel > 0) LOGPRINTF(stderr, "\nCreating virgin network.\n"); // determine the network-creation function // We have several ways to create that network. function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn; createNetworkFn = GetNetworkFactory<ConfigRecordType, ElemType>(config); // create or load from checkpoint shared_ptr<ComputationNetwork> net = !loadNetworkFromCheckpoint ? createNetworkFn(deviceId) : ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelFileName); auto dataReader = CreateObject<DataReader>(config, L"reader"); shared_ptr<DataReader> cvDataReader; if (config.Exists(L"cvReader")) cvDataReader = CreateObject<DataReader>(config, L"cvReader"); optimizer->InitMPI(MPIWrapper::GetInstance()); optimizer->Train(net, deviceId, dataReader.get(), cvDataReader.get(), startEpoch, loadNetworkFromCheckpoint); }
void LibSVMBinaryReader<ElemType>::RenamedMatrices(const ConfigRecordType& config, std::map<std::wstring, std::wstring>& rename) { for (const auto& id : config.GetMemberIds()) { if (!config.CanBeConfigRecord(id)) continue; const ConfigRecordType& temp = config(id); // see if we have a config parameters that contains a "dim" element, it's a sub key, use it if (temp.ExistsCurrent(L"rename")) { std::wstring ren = temp(L"rename"); rename.emplace(msra::strfun::utf16(id), msra::strfun::utf16(ren)); } } }
WriteFormattingOptions::WriteFormattingOptions(const ConfigRecordType& config) : WriteFormattingOptions() { // gather additional formatting options if (config.Exists(L"format")) { const ConfigRecordType& formatConfig(config(L"format", ConfigRecordType::Record())); if (formatConfig.ExistsCurrent(L"type")) // do not inherit 'type' from outer block { wstring type = formatConfig(L"type"); if (type == L"real") ; // default else if (type == L"category") isCategoryLabel = true; else if (type == L"sparse") isSparse = true; else InvalidArgument("write: type must be 'real', 'category', or 'sparse'"); labelMappingFile = (wstring)formatConfig(L"labelMappingFile", L""); } transpose = formatConfig(L"transpose", transpose); prologue = formatConfig(L"prologue", prologue); epilogue = formatConfig(L"epilogue", epilogue); sequenceSeparator = msra::strfun::utf8(formatConfig(L"sequenceSeparator", (wstring)msra::strfun::utf16(sequenceSeparator))); sequencePrologue = msra::strfun::utf8(formatConfig(L"sequencePrologue", (wstring)msra::strfun::utf16(sequencePrologue))); sequenceEpilogue = msra::strfun::utf8(formatConfig(L"sequenceEpilogue", (wstring)msra::strfun::utf16(sequenceEpilogue))); elementSeparator = msra::strfun::utf8(formatConfig(L"elementSeparator", (wstring)msra::strfun::utf16(elementSeparator))); sampleSeparator = msra::strfun::utf8(formatConfig(L"sampleSeparator", (wstring)msra::strfun::utf16(sampleSeparator))); precisionFormat = msra::strfun::utf8(formatConfig(L"precisionFormat", (wstring)msra::strfun::utf16(precisionFormat))); // TODO: change those strings into wstrings to avoid this conversion mess } }
void LibSVMBinaryReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig) { std::map<std::wstring, std::wstring> rename; RenamedMatrices(readerConfig, rename); if (readerConfig.Exists(L"randomize")) { string randomizeString = readerConfig(L"randomize"); if (randomizeString == "None") { m_randomize = 0L; } else if (randomizeString == "Auto") { time_t rawtime; struct tm* timeinfo; time(&rawtime); timeinfo = localtime(&rawtime); m_randomize = (unsigned long) (timeinfo->tm_sec + timeinfo->tm_min * 60 + timeinfo->tm_hour * 60 * 60); } else { m_randomize = readerConfig(L"randomize", 0); } } else { m_randomize = 0L; } m_partialMinibatch = true; std::string minibatchMode(readerConfig(L"minibatchMode", "Partial")); m_partialMinibatch = EqualCI(minibatchMode, "Partial"); std::wstring file = readerConfig(L"file", L""); m_dataInput = make_shared<SparseBinaryInput<ElemType>>(file); m_dataInput->Init(rename); m_mbSize = (size_t) readerConfig(L"minibatch", 0); if (m_mbSize > 0) { if (m_dataInput->GetMBSize() != m_mbSize) { RuntimeError("Data file and config file have mismatched minibatch sizes.\n"); return; } } else { m_mbSize = m_dataInput->GetMBSize(); } m_prefetchEnabled = true; }
bool TryGetNetworkFactory(const ConfigRecordType& config, function<ComputationNetworkPtr(DEVICEID_TYPE)>& createNetworkFn) { DEVICEID_TYPE deviceId = DeviceFromConfig(config); int traceLevel = config(L"traceLevel", 0); if (config.Exists(L"createNetwork")) { createNetworkFn = GetCreateNetworkFn(config); // (we need a separate function needed due to template code) return true; } else if (config.Exists(L"SimpleNetworkBuilder")) { const ConfigRecordType& simpleNetworkBuilderConfig(config(L"SimpleNetworkBuilder")); auto netBuilder = make_shared<SimpleNetworkBuilder<ElemType>>(simpleNetworkBuilderConfig); // parses the configuration and stores it in the SimpleNetworkBuilder object createNetworkFn = [netBuilder, traceLevel](DEVICEID_TYPE deviceId) { auto net = shared_ptr<ComputationNetwork>(netBuilder->BuildNetworkFromDescription()); // this operates based on the configuration saved above net->SetTraceLevel(traceLevel); return net; }; return true; } // legacy NDL else if (config.Exists(L"NDLNetworkBuilder")) { const ConfigRecordType& ndlNetworkBuilderConfig(config(L"NDLNetworkBuilder")); shared_ptr<NDLBuilder<ElemType>> netBuilder = make_shared<NDLBuilder<ElemType>>(ndlNetworkBuilderConfig); createNetworkFn = [netBuilder, traceLevel](DEVICEID_TYPE deviceId) { auto net = shared_ptr<ComputationNetwork>(netBuilder->BuildNetworkFromDescription()); net->SetTraceLevel(traceLevel); return net; }; return true; } // legacy test mode for BrainScript. Will go away once we fully integrate with BS. else if (config.Exists(L"BrainScriptNetworkBuilder") || config.Exists(L"ExperimentalNetworkBuilder" /*legacy name*/)) { // We interface with outer old CNTK config by taking the inner part, which we get as a string, as BrainScript. // We prepend a few standard definitions, and also definition of deviceId and precision, which all objects will pull out again when they are being constructed. // BUGBUG: We are not getting TextLocations right in this way! Do we need to inject location markers into the source? Moot once we fully switch to BS wstring sourceOfNetwork = config.Exists(L"BrainScriptNetworkBuilder") ? config(L"BrainScriptNetworkBuilder") : config(L"ExperimentalNetworkBuilder"); if (sourceOfNetwork.find_first_of(L"([{") != 0) InvalidArgument("BrainScript network description must be either a BS expression in ( ) or a config record in { }"); // set the include paths to all paths that configs were read from; no additional configurable include paths are supported by BrainScriptNetworkBuilder auto includePaths = ConfigParameters::GetBrainScriptNetworkBuilderIncludePaths(); // inject additional items into the source code // We support two ways of specifying the network in BrainScript: // - BrainScriptNetworkBuilder = ( any BS expression that evaluates to a ComputationNetwork ) // - BrainScriptNetworkBuilder = { constructor parameters for a ComputationNetwork } // For back-compat, [ ] is allowed and means the same as { } if (sourceOfNetwork[0] == '{' || sourceOfNetwork[0] == '[') // if { } form then we turn it into ComputationNetwork by constructing a ComputationNetwork from it sourceOfNetwork = L"new ComputationNetwork " + sourceOfNetwork; let sourceOfBS = msra::strfun::wstrprintf(L"include \'cntk.core.bs\'\n" // include our core lib. Note: Using lowercase here to match the Linux name of the CNTK exe. L"deviceId = %d\n" // deviceId as passed in L"traceLevel = %d\n" L"precision = '%ls'\n" // 'float' or 'double' L"network = %ls", // source code of expression that evaluates to a ComputationNetwork (int)deviceId, traceLevel, ElemTypeName<ElemType>(), sourceOfNetwork.c_str()); let expr = BS::ParseConfigDictFromString(sourceOfBS, L"BrainScriptNetworkBuilder", move(includePaths)); // the rest is done in a lambda that is only evaluated when a virgin network is needed // Note that evaluating the BrainScript *is* instantiating the network, so the evaluate call must be inside the lambda. createNetworkFn = [expr](DEVICEID_TYPE /*deviceId*/) { // evaluate the parse tree, particularly the top-level field 'network' // Evaluating it will create the network. let object = EvaluateField(expr, L"network"); // this comes back as a BS::Object let network = dynamic_pointer_cast<ComputationNetwork>(object); // cast it if (!network) LogicError("BuildNetworkFromDescription: ComputationNetwork not what it was meant to be"); return network; }; return true; } else return false; }
void DSSMReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig) { std::vector<std::wstring> features; std::vector<std::wstring> labels; // Determine the names of the features and lables sections in the config file. // features - [in,out] a vector of feature name strings // labels - [in,out] a vector of label name strings // For DSSM dataset, we only need features. No label is necessary. The following "labels" just serves as a place holder GetFileConfigNames(readerConfig, features, labels); // For DSSM dataset, it must have exactly two features // In the config file, we must specify query features first, then document features. The sequence is different here. Pay attention if (features.size() == 2 && labels.size() == 1) { m_featuresNameQuery = features[1]; m_featuresNameDoc = features[0]; m_labelsName = labels[0]; } else { RuntimeError("DSSM requires exactly two features and one label. Their names should match those in NDL definition"); return; } m_mbStartSample = m_epoch = m_totalSamples = m_epochStartSample = 0; m_labelIdMax = m_labelDim = 0; m_partialMinibatch = m_endReached = false; m_labelType = labelCategory; m_readNextSample = 0; m_traceLevel = readerConfig(L"traceLevel", 0); if (readerConfig.Exists(L"randomize")) { // BUGBUG: reading out string and number... ugh wstring randomizeString = readerConfig(L"randomize"); if (randomizeString == L"None") { m_randomizeRange = randomizeNone; } else if (randomizeString == L"Auto") { m_randomizeRange = randomizeAuto; } else { m_randomizeRange = readerConfig(L"randomize"); } } else { m_randomizeRange = randomizeNone; } std::string minibatchMode(readerConfig(L"minibatchMode", "Partial")); m_partialMinibatch = EqualCI(minibatchMode, "Partial"); // Get the config parameters for query feature and doc feature ConfigParameters configFeaturesQuery = readerConfig(m_featuresNameQuery, ""); ConfigParameters configFeaturesDoc = readerConfig(m_featuresNameDoc, ""); if (configFeaturesQuery.size() == 0) RuntimeError("features file not found, required in configuration: i.e. 'features=[file=c:\\myfile.txt;start=1;dim=123]'"); if (configFeaturesDoc.size() == 0) RuntimeError("features file not found, required in configuration: i.e. 'features=[file=c:\\myfile.txt;start=1;dim=123]'"); // Read in feature size information // This information will be used to handle OOVs m_featuresDimQuery = configFeaturesQuery(L"dim"); m_featuresDimDoc = configFeaturesDoc(L"dim"); std::wstring fileQ = configFeaturesQuery("file"); std::wstring fileD = configFeaturesDoc("file"); dssm_queryInput.Init(fileQ, m_featuresDimQuery); dssm_docInput.Init(fileD, m_featuresDimDoc); m_totalSamples = dssm_queryInput.numRows; if (read_order == NULL) { read_order = new int[m_totalSamples]; for (int c = 0; c < m_totalSamples; c++) { read_order[c] = c; } } m_mbSize = 0; }