void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState) { auto learnersState = m_parameterLearners->CreateCheckpoint(); if (!m_distributed) return Save(modelFilePath, learnersState, externalState); // Collect distrbuted external state. DistributedCommunicatorPtr communicator = MPICommunicator(); communicator->Barrier(); std::vector<DictionaryPtr> remoteState; communicator->Gather(externalState, remoteState, communicator->Workers()); Dictionary aggregatedState; for (const auto& w : communicator->Workers()) { aggregatedState[std::to_wstring(w.m_globalRank)] = *remoteState[w.m_globalRank]; } if (communicator->CurrentWorker().IsMain()) Save(modelFilePath, learnersState, aggregatedState); // all workers need to sync up after saving model to avoid read-after-write hazard // i.e. one worker is in the middle of write while another tries to read communicator->Barrier(); }
bool Evaluator::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, std::pair<ValuePtr, size_t>& result, const DeviceDescriptor& computeDevice, bool distributed) { result = TestLocalMinibatch(arguments, outputsToFetch, computeDevice); if (distributed) { if (!outputsToFetch.empty()) RuntimeError("Custom outputs are not yet supported in distributed evaluation."); double localSampleCount = static_cast<double>(result.second); auto values = std::vector<NDArrayViewPtr>{ result.first->Data(), MakeSharedObject<NDArrayView>(NDShape{}, &localSampleCount, 1, DeviceDescriptor::CPUDevice()) }; DistributedCommunicatorPtr communicator = MPICommunicator(); communicator->AggregateInPlace(values, communicator->Workers()); result.second = static_cast<size_t>(localSampleCount); } bool hasData = (result.second != 0); if (hasData) UpdateTestProgress(result.second, result.first, computeDevice); return hasData; }