예제 #1
0
파일: Trainer.cpp 프로젝트: Microsoft/CNTK
    Dictionary Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
    {
        // Restore the model's parameters
        m_combinedTrainingFunction->RestoreModel(modelFilePath);

        std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
        auto ckpStream = GetFstream(trainerStateCheckpointFilePath, true);
        Dictionary checkpoint;
        *ckpStream >> checkpoint;

        auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>();
        auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>();

        if (!m_distributed)
        {
            m_parameterLearners->RestoreFromCheckpoint(learnerState);
            return externalState;
        }

        m_parameterLearners->RestoreFromCheckpoint(learnerState);
        DistributedCommunicatorPtr communicator = MPICommunicator();
        communicator->Barrier();

        auto key = std::to_wstring(communicator->CurrentWorker().m_globalRank);

        if (externalState.Contains(key))
            return externalState[key].Value<Dictionary>();
        else
            return externalState[std::to_wstring(0)].Value<Dictionary>();
    }
예제 #2
0
파일: Trainer.cpp 프로젝트: AllanYiin/CNTK
    Dictionary Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
    {
        // Restore the model's parameters
        m_combinedTrainingFunction->Restore(modelFilePath);

        Dictionary checkpoint = Dictionary::Load(GetTrainerStateCheckpointFilePath(modelFilePath));

        size_t version = 0;

        if (checkpoint.Contains(versionPropertyName))
            version = checkpoint[versionPropertyName].Value<size_t>();
        
        auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>();
        auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>();

        m_parameterLearners->RestoreFromCheckpoint(learnerState);

        if (!m_distributed)
        {
            return externalState;
        }

        // this ensures that nobody will start writing to the model/checkpoint files, until
        // everybody is done reading them.
        DistributedCommunicatorPtr communicator = MPICommunicator();
        communicator->Barrier();

        auto mainWorkerId = std::to_wstring(0);
        auto localWorkerId = std::to_wstring(communicator->CurrentWorker().m_globalRank);

        // before version 1, there was no distributed state per se. Instead, the external state
        // contained a dictionary of worker-specific external states.
        if (version == 0)
        {
            auto key = externalState.Contains(localWorkerId) ? localWorkerId : mainWorkerId;
            return externalState[key].Value<Dictionary>();
        }

        Dictionary distributedState = checkpoint[distributedStatePropertyName].Value<Dictionary>();

        if (communicator->CurrentWorker().IsMain() || !distributedState.Contains(localWorkerId))
        {
            return externalState;
        }
        
        // the checkpoint contains internal state for this worker.
        Dictionary localState = distributedState[localWorkerId].Value<Dictionary>();

        auto internalState = localState[internalWorkerStateKey].Value<Dictionary>();
        auto compositeFunction = std::dynamic_pointer_cast<CompositeFunction>(m_combinedTrainingFunction);
        if (compositeFunction == nullptr)
            RuntimeError("Combined training function is not a CompositeFunction.");
            
        // this assumes the compositeFunction (restored form a checkpoint made by the main node) and 
        // the internal worker state both have identical UIDs.
        compositeFunction->SetInternalState(internalState);
        
        return localState[externalWorkerStateKey].Value<Dictionary>();
    }
예제 #3
0
파일: Trainer.cpp 프로젝트: Microsoft/CNTK
    void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, Dictionary externalState)
    {
        auto learnersState = m_parameterLearners->CreateCheckpoint();
        if (!m_distributed)
            return Save(modelFilePath, learnersState, externalState);

        // Collect distrbuted external state.
        DistributedCommunicatorPtr communicator = MPICommunicator();
        communicator->Barrier();

        std::vector<DictionaryPtr> remoteState;
        communicator->Gather(externalState, remoteState, communicator->Workers());

        Dictionary aggregatedState;
        for (const auto& w : communicator->Workers())
        {
            aggregatedState[std::to_wstring(w.m_globalRank)] = *remoteState[w.m_globalRank];
        }

        if (communicator->CurrentWorker().IsMain())
            Save(modelFilePath, learnersState, aggregatedState);

        // all workers need to sync up after saving model to avoid read-after-write hazard
        // i.e. one worker is in the middle of write while another tries to read
        communicator->Barrier();
    }
예제 #4
0
    bool Evaluator::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, std::pair<ValuePtr, size_t>& result, const DeviceDescriptor& computeDevice, bool distributed)
    {
        result = TestLocalMinibatch(arguments, outputsToFetch, computeDevice);
        if (distributed)
        {
            if (!outputsToFetch.empty())
                RuntimeError("Custom outputs are not yet supported in distributed evaluation.");

            double localSampleCount = static_cast<double>(result.second);

            auto values = std::vector<NDArrayViewPtr>{ result.first->Data(), MakeSharedObject<NDArrayView>(NDShape{}, &localSampleCount, 1, DeviceDescriptor::CPUDevice()) };
            DistributedCommunicatorPtr communicator = MPICommunicator();
            communicator->AggregateInPlace(values, communicator->Workers());
            result.second = static_cast<size_t>(localSampleCount);
        }

        bool hasData = (result.second != 0);
        if (hasData)
            UpdateTestProgress(result.second, result.first, computeDevice);

        return hasData;
    }