Deserializer HTKFeatureDeserializer(const std::vector<HTKFeatureConfiguration>& streams)
 {
     Deserializer htk;
     Dictionary input;
     for (const auto& s : streams)
     {
         const auto& key = s.m_streamName;
         Dictionary stream;
         std::vector<DictionaryValue> ctxWindow = { DictionaryValue(s.m_left), DictionaryValue(s.m_right) };
         stream.Add(L"scpFile", s.m_scp, L"dim", s.m_dim, L"contextWindow", ctxWindow, L"expandToUtterance", s.m_broadcast);
         stream[L"definesMBSize"] = s.m_definesMbSize;
         input[key] = stream;
     }
     htk.Add(L"type", L"HTKFeatureDeserializer", L"input", input);
     return htk;
 }
示例#2
0
 /*virtual*/ Dictionary TrainingParameterSchedule<T>::Serialize() const
 {
     Dictionary schedule;
     for (const auto& it : m_schedule)
     {
         schedule[std::to_wstring(it.first)] = DictionaryValue(it.second);
     }
     Dictionary dict;
     dict[versionKey] = CurrentVersion();
     dict[typeKey] = s_trainingParameterScheduleTypeValue;
     dict[epochSizeKey] = m_epochSize;
     dict[unitKey] = static_cast<size_t>(m_unit);
     dict[scheduleKey] = schedule;
     return dict;
 }
示例#3
0
文件: Trainer.cpp 项目: Soukiy/CNTK
    void Trainer::Save(const std::wstring& modelFilePath, bool usinglegacyModelFormat, const Dictionary& distributedLearnerState)
    {
        vector<DictionaryValue> learnerStates;
        for (const auto& learner : m_parameterLearners)
        {
            // TODO: add DictionaryValue(T&&)
            learnerStates.push_back(DictionaryValue(learner->Serialize()));
        }

        // add DictionaryValue ctor that takes an rvalue!
        Dictionary state;
        state[learnersPropertyName] = learnerStates;
        state[distributedLearnerPropertyName] = distributedLearnerState;

        m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat);
        std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
        auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false);
        *ckpStream << state;
        ckpStream->flush();
    }
示例#4
0
文件: Variable.cpp 项目: rlugojr/CNTK
    /*virtual*/ Dictionary Variable::Serialize() const
    {
        if (IsOutput())
        {
            LogicError("Output variables cannot be saved");
        }
        Dictionary dict;

        dict[versionKey] = CurrentVersion();
        dict[typeKey] = s_variableTypeValue;
        dict[uidKey] = Uid();
        dict[kindKey] = static_cast<size_t>(Kind());
        dict[dataTypeKey] = static_cast<size_t>(GetDataType());
        const auto& dynamicAxes = DynamicAxes();
        vector<DictionaryValue> dictionaryValueVector; 
        dictionaryValueVector.reserve(dynamicAxes.size());
        for (const auto& axis : dynamicAxes)
            dictionaryValueVector.push_back(axis);

        dict[dynamicAxisKey] = dictionaryValueVector;
        dict[isSparseKey] = IsSparse();
        if (!Name().empty())
            dict[nameKey] = Name();
        dict[needsGradientKey] = NeedsGradient();
        dict[shapeKey] = Shape();
        if (IsParameter() || IsConstant())
        {
            NDArrayView* value = Value().get();
            if (value == nullptr)
            {
                LogicError("Uninitialized Parameter variable cannot be saved");
            }

            // TODO: add a dictionary value constructor with an rvalue parameter.
            dict[valueKey] = DictionaryValue(*value);
        }
        
        return dict;
    }
        Dictionary ToDictionary(const ::CNTK::MinibatchSourceConfig& configuration)
        {
            Validate(configuration);

            Dictionary augmentedConfiguration;

            if (configuration.randomizationWindowInSamples != 0)
            {
                augmentedConfiguration[L"randomize"] = true;
                augmentedConfiguration[L"randomizationWindow"] = configuration.randomizationWindowInSamples;
                augmentedConfiguration[L"sampleBasedRandomizationWindow"] = true;
                augmentedConfiguration[L"randomizationSeed"] = configuration.randomizationSeed;
            }
            else if (configuration.randomizationWindowInChunks != 0) 
            {
                augmentedConfiguration[L"randomize"] = true;
                augmentedConfiguration[L"randomizationWindow"] = configuration.randomizationWindowInChunks;
                augmentedConfiguration[L"sampleBasedRandomizationWindow"] = false;
                augmentedConfiguration[L"randomizationSeed"] = configuration.randomizationSeed;
            }
            else 
            {
                augmentedConfiguration[L"randomize"] = false;
            }

            if (configuration.truncationLength != 0)
            {
                augmentedConfiguration[L"truncated"] = true;
                augmentedConfiguration[L"truncationLength"] = configuration.truncationLength;
            }

            augmentedConfiguration[L"frameMode"] = configuration.isFrameModeEnabled;
            augmentedConfiguration[L"traceLevel"] = static_cast<size_t>(configuration.traceLevel);

            bool defaultMultithreaded = false;
            // The CNTK reader implementation requires for each deserializer both the module and deserializer type be specified
            // This is redundant and the V2 API users will just specify type from which the module is automatically inferred
            // TODO: This should be done in the same manner for CNTK exe as well.
            vector<DictionaryValue> deserializers;
            for (auto deserializerConfig : configuration.deserializers)
            {
                static const std::unordered_map<std::wstring, std::wstring> deserializerTypeNameToModuleNameMap = {
                    { L"CNTKTextFormatDeserializer", L"CNTKTextFormatReader" },
                    { L"ImageDeserializer",          L"ImageReader" },
                    { L"Base64ImageDeserializer",    L"ImageReader" },
                    { L"HTKFeatureDeserializer",     L"HTKDeserializers" },
                    { L"HTKMLFDeserializer",         L"HTKDeserializers" },
                };

                auto deserializerTypeName = deserializerConfig[L"type"].Value<std::wstring>();
                if (deserializerTypeName == L"ImageDeserializer")
                {
                    defaultMultithreaded = true;

                    // Add a transpose transform since the image data in read in HWC (CWH in column major format) form while 
                    // the CNTK convolution engive supports WHC (in column-major format)
                    auto& inputStreamsConfig = deserializerConfig[L"input"].Value<Dictionary>();
                    for (auto& inputStreamEntry : inputStreamsConfig)
                    {
                        auto& inputStreamConfig = inputStreamEntry.second.Value<Dictionary>();
                        if (inputStreamConfig.Contains(L"transforms"))
                        {
                            auto& transforms = inputStreamConfig[L"transforms"].Value<std::vector<DictionaryValue>>();

                            // Add the transpose transform
                            Dictionary transposeTransform;
                            transposeTransform[L"type"] = L"Transpose";
                            transforms.push_back(DictionaryValue(transposeTransform));
                        }
                    }
                }

                if (deserializerTypeNameToModuleNameMap.find(deserializerTypeName) == deserializerTypeNameToModuleNameMap.end())
                    InvalidArgument("Unknown deserializer type '%S' specified for CNTK built-in composite MinibatchSource construction.", deserializerTypeName.c_str());

                deserializerConfig[L"module"] = deserializerTypeNameToModuleNameMap.at(deserializerTypeName);
                deserializers.push_back(deserializerConfig);
            }

            augmentedConfiguration[L"multiThreadedDeserialization"] = 
                (configuration.isMultithreaded.IsInitialized()) ? configuration.isMultithreaded.Get() : defaultMultithreaded;

            augmentedConfiguration[L"deserializers"] = deserializers;

            return augmentedConfiguration;
        }