void FeedbackSolver<Dtype>::Restore(const char* state_file) { SolverState state; NetParameter net_param; ReadProtoFromBinaryFile(state_file, &state); if (state.has_learned_net()) { LOG(INFO)<<"Copy trained model from "<<state.learned_net(); ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param); net_->CopyTrainedLayersFrom(net_param); } iter_ = state.iter(); RestoreSolverState(state); }
void Solver<Dtype>::Restore(const char* state_file) { SolverState state; NetParameter net_param; ReadProtoFromBinaryFile(state_file, &state); if (state.has_learned_net()) { ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param); net_->CopyTrainedLayersFrom(net_param); } iter_ = state.iter(); current_step_ = state.current_step(); RestoreSolverState(state); }
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto( const string& state_file) { SolverState state; ReadProtoFromBinaryFile(state_file, &state); this->iter_ = state.iter(); if (state.has_learned_net()) { NetParameter net_param; ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param); this->net_->CopyTrainedLayersFrom(net_param); } this->current_step_ = state.current_step(); CHECK_EQ(state.history_size(), history_.size()) << "Incorrect length of history blobs."; LOG(INFO) << "SGDSolver: restoring history"; for (int i = 0; i < history_.size(); ++i) { history_[i]->FromProto(state.history(i)); } }