Пример #1
0
    void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState)
    {
        auto learnersState = m_parameterLearners->CreateCheckpoint();
        if (!m_distributed)
            return Save(modelFilePath, learnersState, externalState);

        // Collect distrbuted external state.
        DistributedCommunicatorPtr communicator = MPICommunicator();
        communicator->Barrier();

        std::vector<DictionaryPtr> remoteState;
        communicator->Gather(externalState, remoteState, communicator->Workers());

        Dictionary aggregatedState;
        for (const auto& w : communicator->Workers())
        {
            aggregatedState[std::to_wstring(w.m_globalRank)] = *remoteState[w.m_globalRank];
        }

        if (communicator->CurrentWorker().IsMain())
            Save(modelFilePath, learnersState, aggregatedState);

        // all workers need to sync up after saving model to avoid read-after-write hazard
        // i.e. one worker is in the middle of write while another tries to read
        communicator->Barrier();
    }
Пример #2
0
    bool Evaluator::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, std::pair<ValuePtr, size_t>& result, const DeviceDescriptor& computeDevice, bool distributed)
    {
        result = TestLocalMinibatch(arguments, outputsToFetch, computeDevice);
        if (distributed)
        {
            if (!outputsToFetch.empty())
                RuntimeError("Custom outputs are not yet supported in distributed evaluation.");

            double localSampleCount = static_cast<double>(result.second);

            auto values = std::vector<NDArrayViewPtr>{ result.first->Data(), MakeSharedObject<NDArrayView>(NDShape{}, &localSampleCount, 1, DeviceDescriptor::CPUDevice()) };
            DistributedCommunicatorPtr communicator = MPICommunicator();
            communicator->AggregateInPlace(values, communicator->Workers());
            result.second = static_cast<size_t>(localSampleCount);
        }

        bool hasData = (result.second != 0);
        if (hasData)
            UpdateTestProgress(result.second, result.first, computeDevice);

        return hasData;
    }