void TestFunctionSerializationDuringTraining(const FunctionPtr& function, const Variable& labels, const MinibatchSourcePtr& minibatchSource, const DeviceDescriptor& device) { auto classifierOutput1 = function; auto featureStreamInfo = minibatchSource->StreamInfo(classifierOutput1->Arguments()[0]); auto labelStreamInfo = minibatchSource->StreamInfo(labels); const size_t minibatchSize = 200; auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device); auto trainer1 = BuildTrainer(classifierOutput1, labels); Dictionary model = classifierOutput1->Serialize(); trainer1.TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); auto classifierOutput2 = Function::Deserialize(model, device); if (AreEqual(classifierOutput1, classifierOutput2)) { throw std::runtime_error("TestModelSerialization: reloaded function is still identical to the original after it was trained."); } for (int i = 0; i < 3; ++i) { Dictionary model = classifierOutput1->Serialize(); auto classifierOutput2 = Function::Deserialize(model, device); if (!AreEqual(classifierOutput1, classifierOutput2)) { throw std::runtime_error("TestModelSerialization: original and reloaded functions are not identical."); } Trainer trainer2 = BuildTrainer(classifierOutput2, labels); for (int j = 0; j < 3; ++j) { trainer1.TrainMinibatch({ { classifierOutput1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); trainer2.TrainMinibatch({ { classifierOutput2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); double mbLoss1 = trainer1.PreviousMinibatchLossAverage(); double mbLoss2 = trainer2.PreviousMinibatchLossAverage(); FloatingPointCompare(mbLoss1, mbLoss2, "Post checkpoint restoration training loss does not match expectation"); } } }
void TestTrainingWithCheckpointing(const FunctionPtr& function1, const FunctionPtr& function2, const Variable& labels, const MinibatchSourcePtr& minibatchSource, const DeviceDescriptor& device) { auto featureStreamInfo = minibatchSource->StreamInfo(function1->Arguments()[0]); auto labelStreamInfo = minibatchSource->StreamInfo(labels); const size_t minibatchSize = 50; auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device); auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples; LearningRateSchedule learningRateSchedule({ { 2, 0.005 }, { 2, 0.0025 }, { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize); MomentumAsTimeConstantSchedule momentumValues({ { 2, 100 }, { 2, 200 }, { 2, 400 }, { 2, 800 } }, actualMBSize); auto trainer1 = BuildTrainer(function1, labels, learningRateSchedule, momentumValues); auto trainer2 = BuildTrainer(function2, labels, learningRateSchedule, momentumValues); assert(AreEqual(function1, function2)); trainer2.SaveCheckpoint(L"trainer.v2.checkpoint", false); trainer2.RestoreFromCheckpoint(L"trainer.v2.checkpoint"); if (!AreEqual(function1, function2)) { throw std::runtime_error("TestModelSerialization: reloaded function is not identical to the original."); } trainer1.TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); if (AreEqual(function1, function2)) { throw std::runtime_error("TestModelSerialization: reloaded function is still identical to the original after it was trained."); } trainer2.TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); if (!AreEqual(function1, function2)) { throw std::runtime_error("TestModelSerialization: reloaded function is not identical to the original."); } for (int i = 0; i < 3; ++i) { trainer2.SaveCheckpoint(L"trainer.v2.checkpoint", false); trainer2.RestoreFromCheckpoint(L"trainer.v2.checkpoint"); if (!AreEqual(function1, function2)) { throw std::runtime_error("TestModelSerialization: original and reloaded functions are not identical."); } for (int j = 0; j < 3; ++j) { trainer1.TrainMinibatch({ { function1->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); trainer2.TrainMinibatch({ { function2->Arguments()[0], minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); double mbLoss1 = trainer1.PreviousMinibatchLossAverage(); double mbLoss2 = trainer2.PreviousMinibatchLossAverage(); FloatingPointCompare(mbLoss1, mbLoss2, "Post checkpoint restoration training loss does not match expectation"); } } }