Exemple #1
0
    Dictionary Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
    {
        // Restore the model's parameters
        m_combinedTrainingFunction->RestoreModel(modelFilePath);

        std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
        auto ckpStream = GetFstream(trainerStateCheckpointFilePath, true);
        Dictionary checkpoint;
        *ckpStream >> checkpoint;

        auto learnerState = checkpoint[learnersPropertyName].Value<std::vector<DictionaryValue>>();
        auto externalState = checkpoint[externalStatePropertyName].Value<Dictionary>();

        if (!m_distributed)
        {
            m_parameterLearners->RestoreFromCheckpoint(learnerState);
            return externalState;
        }

        m_parameterLearners->RestoreFromCheckpoint(learnerState);
        DistributedCommunicatorPtr communicator = MPICommunicator();
        communicator->Barrier();

        auto key = std::to_wstring(communicator->CurrentWorker().m_globalRank);

        if (externalState.Contains(key))
            return externalState[key].Value<Dictionary>();
        else
            return externalState[std::to_wstring(0)].Value<Dictionary>();
    }
Exemple #2
0
    void Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
    {
        // Restore the model's parameters
         m_combinedTrainingFunction->RestoreModel(modelFilePath);

        std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
        auto ckpStream = GetFstream(trainerStateCheckpointFilePath, true);
        Dictionary checkpoint;
        *ckpStream >> checkpoint;

        const DictionaryValue& learners = checkpoint[learnersPropertyName];
        const vector<DictionaryValue>& learnerStates = learners.Value<vector<DictionaryValue>>();

        if (learnerStates.size() != m_parameterLearners.size())
        {
            LogicError("Trainer::RestoreFromCheckpoint: "
                       "Number of learners in the checkpoint (%zu) does not match the expected number (%zu)",
                       learnerStates.size(), m_parameterLearners.size());
        }

        for (int i = 0; i < m_parameterLearners.size(); ++i)
        {
            m_parameterLearners[i]->RestoreFromCheckpoint(learnerStates[i].Value<Dictionary>());
        }

        // TODO: we should return shared state from this function,
        // otherwise how can we be sure the minibatch source is in consistent state?
        if (m_distributedTrainer)
        {
            const DictionaryValue& distributedLearner = checkpoint[distributedLearnerPropertyName];
            m_distributedTrainer->RestoreFromCheckpoint(distributedLearner.Value<Dictionary>());
        }
    }
Exemple #3
0
    void Trainer::Save(const std::wstring& modelFilePath, const std::vector<DictionaryValue>& learnerState, const Dictionary& externalState)
    {
        Dictionary state;
        state[learnersPropertyName] = learnerState;
        state[externalStatePropertyName] = externalState;

        m_combinedTrainingFunction->SaveModel(modelFilePath);
        std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
        auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false);
        *ckpStream << state;
        ckpStream->flush();
    }
Exemple #4
0
    void Trainer::Save(const std::wstring& modelFilePath, bool usinglegacyModelFormat, const Dictionary& distributedLearnerState)
    {
        vector<DictionaryValue> learnerStates;
        for (const auto& learner : m_parameterLearners)
        {
            // TODO: add DictionaryValue(T&&)
            learnerStates.push_back(DictionaryValue(learner->Serialize()));
        }

        // add DictionaryValue ctor that takes an rvalue!
        Dictionary state;
        state[learnersPropertyName] = learnerStates;
        state[distributedLearnerPropertyName] = distributedLearnerState;

        m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat);
        std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
        auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false);
        *ckpStream << state;
        ckpStream->flush();
    }