Deserializer HTKFeatureDeserializer(const std::vector<HTKFeatureConfiguration>& streams) { Deserializer htk; Dictionary input; for (const auto& s : streams) { const auto& key = s.m_streamName; Dictionary stream; std::vector<DictionaryValue> ctxWindow = { DictionaryValue(s.m_left), DictionaryValue(s.m_right) }; stream.Add(L"scpFile", s.m_scp, L"dim", s.m_dim, L"contextWindow", ctxWindow, L"expandToUtterance", s.m_broadcast); stream[L"definesMBSize"] = s.m_definesMbSize; input[key] = stream; } htk.Add(L"type", L"HTKFeatureDeserializer", L"input", input); return htk; }
/*virtual*/ Dictionary TrainingParameterSchedule<T>::Serialize() const { Dictionary schedule; for (const auto& it : m_schedule) { schedule[std::to_wstring(it.first)] = DictionaryValue(it.second); } Dictionary dict; dict[versionKey] = CurrentVersion(); dict[typeKey] = s_trainingParameterScheduleTypeValue; dict[epochSizeKey] = m_epochSize; dict[unitKey] = static_cast<size_t>(m_unit); dict[scheduleKey] = schedule; return dict; }
void Trainer::Save(const std::wstring& modelFilePath, bool usinglegacyModelFormat, const Dictionary& distributedLearnerState) { vector<DictionaryValue> learnerStates; for (const auto& learner : m_parameterLearners) { // TODO: add DictionaryValue(T&&) learnerStates.push_back(DictionaryValue(learner->Serialize())); } // add DictionaryValue ctor that takes an rvalue! Dictionary state; state[learnersPropertyName] = learnerStates; state[distributedLearnerPropertyName] = distributedLearnerState; m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat); std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath); auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false); *ckpStream << state; ckpStream->flush(); }
/*virtual*/ Dictionary Variable::Serialize() const { if (IsOutput()) { LogicError("Output variables cannot be saved"); } Dictionary dict; dict[versionKey] = CurrentVersion(); dict[typeKey] = s_variableTypeValue; dict[uidKey] = Uid(); dict[kindKey] = static_cast<size_t>(Kind()); dict[dataTypeKey] = static_cast<size_t>(GetDataType()); const auto& dynamicAxes = DynamicAxes(); vector<DictionaryValue> dictionaryValueVector; dictionaryValueVector.reserve(dynamicAxes.size()); for (const auto& axis : dynamicAxes) dictionaryValueVector.push_back(axis); dict[dynamicAxisKey] = dictionaryValueVector; dict[isSparseKey] = IsSparse(); if (!Name().empty()) dict[nameKey] = Name(); dict[needsGradientKey] = NeedsGradient(); dict[shapeKey] = Shape(); if (IsParameter() || IsConstant()) { NDArrayView* value = Value().get(); if (value == nullptr) { LogicError("Uninitialized Parameter variable cannot be saved"); } // TODO: add a dictionary value constructor with an rvalue parameter. dict[valueKey] = DictionaryValue(*value); } return dict; }
Dictionary ToDictionary(const ::CNTK::MinibatchSourceConfig& configuration) { Validate(configuration); Dictionary augmentedConfiguration; if (configuration.randomizationWindowInSamples != 0) { augmentedConfiguration[L"randomize"] = true; augmentedConfiguration[L"randomizationWindow"] = configuration.randomizationWindowInSamples; augmentedConfiguration[L"sampleBasedRandomizationWindow"] = true; augmentedConfiguration[L"randomizationSeed"] = configuration.randomizationSeed; } else if (configuration.randomizationWindowInChunks != 0) { augmentedConfiguration[L"randomize"] = true; augmentedConfiguration[L"randomizationWindow"] = configuration.randomizationWindowInChunks; augmentedConfiguration[L"sampleBasedRandomizationWindow"] = false; augmentedConfiguration[L"randomizationSeed"] = configuration.randomizationSeed; } else { augmentedConfiguration[L"randomize"] = false; } if (configuration.truncationLength != 0) { augmentedConfiguration[L"truncated"] = true; augmentedConfiguration[L"truncationLength"] = configuration.truncationLength; } augmentedConfiguration[L"frameMode"] = configuration.isFrameModeEnabled; augmentedConfiguration[L"traceLevel"] = static_cast<size_t>(configuration.traceLevel); bool defaultMultithreaded = false; // The CNTK reader implementation requires for each deserializer both the module and deserializer type be specified // This is redundant and the V2 API users will just specify type from which the module is automatically inferred // TODO: This should be done in the same manner for CNTK exe as well. vector<DictionaryValue> deserializers; for (auto deserializerConfig : configuration.deserializers) { static const std::unordered_map<std::wstring, std::wstring> deserializerTypeNameToModuleNameMap = { { L"CNTKTextFormatDeserializer", L"CNTKTextFormatReader" }, { L"ImageDeserializer", L"ImageReader" }, { L"Base64ImageDeserializer", L"ImageReader" }, { L"HTKFeatureDeserializer", L"HTKDeserializers" }, { L"HTKMLFDeserializer", L"HTKDeserializers" }, }; auto deserializerTypeName = deserializerConfig[L"type"].Value<std::wstring>(); if (deserializerTypeName == L"ImageDeserializer") { defaultMultithreaded = true; // Add a transpose transform since the image data in read in HWC (CWH in column major format) form while // the CNTK convolution engive supports WHC (in column-major format) auto& inputStreamsConfig = deserializerConfig[L"input"].Value<Dictionary>(); for (auto& inputStreamEntry : inputStreamsConfig) { auto& inputStreamConfig = inputStreamEntry.second.Value<Dictionary>(); if (inputStreamConfig.Contains(L"transforms")) { auto& transforms = inputStreamConfig[L"transforms"].Value<std::vector<DictionaryValue>>(); // Add the transpose transform Dictionary transposeTransform; transposeTransform[L"type"] = L"Transpose"; transforms.push_back(DictionaryValue(transposeTransform)); } } } if (deserializerTypeNameToModuleNameMap.find(deserializerTypeName) == deserializerTypeNameToModuleNameMap.end()) InvalidArgument("Unknown deserializer type '%S' specified for CNTK built-in composite MinibatchSource construction.", deserializerTypeName.c_str()); deserializerConfig[L"module"] = deserializerTypeNameToModuleNameMap.at(deserializerTypeName); deserializers.push_back(deserializerConfig); } augmentedConfiguration[L"multiThreadedDeserialization"] = (configuration.isMultithreaded.IsInitialized()) ? configuration.isMultithreaded.Get() : defaultMultithreaded; augmentedConfiguration[L"deserializers"] = deserializers; return augmentedConfiguration; }