Exemplo n.º 1
0
void TestFunctionSerializationDuringTraining(const FunctionPtr& function, const Variable& labels, const MinibatchSourcePtr& minibatchSource, const DeviceDescriptor& device)
{
    auto classifierOutput1 = function;

    auto featureStreamInfo = minibatchSource->StreamInfo(classifierOutput1->Arguments()[0]);
    auto labelStreamInfo = minibatchSource->StreamInfo(labels);

    const size_t minibatchSize = 200;
    auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);

    auto trainer1 = BuildTrainer(classifierOutput1, labels);

    Dictionary model = classifierOutput1->Serialize();

    trainer1.TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

    auto classifierOutput2 = Function::Deserialize(model, device);

    if (AreEqual(classifierOutput1, classifierOutput2))
    {
        throw std::runtime_error("TestModelSerialization: reloaded function is still identical to the original after it was trained.");
    }

    for (int i = 0; i < 3; ++i)
    {
        Dictionary model = classifierOutput1->Serialize();

        auto classifierOutput2 = Function::Deserialize(model, device);

        if (!AreEqual(classifierOutput1, classifierOutput2))
        {
            throw std::runtime_error("TestModelSerialization: original and reloaded functions are not identical.");
        }
      
        Trainer trainer2 = BuildTrainer(classifierOutput2, labels);

        for (int j = 0; j < 3; ++j)
        {
            trainer1.TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
            trainer2.TrainMinibatch({ { classifierOutput2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

            double mbLoss1 = trainer1.PreviousMinibatchLossAverage();
            double mbLoss2 = trainer2.PreviousMinibatchLossAverage();
            FloatingPointCompare(mbLoss1, mbLoss2, "Post checkpoint restoration training loss does not match expectation");
        }
    }
}
Exemplo n.º 2
0
void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionPtr& function2, const Variable& labels, const MinibatchSourcePtr& minibatchSource, const DeviceDescriptor& device)
{
    auto featureStreamInfo = minibatchSource->StreamInfo(function1->Arguments()[0]);
    auto labelStreamInfo = minibatchSource->StreamInfo(labels);

    const size_t minibatchSize = 50;
    auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
     auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;

    LearningRateSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
    MomentumAsTimeConstantSchedule momentumValues({ { 2, 100 }, { 2, 200 }, { 2, 400 }, { 2, 800 } }, actualMBSize);


    auto trainer1 = BuildTrainer(function1, labels, learningRateSchedule, momentumValues);
    auto trainer2 = BuildTrainer(function2, labels, learningRateSchedule, momentumValues);

    assert(AreEqual(function1, function2));

    trainer2.SaveCheckpoint(L"trainer.v2.checkpoint", false);
    trainer2.RestoreFromCheckpoint(L"trainer.v2.checkpoint");

    if (!AreEqual(function1, function2))
    {
        throw std::runtime_error("TestModelSerialization: reloaded function is not identical to the original.");
    }

    trainer1.TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

    if (AreEqual(function1, function2))
    {
        throw std::runtime_error("TestModelSerialization: reloaded function is still identical to the original after it was trained.");
    }

    trainer2.TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

    if (!AreEqual(function1, function2))
    {
        throw std::runtime_error("TestModelSerialization: reloaded function is not identical to the original.");
    }

    for (int i = 0; i < 3; ++i)
    {
        trainer2.SaveCheckpoint(L"trainer.v2.checkpoint", false);
        trainer2.RestoreFromCheckpoint(L"trainer.v2.checkpoint");

        if (!AreEqual(function1, function2))
        {
            throw std::runtime_error("TestModelSerialization: original and reloaded functions are not identical.");
        }
      
        for (int j = 0; j < 3; ++j)
        {
            trainer1.TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
            trainer2.TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);

            double mbLoss1 = trainer1.PreviousMinibatchLossAverage();
            double mbLoss2 = trainer2.PreviousMinibatchLossAverage();
            FloatingPointCompare(mbLoss1, mbLoss2, "Post checkpoint restoration training loss does not match expectation");
        }
    }
}