コード例 #1
0
// Load the mean file in binaryproto format.
int DeepFeatureExtractor::SetMean(
  const std::string& meanfile) 
{
  BlobProto blob_proto;
  ReadProtoFromBinaryFileOrDie(meanfile.c_str(), &blob_proto);
  // Convert from BlobProto to Blob<float> 
  Blob<float> meanblob;
  meanblob.FromProto(blob_proto);
  CHECK_EQ(meanblob.channels(), m_num_channels)
    << "Number of channels of mean file doesn't match input layer.";
  // The format of the mean file is planar 32-bit float BGR or grayscale.
  std::vector<cv::Mat> channels;
  float* data = meanblob.mutable_cpu_data();
  for (unsigned int i = 0; i < m_num_channels; ++i) {
    // Extract an individual channel.
    cv::Mat channel(meanblob.height(), meanblob.width(), CV_32FC1, data);
    channels.push_back(channel);
    data += meanblob.height() * meanblob.width();
  }
  // Merge the separate channels into a single image.
  cv::Mat mean;
  cv::merge(channels, mean);
  // Compute the global mean pixel value and create a mean image
  // filled with this value. 
  cv::Scalar channel_mean = cv::mean(mean);
  m_mean = cv::Mat(m_input_geometry, mean.type(), channel_mean);
  return 0;
}
コード例 #2
0
void CompactDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      vector<Blob<Dtype>*>* top) {
  if (top->size() == 1) {
    this->output_labels_ = false;
  } else {
    this->output_labels_ = true;
  }
  DataLayerSetUp(bottom, top);
  // The subclasses should setup the datum channels, height and width
  CHECK_GT(this->datum_channels_, 0);
  CHECK_GT(this->datum_height_, 0);
  CHECK_GT(this->datum_width_, 0);
  CHECK(this->transform_param_.crop_size() > 0);
  CHECK_GE(this->datum_height_, this->transform_param_.crop_size());
  CHECK_GE(this->datum_width_, this->transform_param_.crop_size());
  int crop_size = this->transform_param_.crop_size();

  // check if we want to have mean
  if (transform_param_.has_mean_file()) {
	  //CHECK(this->transform_param_.has_mean_file());
	  this->data_mean_.Reshape(1, this->datum_channels_, crop_size, crop_size);
	  const string& mean_file = this->transform_param_.mean_file();
	  LOG(INFO) << "Loading mean file from" << mean_file;
	  BlobProto blob_proto;
	  ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
	  this->data_mean_.FromProto(blob_proto);
	  Blob<Dtype> tmp;
	  tmp.FromProto(blob_proto);
	  const Dtype* src_data = tmp.cpu_data();
	  Dtype* dst_data = this->data_mean_.mutable_cpu_data();
	  CHECK_EQ(tmp.num(), 1);
	  CHECK_EQ(tmp.channels(), this->datum_channels_);
	  CHECK_GE(tmp.height(), crop_size);
	  CHECK_GE(tmp.width(), crop_size);
	  int w_off = (tmp.width() - crop_size) / 2;
	  int h_off = (tmp.height() - crop_size) / 2;
	  for (int c = 0; c < this->datum_channels_; c++) {
		  for (int h = 0; h < crop_size; h++) {
			  for (int w = 0; w < crop_size; w++) {
				  int src_idx = (c * tmp.height() + h + h_off) * tmp.width() + w + w_off;
				  int dst_idx = (c * crop_size + h) * crop_size + w;
				  dst_data[dst_idx] = src_data[src_idx];
			  }
		  }
	  }
  } else {
	// Simply initialize an all-empty mean.
	this->data_mean_.Reshape(1, this->datum_channels_, crop_size, crop_size);
  }

  this->mean_ = this->data_mean_.cpu_data();
  this->data_transformer_.InitRand();

  this->prefetch_data_.mutable_cpu_data();
  if (this->output_labels_) {
    this->prefetch_label_.mutable_cpu_data();
  }
  DLOG(INFO) << "Initializing prefetch";
  this->CreatePrefetchThread();
  DLOG(INFO) << "Prefetch initialized.";
}
コード例 #3
0
ファイル: caffe.cpp プロジェクト: aaalgo/xnn
 CaffeModel (fs::path const& dir, int batch)
     : CaffeSetMode(mode),
     net((dir/"caffe.model").native(), TEST)
 {
     BOOST_VERIFY(batch >= 1);
     //CHECK_EQ(net.num_inputs(), 1) << "Network should have exactly one input: " << net.num_inputs();
     input_blob = net.input_blobs()[0];
     shape[0] = batch;
     shape[1] = input_blob->shape(1);
     CHECK(shape[1] == 3 || shape[1] == 1)
         << "Input layer should have 1 or 3 channels." << shape[1];
     net.CopyTrainedLayersFrom((dir/"caffe.params").native());
     // resize to required batch size
     shape[2] = input_blob->shape(2);
     shape[3] = input_blob->shape(3);
     input_blob->Reshape(shape[0], shape[1], shape[2], shape[3]);
     net.Reshape();
     // set mean file
     means[0] = means[1] = means[2] = 0;
     fs::path mean_file = dir / "caffe.mean";
     fs::ifstream test(mean_file);
     if (test) {
         BlobProto blob_proto;
         // check old format
         if (ReadProtoFromBinaryFile(mean_file.native(), &blob_proto)) {
             /* Convert from BlobProto to Blob<float> */
             Blob<float> meanblob;
             meanblob.FromProto(blob_proto);
             CHECK_EQ(meanblob.channels(), channels())
                 << "Number of channels of mean file doesn't match input layer.";
             /* The format of the mean file is planar 32-bit float BGR or grayscale. */
             vector<cv::Mat> mats;
             float* data = meanblob.mutable_cpu_data();
             for (int i = 0; i < channels(); ++i) {
                 /* Extract an individual channel. */
                 cv::Mat channel(meanblob.height(), meanblob.width(), CV_32FC1, data);
                 mats.push_back(channel);
                 data += meanblob.height() * meanblob.width();
             }
             /* Merge the separate channels into a single image. */
             cv::Mat merged;
             cv::merge(mats, merged);
             cv::Scalar channel_mean = cv::mean(merged);
             //mean = cv::Mat(input_height, input_width, merged.type(), channel_mean);
             means[0] = means[1] = means[2] = channel_mean[0];
             if (channels() > 1) {
                 means[1] = channel_mean[1];
                 means[2] = channel_mean[2];
             }   
         }
         // if not proto format, then the mean file is just a bunch of textual numbers
         else {
             test >> means[0];
             means[1] = means[2] = means[0];
             test >> means[1];
             test >> means[2];
         }
     }
     {
         fs::ifstream is(dir/"blobs");
         string blob;
         CHECK(is) << "cannot open blobs file.";
         while (is >> blob) {
             output_blobs.push_back(net.blob_by_name(blob));
         }
     }
 }