CaffeImporter(const char *pototxt, const char *caffeModel) { CV_TRACE_FUNCTION(); ReadNetParamsFromTextFileOrDie(pototxt, &net); if (caffeModel && caffeModel[0]) ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary); }
void Solver::InitTrainNet() { const int num_train_nets = param_.has_net() + param_.has_net_param() + param_.has_train_net() + param_.has_train_net_param(); const string& field_names = "net, net_param, train_net, train_net_param"; CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net " << "using one of these fields: " << field_names; CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than " << "one of these fields specifying a train_net: " << field_names; NetParameter net_param; if (param_.has_train_net_param()) { LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in train_net_param."; net_param.CopyFrom(param_.train_net_param()); } else if (param_.has_train_net()) { LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from train_net file: " << param_.train_net(); ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); } if (param_.has_net_param()) { LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in net_param."; net_param.CopyFrom(param_.net_param()); } if (param_.has_net()) { LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from net file: " << param_.net(); ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); } // Set the correct NetState. We start with the solver defaults (lowest // precedence); then, merge in any NetState specified by the net_param itself; // finally, merge in any NetState specified by the train_state (highest // precedence). NetState net_state; net_state.set_phase(TRAIN); net_state.MergeFrom(net_param.state()); net_state.MergeFrom(param_.train_state()); net_param.mutable_state()->CopyFrom(net_state); if (Caffe::root_solver()) { net_.reset(new Net(net_param, rank_, &init_flag_, &iter0_flag_)); } else { net_.reset(new Net(net_param, rank_, &init_flag_, &iter0_flag_, root_solver_->net_.get())); } }
void Solver<Dtype>::InitTestNets() { CHECK(Caffe::root_solver()); const bool has_net_param = param_.has_net_param(); const bool has_net_file = param_.has_net(); const int num_generic_nets = has_net_param + has_net_file; CHECK_LE(num_generic_nets, 1) << "Both net_param and net_file may not be specified."; const int num_test_net_params = param_.test_net_param_size(); const int num_test_net_files = param_.test_net_size(); const int num_test_nets = num_test_net_params + num_test_net_files; if (num_generic_nets) { CHECK_GE(param_.test_iter_size(), num_test_nets) << "test_iter must be specified for each test network."; } else { CHECK_EQ(param_.test_iter_size(), num_test_nets) << "test_iter must be specified for each test network."; } // If we have a generic net (specified by net or net_param, rather than // test_net or test_net_param), we may have an unlimited number of actual // test networks -- the actual number is given by the number of remaining // test_iters after any test nets specified by test_net_param and/or test_net // are evaluated. const int num_generic_net_instances = param_.test_iter_size() - num_test_nets; const int num_test_net_instances = num_test_nets + num_generic_net_instances; if (param_.test_state_size()) { CHECK_EQ(param_.test_state_size(), num_test_net_instances) << "test_state must be unspecified or specified once per test net."; } if (num_test_net_instances) { CHECK_GT(param_.test_interval(), 0); } int test_net_id = 0; vector<string> sources(num_test_net_instances); vector<NetParameter> net_params(num_test_net_instances); for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) { sources[test_net_id] = "test_net_param"; net_params[test_net_id].CopyFrom(param_.test_net_param(i)); } for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) { sources[test_net_id] = "test_net file: " + param_.test_net(i); ReadNetParamsFromTextFileOrDie(param_.test_net(i), &net_params[test_net_id]); } const int remaining_test_nets = param_.test_iter_size() - test_net_id; if (has_net_param) { for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) { sources[test_net_id] = "net_param"; net_params[test_net_id].CopyFrom(param_.net_param()); } } if (has_net_file) { for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) { sources[test_net_id] = "net file: " + param_.net(); ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]); } } test_nets_.resize(num_test_net_instances); for (int i = 0; i < num_test_net_instances; ++i) { // Set the correct NetState. We start with the solver defaults (lowest // precedence); then, merge in any NetState specified by the net_param // itself; finally, merge in any NetState specified by the test_state // (highest precedence). NetState net_state; net_state.set_phase(TEST); net_state.MergeFrom(net_params[i].state()); if (param_.test_state_size()) { net_state.MergeFrom(param_.test_state(i)); } net_params[i].mutable_state()->CopyFrom(net_state); LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i]; if (Caffe::root_solver()) { test_nets_[i].reset(new Net<Dtype>(net_params[i])); } else { test_nets_[i].reset(new Net<Dtype>(net_params[i], root_solver_->test_nets_[i].get())); } test_nets_[i]->set_debug_info(param_.debug_info()); } }
Net<Dtype>::Net(const string& param_file) { NetParameter param; ReadNetParamsFromTextFileOrDie(param_file, ¶m); Init(param); }