void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState) { auto learnersState = m_parameterLearners->CreateCheckpoint(); if (!m_distributed) return Save(modelFilePath, learnersState, externalState); // Collect distrbuted external state. DistributedCommunicatorPtr communicator = MPICommunicator(); communicator->Barrier(); std::vector<DictionaryPtr> remoteState; communicator->Gather(externalState, remoteState, communicator->Workers()); Dictionary aggregatedState; for (const auto& w : communicator->Workers()) { aggregatedState[std::to_wstring(w.m_globalRank)] = *remoteState[w.m_globalRank]; } if (communicator->CurrentWorker().IsMain()) Save(modelFilePath, learnersState, aggregatedState); // all workers need to sync up after saving model to avoid read-after-write hazard // i.e. one worker is in the middle of write while another tries to read communicator->Barrier(); }
Dictionary Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath) { // Restore the model's parameters m_combinedTrainingFunction->RestoreModel(modelFilePath); std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath); auto ckpStream = GetFstream(trainerStateCheckpointFilePath, true); Dictionary checkpoint; *ckpStream >> checkpoint; auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>(); auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>(); if (!m_distributed) { m_parameterLearners->RestoreFromCheckpoint(learnerState); return externalState; } m_parameterLearners->RestoreFromCheckpoint(learnerState); DistributedCommunicatorPtr communicator = MPICommunicator(); communicator->Barrier(); auto key = std::to_wstring(communicator->CurrentWorker().m_globalRank); if (externalState.Contains(key)) return externalState[key].Value<Dictionary>(); else return externalState[std::to_wstring(0)].Value<Dictionary>(); }
Dictionary Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath) { // Restore the model's parameters m_combinedTrainingFunction->Restore(modelFilePath); Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath)); size_t version = 0; if (checkpoint.Contains(versionPropertyName)) version = checkpoint[versionPropertyName].Value<size_t>(); auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>(); auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>(); m_parameterLearners->RestoreFromCheckpoint(learnerState); if (!m_distributed) { return externalState; } // this ensures that nobody will start writing to the model/checkpoint files, until // everybody is done reading them. DistributedCommunicatorPtr communicator = MPICommunicator(); communicator->Barrier(); auto mainWorkerId = std::to_wstring(0); auto localWorkerId = std::to_wstring(communicator->CurrentWorker().m_globalRank); // before version 1, there was no distributed state per se. Instead, the external state // contained a dictionary of worker-specific external states. if (version == 0) { auto key = externalState.Contains(localWorkerId) ? localWorkerId : mainWorkerId; return externalState[key].Value<Dictionary>(); } Dictionary distributedState = checkpoint[distributedStatePropertyName].Value<Dictionary>(); if (communicator->CurrentWorker().IsMain() || !distributedState.Contains(localWorkerId)) { return externalState; } // the checkpoint contains internal state for this worker. Dictionary localState = distributedState[localWorkerId].Value<Dictionary>(); auto internalState = localState[internalWorkerStateKey].Value<Dictionary>(); auto compositeFunction = std::dynamic_pointer_cast<CompositeFunction>(m_combinedTrainingFunction); if (compositeFunction == nullptr) RuntimeError("Combined training function is not a CompositeFunction."); // this assumes the compositeFunction (restored form a checkpoint made by the main node) and // the internal worker state both have identical UIDs. compositeFunction->SetInternalState(internalState); return localState[externalWorkerStateKey].Value<Dictionary>(); }