Dictionary Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath) { // Restore the model's parameters m_combinedTrainingFunction->RestoreModel(modelFilePath); std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath); auto ckpStream = GetFstream(trainerStateCheckpointFilePath, true); Dictionary checkpoint; *ckpStream >> checkpoint; auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>(); auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>(); if (!m_distributed) { m_parameterLearners->RestoreFromCheckpoint(learnerState); return externalState; } m_parameterLearners->RestoreFromCheckpoint(learnerState); DistributedCommunicatorPtr communicator = MPICommunicator(); communicator->Barrier(); auto key = std::to_wstring(communicator->CurrentWorker().m_globalRank); if (externalState.Contains(key)) return externalState[key].Value<Dictionary>(); else return externalState[std::to_wstring(0)].Value<Dictionary>(); }
void Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath) { // Restore the model's parameters m_combinedTrainingFunction->RestoreModel(modelFilePath); std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath); auto ckpStream = GetFstream(trainerStateCheckpointFilePath, true); Dictionary checkpoint; *ckpStream >> checkpoint; const DictionaryValue& learners = checkpoint[learnersPropertyName]; const vector<DictionaryValue>& learnerStates = learners.Value<vector<DictionaryValue>>(); if (learnerStates.size() != m_parameterLearners.size()) { LogicError("Trainer::RestoreFromCheckpoint: " "Number of learners in the checkpoint (%zu) does not match the expected number (%zu)", learnerStates.size(), m_parameterLearners.size()); } for (int i = 0; i < m_parameterLearners.size(); ++i) { m_parameterLearners[i]->RestoreFromCheckpoint(learnerStates[i].Value<Dictionary>()); } // TODO: we should return shared state from this function, // otherwise how can we be sure the minibatch source is in consistent state? if (m_distributedTrainer) { const DictionaryValue& distributedLearner = checkpoint[distributedLearnerPropertyName]; m_distributedTrainer->RestoreFromCheckpoint(distributedLearner.Value<Dictionary>()); } }
void Trainer::Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState) { Dictionary state; state[learnersPropertyName] = learnerState; state[externalStatePropertyName] = externalState; m_combinedTrainingFunction->SaveModel(modelFilePath); std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath); auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false); *ckpStream << state; ckpStream->flush(); }
void Trainer::Save(const std::wstring& modelFilePath, bool usinglegacyModelFormat, const Dictionary& distributedLearnerState) { vector<DictionaryValue> learnerStates; for (const auto& learner : m_parameterLearners) { // TODO: add DictionaryValue(T&&) learnerStates.push_back(DictionaryValue(learner->Serialize())); } // add DictionaryValue ctor that takes an rvalue! Dictionary state; state[learnersPropertyName] = learnerStates; state[distributedLearnerPropertyName] = distributedLearnerState; m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat); std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath); auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false); *ckpStream << state; ckpStream->flush(); }