void TestFunctionSerializationDuringTraining(const FunctionPtr& function, const Variable& labels, const MinibatchSourcePtr& minibatchSource, const DeviceDescriptor& device)
{
    auto classifierOutput1 = function;

    auto featureStreamInfo = minibatchSource->StreamInfo(classifierOutput1->Arguments()[0]);
    auto labelStreamInfo = minibatchSource->StreamInfo(labels);

    const size_t minibatchSize = 200;
    auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);

    auto trainer1 = BuildTrainer(classifierOutput1, labels);

    Dictionary model = classifierOutput1->Serialize();

    trainer1.TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

    auto classifierOutput2 = Function::Deserialize(model, device);

    if (AreEqual(classifierOutput1, classifierOutput2))
    {
        throw std::runtime_error("TestModelSerialization: reloaded function is still identical to the original after it was trained.");
    }

    for (int i = 0; i < 3; ++i)
    {
        Dictionary model = classifierOutput1->Serialize();

        auto classifierOutput2 = Function::Deserialize(model, device);

        if (!AreEqual(classifierOutput1, classifierOutput2))
        {
            throw std::runtime_error("TestModelSerialization: original and reloaded functions are not identical.");
        }
      
        Trainer trainer2 = BuildTrainer(classifierOutput2, labels);

        for (int j = 0; j < 3; ++j)
        {
            trainer1.TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
            trainer2.TrainMinibatch({ { classifierOutput2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

            double mbLoss1 = trainer1.PreviousMinibatchLossAverage();
            double mbLoss2 = trainer2.PreviousMinibatchLossAverage();
            FloatingPointCompare(mbLoss1, mbLoss2, "Post checkpoint restoration training loss does not match expectation");
        }
    }
}
Beispiel #2
0
    void TrainingSession::GetNextMinibatch(const MinibatchSourcePtr& source, std::unordered_map<Variable, ValuePtr>& minibatch, size_t mbSize, size_t workerRank, size_t numberOfWorkers, const DeviceDescriptor& computeDevice)
    {
        minibatch.clear();

        if (mbSize == 0)
            return;

        auto minibatchData = source->GetNextMinibatch(0 /*numberOfSequences*/, mbSize, numberOfWorkers, workerRank, computeDevice);
        if (minibatchData.empty())
            return;

        for (auto v : m_modelInputToMinibatchSourceStream)
            minibatch.insert({ v.first, minibatchData[v.second].data });
    }
Beispiel #3
0
    void TrainingSession::GetNextMinibatch(
        const MinibatchSourcePtr& source,
        std::unordered_map<Variable, ValuePtr>& minibatch,
        const std::unordered_map<Variable, StreamInformation>& inputVarToStream,
        size_t mbSize,
        size_t workerRank,
        size_t numberOfWorkers,
        const DeviceDescriptor& computeDevice)
    {
        minibatch.clear();

        if (mbSize == 0)
            return;

        // TODO: is copy really necessary here?
        auto minibatchData = source->GetNextMinibatch(0 /*numberOfSequences*/, mbSize, numberOfWorkers, workerRank, computeDevice);
        if (minibatchData.empty())
            return;

        for (auto v : inputVarToStream)
            minibatch.insert({ v.first, minibatchData[v.second].data });
    }
void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionPtr& function2, const Variable& labels, const MinibatchSourcePtr& minibatchSource, const DeviceDescriptor& device)
{
    auto featureStreamInfo = minibatchSource->StreamInfo(function1->Arguments()[0]);
    auto labelStreamInfo = minibatchSource->StreamInfo(labels);

    const size_t minibatchSize = 50;
    auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
     auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;

    LearningRateSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
    MomentumAsTimeConstantSchedule momentumValues({ { 2, 100 }, { 2, 200 }, { 2, 400 }, { 2, 800 } }, actualMBSize);


    auto trainer1 = BuildTrainer(function1, labels, learningRateSchedule, momentumValues);
    auto trainer2 = BuildTrainer(function2, labels, learningRateSchedule, momentumValues);

    assert(AreEqual(function1, function2));

    trainer2.SaveCheckpoint(L"trainer.v2.checkpoint", false);
    trainer2.RestoreFromCheckpoint(L"trainer.v2.checkpoint");

    if (!AreEqual(function1, function2))
    {
        throw std::runtime_error("TestModelSerialization: reloaded function is not identical to the original.");
    }

    trainer1.TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

    if (AreEqual(function1, function2))
    {
        throw std::runtime_error("TestModelSerialization: reloaded function is still identical to the original after it was trained.");
    }

    trainer2.TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

    if (!AreEqual(function1, function2))
    {
        throw std::runtime_error("TestModelSerialization: reloaded function is not identical to the original.");
    }

    for (int i = 0; i < 3; ++i)
    {
        trainer2.SaveCheckpoint(L"trainer.v2.checkpoint", false);
        trainer2.RestoreFromCheckpoint(L"trainer.v2.checkpoint");

        if (!AreEqual(function1, function2))
        {
            throw std::runtime_error("TestModelSerialization: original and reloaded functions are not identical.");
        }
      
        for (int j = 0; j < 3; ++j)
        {
            trainer1.TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
            trainer2.TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

            double mbLoss1 = trainer1.PreviousMinibatchLossAverage();
            double mbLoss2 = trainer2.PreviousMinibatchLossAverage();
            FloatingPointCompare(mbLoss1, mbLoss2, "Post checkpoint restoration training loss does not match expectation");
        }
    }
}