int main(int argc, char** argv) {
	::google::InitGoogleLogging(argv[0]);
	if (argc != 3) {
		LOG(ERROR)<< "Usage: demo_compute_image_mean input_leveldb output_file";
		return(0);
	}

	leveldb::DB* db;
	leveldb::Options options;
	options.create_if_missing = false;

	LOG(INFO) << "Opening leveldb " << argv[1];
	leveldb::Status status = leveldb::DB::Open(options, argv[1], &db);
	CHECK(status.ok()) << "Failed to open leveldb " << argv[1];

	leveldb::ReadOptions read_options;
	read_options.fill_cache = false;
	leveldb::Iterator* it = db->NewIterator(read_options);
	it->SeekToFirst();
	Datum datum;
	BlobProto sum_blob;
	int count = 0;
	datum.ParseFromString(it->value().ToString());
	sum_blob.set_num(1);
	sum_blob.set_channels(datum.channels());
	sum_blob.set_height(datum.height());
	sum_blob.set_width(datum.width());
	const int data_size = datum.channels() * datum.height() * datum.width();
	for (int i = 0; i < datum.data().size(); ++i) {
		sum_blob.add_data(0.);
	}
	LOG(INFO) << "Starting Iteration";
	for (it->SeekToFirst(); it->Valid(); it->Next()) {
		// just a dummy operation
		datum.ParseFromString(it->value().ToString());
		const string& data = datum.data();
		CHECK_EQ(data.size(), data_size)<< "Incorrect data field size " << data.size();
		for (int i = 0; i < data.size(); ++i) {
			sum_blob.set_data(i, sum_blob.data(i) + (uint8_t) data[i]);
		}
		++count;
		if (count % 10000 == 0) {
			LOG(ERROR)<< "Processed " << count << " files.";
			if (count == 100000) break;
		}
	}
	for (int i = 0; i < sum_blob.data_size(); ++i) {
		sum_blob.set_data(i, sum_blob.data(i) / count);
	}
	// Write to disk
	LOG(INFO) << "Write to " << argv[2];
	WriteProtoToBinaryFile(sum_blob, argv[2]);

	delete db;
	return 0;
}
コード例 #2
0
ファイル: io.cpp プロジェクト: ravi-teja-mullapudi/Halide-NN
cv::Mat DatumToCVMat(const Datum& datum) {

    if (datum.encoded()) {
        cv::Mat cv_img;
        cv_img = DecodeDatumToCVMatNative(datum);
        return cv_img;
    }

    const string& data = datum.data();

    int datum_channels = datum.channels();
    int datum_height = datum.height();
    int datum_width = datum.width();

    CHECK(datum_channels==3);

    cv::Mat cv_img(datum_height, datum_width, CV_8UC3);

    for (int h = 0; h < datum_height; ++h) {
        for (int w = 0; w < datum_width; ++w) {
            for (int c = 0; c < datum_channels; ++c) {
                int datum_index = (c * datum_height + h) * datum_width + w;
                cv_img.at<cv::Vec3b>(h, w)[c] = static_cast<uchar>(data[datum_index]);
            }
        }
    }

    return cv_img;
}
コード例 #3
0
ファイル: test_io.cpp プロジェクト: azrael417/caffe
TEST_F(IOTest, TestReadFileToDatum) {
  string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
  Datum datum;
  EXPECT_TRUE(ReadFileToDatum(filename, &datum));
  EXPECT_TRUE(datum.encoded());
  EXPECT_EQ(datum.label(), -1);
  EXPECT_EQ(datum.data().size(), 140391);
}
コード例 #4
0
ファイル: test_io.cpp プロジェクト: azrael417/caffe
TEST_F(IOTest, TestCVMatToDatumReference) {
  string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
  cv::Mat cv_img = ReadImageToCVMat(filename);
  Datum datum;
  CVMatToDatum(cv_img, &datum);
  Datum datum_ref;
  ReadImageToDatumReference(filename, 0, 0, 0, true, &datum_ref);
  EXPECT_EQ(datum.channels(), datum_ref.channels());
  EXPECT_EQ(datum.height(), datum_ref.height());
  EXPECT_EQ(datum.width(), datum_ref.width());
  EXPECT_EQ(datum.data().size(), datum_ref.data().size());

  const string& data = datum.data();
  const string& data_ref = datum_ref.data();
  for (int i = 0; i < datum.data().size(); ++i) {
    EXPECT_TRUE(data[i] == data_ref[i]);
  }
}
コード例 #5
0
ファイル: test_io.cpp プロジェクト: azrael417/caffe
TEST_F(IOTest, TestDecodeDatumNativeGray) {
  string filename = EXAMPLES_SOURCE_DIR "images/cat_gray.jpg";
  Datum datum;
  EXPECT_TRUE(ReadFileToDatum(filename, &datum));
  EXPECT_TRUE(DecodeDatumNative(&datum));
  EXPECT_FALSE(DecodeDatumNative(&datum));
  Datum datum_ref;
  ReadImageToDatumReference(filename, 0, 0, 0, false, &datum_ref);
  EXPECT_EQ(datum.channels(), datum_ref.channels());
  EXPECT_EQ(datum.height(), datum_ref.height());
  EXPECT_EQ(datum.width(), datum_ref.width());
  EXPECT_EQ(datum.data().size(), datum_ref.data().size());

  const string& data = datum.data();
  const string& data_ref = datum_ref.data();
  for (int i = 0; i < datum.data().size(); ++i) {
    EXPECT_TRUE(data[i] == data_ref[i]);
  }
}
コード例 #6
0
ファイル: io.cpp プロジェクト: xieguotian/caffe
cv::Mat DecodeDatumToCVMatNative(const Datum& datum) {
  cv::Mat cv_img;
  CHECK(datum.encoded()) << "Datum not encoded";
  const string& data = datum.data();
  std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());
  cv_img = cv::imdecode(vec_data, -1);
  if (!cv_img.data) {
    LOG(ERROR) << "Could not decode datum ";
  }
  return cv_img;
}
コード例 #7
0
vector<double> GetChannelMean(scoped_ptr<db::Cursor>& cursor)
{
	vector<double> meanv(3, 0);

	int count = 0;
	LOG(INFO) << "Starting Iteration";
	while (cursor->valid()) {
		Datum datum;
		datum.ParseFromString(cursor->value());
		DecodeDatumNative(&datum);

		const std::string& data = datum.data();
		int w = datum.width(), h = datum.height();
		int ch = datum.channels();
		int dim = w*h;
		double chmean[3] = { 0,0,0 };
		for (int i = 0; i < ch;i++)
		{
			int chstart = i*dim;
			for (int j = 0; j < dim;j++)
				chmean[i] += (uint8_t)data[chstart+j];
			chmean[i] /= dim;
		}
		if (ch == 1)
		{
			meanv[0] += chmean[0];
			meanv[1] += chmean[0];
			meanv[2] += chmean[0];
		}
		else
		{
			meanv[0] += chmean[0];
			meanv[1] += chmean[1];
			meanv[2] += chmean[2];
		}
		
		++count;
		if (count % 10000 == 0) {
			LOG(INFO) << "Processed " << count << " files.";
		}
		cursor->Next();
	}

	if (count % 10000 != 0) {
		LOG(INFO) << "Processed " << count << " files.";
	}

	for (int c = 0; c < 3; ++c) {
		LOG(INFO) << "mean_value channel [" << c << "]:" << meanv[c] / count;
	}

	return meanv;
}
コード例 #8
0
ファイル: io.cpp プロジェクト: xieguotian/caffe
cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) {
  cv::Mat cv_img;
  CHECK(datum.encoded()) << "Datum not encoded";
  const string& data = datum.data();
  std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());
  int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
    CV_LOAD_IMAGE_GRAYSCALE);
  cv_img = cv::imdecode(vec_data, cv_read_flag);
  if (!cv_img.data) {
    LOG(ERROR) << "Could not decode datum ";
  }
  return cv_img;
}
コード例 #9
0
bool MostCV::LevelDBReader::GetNextEntry(string &key, vector<double> &retVec, int &label) {
  if (!database_iter_->Valid())
    return false;

  Datum datum;
  datum.clear_float_data();
  datum.clear_data();
  datum.ParseFromString(database_iter_->value().ToString());

  key = database_iter_->key().ToString();
  label = datum.label();

  int expected_data_size = std::max<int>(datum.data().size(), datum.float_data_size());
  const int datum_volume_size = datum.channels() * datum.height() * datum.width();
  if (expected_data_size != datum_volume_size) {
    cout << "Something wrong in saved data.";
    assert(false);
  }

  retVec.resize(datum_volume_size);

  const string& data = datum.data();
  if (data.size() != 0) {
    // Data stored in string, e.g. just pixel values of 196608 = 256 * 256 * 3
    for (int i = 0; i < datum_volume_size; ++i)
      retVec[i] = data[i];
  } else {
    // Data stored in real feature vector such as 4096 from feature extraction
    for (int i = 0; i < datum_volume_size; ++i)
      retVec[i] = datum.float_data(i);
  }

  database_iter_->Next();
  ++record_idx_;

  return true;
}
コード例 #10
0
ファイル: test_io.cpp プロジェクト: azrael417/caffe
TEST_F(IOTest, TestReadImageToDatumContentGray) {
  string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
  Datum datum;
  const bool is_color = false;
  ReadImageToDatum(filename, 0, is_color, &datum);
  cv::Mat cv_img = ReadImageToCVMat(filename, is_color);
  EXPECT_EQ(datum.channels(), cv_img.channels());
  EXPECT_EQ(datum.height(), cv_img.rows);
  EXPECT_EQ(datum.width(), cv_img.cols);

  const string& data = datum.data();
  int index = 0;
  for (int h = 0; h < datum.height(); ++h) {
    for (int w = 0; w < datum.width(); ++w) {
      EXPECT_TRUE(data[index++] == static_cast<char>(cv_img.at<uchar>(h, w)));
    }
  }
}
コード例 #11
0
ファイル: io.cpp プロジェクト: williford/seg-caffe
cv::Mat DecodeDatumToCVMat(const Datum& datum,
    const int height, const int width, const bool is_color) {
  cv::Mat cv_img;
  CHECK(datum.encoded()) << "Datum not encoded";
  int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
    CV_LOAD_IMAGE_GRAYSCALE);
  const string& data = datum.data();
  std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());
  if (height > 0 && width > 0) {
    cv::Mat cv_img_origin = cv::imdecode(cv::Mat(vec_data), cv_read_flag);
    cv::resize(cv_img_origin, cv_img, cv::Size(width, height));
  } else {
    cv_img = cv::imdecode(vec_data, cv_read_flag);
  }
  if (!cv_img.data) {
    LOG(ERROR) << "Could not decode datum ";
  }
  return cv_img;
}
コード例 #12
0
ファイル: test_io.cpp プロジェクト: naibaf7/caffe
TEST_F(IOTest, TestReadImageToDatumContent) {
  string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
  Datum datum;
  ReadImageToDatum(filename, 0, &datum);
  cv::Mat cv_img = ReadImageToCVMat(filename);
  EXPECT_EQ(datum.channels(), cv_img.channels());
  EXPECT_EQ(datum.height(), cv_img.rows);
  EXPECT_EQ(datum.width(), cv_img.cols);

  const string& data = datum.data();
  int_tp index = 0;
  for (int_tp c = 0; c < datum.channels(); ++c) {
    for (int_tp h = 0; h < datum.height(); ++h) {
      for (int_tp w = 0; w < datum.width(); ++w) {
        EXPECT_TRUE(data[index++] ==
          static_cast<char>(cv_img.at<cv::Vec3b>(h, w)[c]));
      }
    }
  }
}
コード例 #13
0
ファイル: data_reader.cpp プロジェクト: XinLiuNvidia/caffe
void DataReader::Body::read_one(db::Cursor* cursor, db::Transaction* dblt, QueuePair* qp) {
  Datum* datum = qp->free_.pop();
  // TODO deserialize in-place instead of copy?
  datum->ParseFromString(cursor->value());
  if (dblt != NULL) {
    string labels;
    CHECK_EQ(dblt->Get(cursor->key(), labels), 0);
    Datum labelDatum;
    labelDatum.ParseFromString(labels);
//    datum->MergeFrom(labelDatum);
    datum->set_channels(datum->channels() + labelDatum.channels());
    datum->mutable_float_data()->MergeFrom(labelDatum.float_data());
    datum->mutable_data()->append(labelDatum.data());
  }
  qp->full_.push(datum);

  // go to the next iter
  cursor->Next();
  if (!cursor->valid()) {
    DLOG(INFO) << "Restarting data prefetching from start.";
    cursor->SeekToFirst();
  }
}
コード例 #14
0
void DataTransformer<Dtype>::PostTransform(const int batch_item_id,
                                       const Datum& datum,
                                       const Dtype* mean,
                                       Dtype* transformed_data)
{
	  const string& data = datum.data();
	  const int channels = datum.channels();
	  const int height = datum.height();
	  const int width = datum.width();
	  const int size = datum.channels() * datum.height() * datum.width();

	  /**
	   * only works for uint8 data data.
	   * post transfrom  parameters:
	   * int : post_random_translation_size
	   * string : post_ground_truth_pooling_param   : [num_of_pooling] [pooling_h_1] ] [pooling_w_1] [pooling_h_2],.......
	   * int : post_channel_for_additional_translation
	   */
	  const int crop_size = param_.crop_size();
	  const bool mirror = param_.mirror();
	  const Dtype scale = param_.scale();

//	  if(param_.has_post_random_translation_size())
//	  {
//
//	  }
//	  if(param_.has_post_ground_truth_pooling_param())
//	  {
//
//	  }
//	  if(param_.has_post_channel_for_additional_translation())
//	  {
//
//	  }

}
コード例 #15
0
void DataLstmTrainHistLayer<Dtype>::InternalThreadEntry() {
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(this->prefetch_data_.count());

  Datum datum;
  Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
  Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
  Dtype* top_hist = this->prefetch_hist_.mutable_cpu_data();
  Dtype* top_marker = this->prefetch_marker_.mutable_cpu_data();

  // datum scales
  const int size = resize_height*resize_width*3;
  const Dtype* mean = this->data_mean_.mutable_cpu_data();

  string value;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int key;

  const int sequence_size = this->layer_param_.data_lstm_train_hist_param().sequence_size();
  const int ind_seq_num=this->layer_param_.data_lstm_train_hist_param().sequence_num();
  const int interval=this->layer_param_.data_lstm_train_hist_param().interval();
  int item_id;

  for (int time_id = 0; time_id < sequence_size; ++time_id) {
     for (int seq_id = 0; seq_id < ind_seq_num; ++seq_id) {
        item_id=time_id*ind_seq_num+seq_id;
        timer.Start();
        // get a blob

        key=buffer_key[seq_id];  // MUST be changed according to the size of the training set

        snprintf(key_cstr, kMaxKeyLength, "%08d", key);
        db_->Get(leveldb::ReadOptions(), string(key_cstr), &value);
        datum.ParseFromString(value);
        const string& data = datum.data();

        read_time += timer.MicroSeconds();
        timer.Start();

        for (int j = 0; j < size; ++j) {
           Dtype datum_element = static_cast<Dtype>(static_cast<uint8_t>(data[j]));
           top_data[item_id * size + j] = (datum_element - mean[j]);
        }

        for (int j = 0; j < para_dim; ++j) { 
           top_label[item_id * para_dim + j] = datum.float_data(j); 
        }

        top_marker[item_id] = datum.float_data(para_dim);

        if (buffer_marker[seq_id] == 0) {
            top_marker[item_id] = 0;   
            buffer_marker[seq_id] = 1;
        }

        //////////////////////////////////// for hist
        if (top_marker[item_id] < 0.5) {
           for (int j = 0; j < para_dim; ++j)
               top_hist[item_id * para_dim + j] = 0; 
        } else {
           if (time_id == 0) {
              top_hist[item_id * para_dim + 0] = hist_blob[seq_id * para_dim + 0]/1.1+0.5;
              top_hist[item_id * para_dim + 1] = hist_blob[seq_id * para_dim + 1]*0.17778+1.34445;
              top_hist[item_id * para_dim + 2] = hist_blob[seq_id * para_dim + 2]*0.14545+0.39091;
              top_hist[item_id * para_dim + 3] = hist_blob[seq_id * para_dim + 3]*0.17778-0.34445;
              top_hist[item_id * para_dim + 4] = hist_blob[seq_id * para_dim + 4]/95.0+0.12;
              top_hist[item_id * para_dim + 5] = hist_blob[seq_id * para_dim + 5]/95.0+0.12;
              top_hist[item_id * para_dim + 6] = hist_blob[seq_id * para_dim + 6]*0.14545+1.48181;
              top_hist[item_id * para_dim + 7] = hist_blob[seq_id * para_dim + 7]*0.16+0.98;
              top_hist[item_id * para_dim + 8] = hist_blob[seq_id * para_dim + 8]*0.16+0.02;
              top_hist[item_id * para_dim + 9] = hist_blob[seq_id * para_dim + 9]*0.14545-0.48181;
              top_hist[item_id * para_dim + 10] = hist_blob[seq_id * para_dim + 10]/95.0+0.12;
              top_hist[item_id * para_dim + 11] = hist_blob[seq_id * para_dim + 11]/95.0+0.12;
              top_hist[item_id * para_dim + 12] = hist_blob[seq_id * para_dim + 12]/95.0+0.12;
              top_hist[item_id * para_dim + 13] = hist_blob[seq_id * para_dim + 13]*0.6+0.2;
           } else {
              int pre_id=(time_id-1)*ind_seq_num+seq_id;
              top_hist[item_id * para_dim + 0] = top_label[pre_id * para_dim + 0]/1.1+0.5;
              top_hist[item_id * para_dim + 1] = top_label[pre_id * para_dim + 1]*0.17778+1.34445;
              top_hist[item_id * para_dim + 2] = top_label[pre_id * para_dim + 2]*0.14545+0.39091;
              top_hist[item_id * para_dim + 3] = top_label[pre_id * para_dim + 3]*0.17778-0.34445;
              top_hist[item_id * para_dim + 4] = top_label[pre_id * para_dim + 4]/95.0+0.12;
              top_hist[item_id * para_dim + 5] = top_label[pre_id * para_dim + 5]/95.0+0.12;
              top_hist[item_id * para_dim + 6] = top_label[pre_id * para_dim + 6]*0.14545+1.48181;
              top_hist[item_id * para_dim + 7] = top_label[pre_id * para_dim + 7]*0.16+0.98;
              top_hist[item_id * para_dim + 8] = top_label[pre_id * para_dim + 8]*0.16+0.02;
              top_hist[item_id * para_dim + 9] = top_label[pre_id * para_dim + 9]*0.14545-0.48181;
              top_hist[item_id * para_dim + 10] = top_label[pre_id * para_dim + 10]/95.0+0.12;
              top_hist[item_id * para_dim + 11] = top_label[pre_id * para_dim + 11]/95.0+0.12;
              top_hist[item_id * para_dim + 12] = top_label[pre_id * para_dim + 12]/95.0+0.12;
              top_hist[item_id * para_dim + 13] = top_label[pre_id * para_dim + 13]*0.6+0.2;
           }
        }
        //////////////////////////////////// for hist

        trans_time += timer.MicroSeconds();

        buffer_key[seq_id]++;
        buffer_total[seq_id]++;
        if (buffer_key[seq_id]>total_frames || buffer_total[seq_id]>interval) {
           buffer_key[seq_id]=random(total_frames)+1;
           buffer_marker[seq_id]=0;
           buffer_total[seq_id]=0;
        }

        //////////////////////////////////// for hist
        if (time_id==sequence_size-1) {
           for (int j = 0; j < para_dim; ++j) 
               hist_blob[seq_id * para_dim + j] = datum.float_data(j); 
        }
        //////////////////////////////////// for hist

/*
        if (seq_id == 0) {
           for (int h = 0; h < resize_height; ++h) {
              for (int w = 0; w < resize_width; ++w) {
                 leveldbTrain->imageData[(h*resize_width+w)*3+0]=(uint8_t)data[h*resize_width+w];
                 leveldbTrain->imageData[(h*resize_width+w)*3+1]=(uint8_t)data[resize_height*resize_width+h*resize_width+w];
                 leveldbTrain->imageData[(h*resize_width+w)*3+2]=(uint8_t)data[resize_height*resize_width*2+h*resize_width+w];

                 //leveldbTrain->imageData[(h*resize_width+w)*3+0]=(uint8_t)top_data[item_id * size+h*resize_width+w];
                 //leveldbTrain->imageData[(h*resize_width+w)*3+1]=(uint8_t)top_data[item_id * size+resize_height*resize_width+h*resize_width+w];
                 //leveldbTrain->imageData[(h*resize_width+w)*3+2]=(uint8_t)top_data[item_id * size+resize_height*resize_width*2+h*resize_width+w];
               }
           }
           cvShowImage("Image from leveldb", leveldbTrain);
           cvWaitKey( 1 );
        }
*/
     }
  }

  batch_timer.Stop();
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}
コード例 #16
0
ファイル: main.cpp プロジェクト: cciliber/objrecpipe_mat
int main(int argc, char** argv) {
#ifdef USE_OPENCV
  ::google::InitGoogleLogging(argv[0]);
  // Print output to stderr (while still logging)
  FLAGS_alsologtostderr = 1;

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc < 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "convert_imageset");
    return 1;
  }

  const bool is_color = !FLAGS_gray;
  const bool check_size = FLAGS_check_size;
  const bool encoded = FLAGS_encoded;
  const string encode_type = FLAGS_encode_type;

  std::ifstream infile(argv[2]);
  std::vector<std::pair<std::string, int> > lines;
  std::string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (FLAGS_shuffle) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  if (encode_type.size() && !encoded)
    LOG(INFO) << "encode_type specified, assuming encoded=true.";

  int resize_height = std::max<int>(0, FLAGS_resize_height);
  int resize_width = std::max<int>(0, FLAGS_resize_width);

  // Create new DB
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(argv[3], db::NEW);
  scoped_ptr<db::Transaction> txn(db->NewTransaction());

  // Storing to db
  std::string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  int data_size = 0;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    bool status;
    std::string enc = encode_type;
    if (encoded && !enc.size()) {
      // Guess the encoding type from the file name
      string fn = lines[line_id].first;
      size_t p = fn.rfind('.');
      if ( p == fn.npos )
        LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
      enc = fn.substr(p);
      std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
    }
    status = ReadImageToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, resize_height, resize_width, is_color,
        enc, &datum);
    if (status == false) continue;
    if (check_size) {
      if (!data_size_initialized) {
        data_size = datum.channels() * datum.height() * datum.width();
        data_size_initialized = true;
      } else {
        const std::string& data = datum.data();
        CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
            << data.size();
      }
    }
    // sequential
    string key_str = caffe::format_int(line_id, 8) + "_" + lines[line_id].first;

    // Put in db
    string out;
    CHECK(datum.SerializeToString(&out));
    txn->Put(key_str, out);

    if (++count % 1000 == 0) {
      // Commit db
      txn->Commit();
      txn.reset(db->NewTransaction());
      LOG(INFO) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    txn->Commit();
    LOG(INFO) << "Processed " << count << " files.";
  }
#else
  LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV.";
#endif  // USE_OPENCV
  return 0;
}
コード例 #17
0
ファイル: convert_imageset.cpp プロジェクト: MaoXu/jade
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc < 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
    return 1;
  }

  const bool is_color = !FLAGS_gray;
  const bool check_size = FLAGS_check_size;
  const bool encoded = FLAGS_encoded;

  std::ifstream infile(argv[2]);
  std::vector<std::pair<std::string, int> > lines;
  std::string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (FLAGS_shuffle) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  if (encoded) {
    CHECK_EQ(FLAGS_resize_height, 0) << "With encoded don't resize images";
    CHECK_EQ(FLAGS_resize_width, 0) << "With encoded don't resize images";
    CHECK(!check_size) << "With encoded cannot check_size";
  }

  int resize_height = std::max<int>(0, FLAGS_resize_height);
  int resize_width = std::max<int>(0, FLAGS_resize_width);

  // Create new DB
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(argv[3], db::NEW);
  scoped_ptr<db::Transaction> txn(db->NewTransaction());

  // Storing to db
  std::string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int data_size;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    bool status;
    if (encoded) {
      status = ReadFileToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, &datum);
    } else {
      status = ReadImageToDatum(root_folder + lines[line_id].first,
          lines[line_id].second, resize_height, resize_width, is_color, &datum);
    }
    if (status == false) continue;
    if (check_size) {
      if (!data_size_initialized) {
        data_size = datum.channels() * datum.height() * datum.width();
        data_size_initialized = true;
      } else {
        const std::string& data = datum.data();
        CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
            << data.size();
      }
    }
    // sequential
    int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());

    // Put in db
    string out;
    CHECK(datum.SerializeToString(&out));
    txn->Put(string(key_cstr, length), out);

    if (++count % 1000 == 0) {
      // Commit db
      txn->Commit();
      txn.reset(db->NewTransaction());
      LOG(ERROR) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    txn->Commit();
    LOG(ERROR) << "Processed " << count << " files.";
  }
  return 0;
}
コード例 #18
0
ファイル: calc_mean_stddev.cpp プロジェクト: caomw/ssai
std::vector<float> calc_mean(const std::string &db_fname) {
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(db_fname, db::READ);
  scoped_ptr<db::Cursor> cursor(db->NewCursor());

  BlobProto sum_blob;
  int count = 0;
  // load first datum
  Datum datum;
  datum.ParseFromString(cursor->value());

  if (DecodeDatumNative(&datum)) {
    LOG(INFO) << "Decoding Datum";
  }

  sum_blob.set_num(1);
  sum_blob.set_channels(datum.channels());
  sum_blob.set_height(datum.height());
  sum_blob.set_width(datum.width());
  const int data_size = datum.channels() * datum.height() * datum.width();
  int size_in_datum = std::max<int>(datum.data().size(),
                                    datum.float_data_size());
  for (int i = 0; i < size_in_datum; ++i) {
    sum_blob.add_data(0.);
  }
  LOG(INFO) << "Starting Iteration";
  while (cursor->valid()) {
    Datum datum;
    datum.ParseFromString(cursor->value());
    DecodeDatumNative(&datum);

    const std::string& data = datum.data();
    size_in_datum = std::max<int>(datum.data().size(),
                                  datum.float_data_size());
    CHECK_EQ(size_in_datum, data_size)
      << "Incorrect data field size " << size_in_datum;

    if (data.size() != 0) {
      CHECK_EQ(data.size(), size_in_datum);
      for (int i = 0; i < size_in_datum; ++i) {
        sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
      }
    } else {
      CHECK_EQ(datum.float_data_size(), size_in_datum);
      for (int i = 0; i < size_in_datum; ++i) {
        sum_blob.set_data(i, sum_blob.data(i) +
            static_cast<float>(datum.float_data(i)));
      }
    }
    ++count;
    if (count % 10000 == 0) {
      LOG(INFO) << "Processed " << count << " files.";
    }
    cursor->Next();
  }

  if (count % 10000 != 0) {
    LOG(INFO) << "Processed " << count << " files.";
  }
  for (int i = 0; i < sum_blob.data_size(); ++i) {
    sum_blob.set_data(i, sum_blob.data(i) / count);
  }

  const int channels = sum_blob.channels();
  const int dim = sum_blob.height() * sum_blob.width();
  std::vector<float> mean_values(channels, 0.0);
  LOG(INFO) << "Number of channels: " << channels;
  for (int c = 0; c < channels; ++c) {
    for (int i = 0; i < dim; ++i) {
      mean_values[c] += sum_blob.data(dim * c + i);
    }
    mean_values[c] /= dim;
    LOG(INFO) << "mean_value channel [" << c << "]:" << mean_values[c];
  }

  return mean_values;
}
コード例 #19
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 4) {
    printf("Convert a set of images to the leveldb format used\n"
        "as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset ROOTFOLDER/ LISTFILE NUM_LABELS DB_NAME"
        " RANDOM_SHUFFLE_DATA[0 or 1]\n");
    return 0;
  }
  //Arguments to our program
  std::ifstream infile(argv[2]);
  int numLabels = atoi(argv[3]);
  
  // Each line is constituted of the path to the file and the vector of 
  // labels
  std::vector<std::pair<string, std::vector<float> > > lines;
  // --------
  string filename;
  std::vector<float> labels(numLabels);
  
  while (infile >> filename) {
	  for(int l=0; l<numLabels; l++)
		  infile >> (labels[l]);
    lines.push_back(std::make_pair(filename, labels));
    /*
    LOG(ERROR) <<  "filepath: " << lines[lines.size()-1].first;
    LOG(ERROR) << "values: " << lines[lines.size()-1].second[0] 
			   << ","		 << lines[lines.size()-1].second[5] 
			   << ","		 << lines[lines.size()-1].second[8];
			   * */
  }
  if (argc == 5 && argv[5][0] == '1') {
    // randomly shuffle data
    LOG(ERROR) << "Shuffling data";
    std::random_shuffle(lines.begin(), lines.end());
  }
  LOG(ERROR) << "A total of " << lines.size() << " images.";

  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  LOG(ERROR) << "Opening leveldb " << argv[4];
  leveldb::Status status = leveldb::DB::Open(
      options, argv[4], &db);
  CHECK(status.ok()) << "Failed to open leveldb " << argv[4];
 
  string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  const int maxKeyLength = 256;
  char key_cstr[maxKeyLength];
  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
  int data_size;
  bool data_size_initialized = false;
  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageWithLabelVectorToDatum(root_folder + lines[line_id].first, lines[line_id].second,
										 &datum)) {
      continue;
    };
  
    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
    } else {
      const string& data = datum.data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size " << data.size();
    }
    
    // sequential
    snprintf(key_cstr, maxKeyLength, "%08d_%s", line_id, lines[line_id].first.c_str());
    string value;
    
    // get the value
    datum.SerializeToString(&value);
    batch->Put(string(key_cstr), value);
    if (++count % 1000 == 0) {
      db->Write(leveldb::WriteOptions(), batch);
      LOG(ERROR) << "Processed " << count << " files.";
      delete batch;
      batch = new leveldb::WriteBatch();
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    db->Write(leveldb::WriteOptions(), batch);
    LOG(ERROR) << "Processed " << count << " files.";
  }

  delete batch;
  delete db;
  return 0;
}
コード例 #20
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 4 || argc > 9) {
    printf("Convert a set of images to the leveldb format used\n"
        "as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [-g] ROOTFOLDER/ LISTFILE DB_NAME"
        " RANDOM_SHUFFLE_DATA[0 or 1] DB_BACKEND[leveldb or lmdb]"
        " [resize_height] [resize_width]\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
    return 1;
  }

  // Test whether argv[1] == "-g"
  bool is_color= !(string("-g") == string(argv[1]));
  int  arg_offset = (is_color ? 0 : 1);
  std::ifstream infile(argv[arg_offset+2]);
  std::vector<std::pair<string, int> > lines;
  string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (argc >= (arg_offset+5) && argv[arg_offset+4][0] == '1') {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  string db_backend = "leveldb";
  if (argc >= (arg_offset+6)) {
    db_backend = string(argv[arg_offset+5]);
    if (!(db_backend == "leveldb") && !(db_backend == "lmdb")) {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }
  }

  int resize_height = 0;
  int resize_width = 0;
  if (argc >= (arg_offset+7)) {
    resize_height = atoi(argv[arg_offset+6]);
  }
  if (argc >= (arg_offset+8)) {
    resize_width = atoi(argv[arg_offset+7]);
  }

  // Open new db
  // lmdb
  MDB_env *mdb_env;
  MDB_dbi mdb_dbi;
  MDB_val mdb_key, mdb_data;
  MDB_txn *mdb_txn;
  // leveldb
  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  leveldb::WriteBatch* batch = NULL;

  // Open db
  if (db_backend == "leveldb") {  // leveldb
    LOG(INFO) << "Opening leveldb " << argv[arg_offset+3];
    leveldb::Status status = leveldb::DB::Open(
        options, argv[arg_offset+3], &db);
    CHECK(status.ok()) << "Failed to open leveldb " << argv[arg_offset+3];
    batch = new leveldb::WriteBatch();
  } else if (db_backend == "lmdb") {  // lmdb
    LOG(INFO) << "Opening lmdb " << argv[arg_offset+3];
    CHECK_EQ(mkdir(argv[arg_offset+3], 0744), 0)
        << "mkdir " << argv[arg_offset+3] << "failed";
    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB
        << "mdb_env_set_mapsize failed";
    CHECK_EQ(mdb_env_open(mdb_env, argv[3], 0, 0664), MDB_SUCCESS)
        << "mdb_env_open failed";
    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
        << "mdb_txn_begin failed";
    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
        << "mdb_open failed";
  } else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  // Storing to db
  string root_folder(argv[arg_offset+1]);
  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int data_size;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, resize_height, resize_width, is_color, &datum)) {
      continue;
    }
    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
      data_size_initialized = true;
    } else {
      const string& data = datum.data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
          << data.size();
    }
    // sequential
    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());
    string value;
    datum.SerializeToString(&value);
    string keystr(key_cstr);

    // Put in db
    if (db_backend == "leveldb") {  // leveldb
      batch->Put(keystr, value);
    } else if (db_backend == "lmdb") {  // lmdb
      mdb_data.mv_size = value.size();
      mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
      mdb_key.mv_size = keystr.size();
      mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);
      CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS)
          << "mdb_put failed";
    } else {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }

    if (++count % 1000 == 0) {
      // Commit txn
      if (db_backend == "leveldb") {  // leveldb
        db->Write(leveldb::WriteOptions(), batch);
        delete batch;
        batch = new leveldb::WriteBatch();
      } else if (db_backend == "lmdb") {  // lmdb
        CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS)
            << "mdb_txn_commit failed";
        CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
            << "mdb_txn_begin failed";
      } else {
        LOG(FATAL) << "Unknown db backend " << db_backend;
      }
      LOG(ERROR) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    if (db_backend == "leveldb") {  // leveldb
      db->Write(leveldb::WriteOptions(), batch);
      delete batch;
      delete db;
    } else if (db_backend == "lmdb") {  // lmdb
      CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
      mdb_close(mdb_env, mdb_dbi);
      mdb_env_close(mdb_env);
    } else {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }
    LOG(ERROR) << "Processed " << count << " files.";
  }
  return 0;
}
コード例 #21
0
ファイル: calc_mean_stddev.cpp プロジェクト: caomw/ssai
void calc_stddev(
  const std::string &db_fname,
  std::vector<float> mean_values) {
  
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(db_fname, db::READ);
  scoped_ptr<db::Cursor> cursor(db->NewCursor());

  // load first datum
  Datum datum;
  datum.ParseFromString(cursor->value());

  if (DecodeDatumNative(&datum)) {
    LOG(INFO) << "Decoding Datum";
  }

  std::vector<double> stddev_values;
  for (int c = 0; c < datum.channels(); ++c) {
    stddev_values.push_back(0.0);
  }

  int files = 0;
  unsigned long count = 0;
  LOG(INFO) << "Starting Iteration";

  while (cursor->valid()) {
    Datum datum;
    datum.ParseFromString(cursor->value());
    DecodeDatumNative(&datum);

    const int channels = datum.channels();
    const int height = datum.height();
    const int width = datum.width();
    const std::string& data = datum.data();

    for (int c = 0; c < channels; ++c) {
      for (int h = 0; h < height; ++h) {
        for (int w = 0; w < width; ++w) {
          const int index = c * height * width + h * width + w;
          const int pixel = static_cast<uint8_t>(data[index]);
          stddev_values[c] += pow((double)pixel - mean_values[c], 2.0);
        }
      }
    }

    count += width * height;
    ++files;
    if (count % 10000 == 0) {
      LOG(INFO) << "Processed " << files << " files.";
      LOG(INFO) << "count:" << count;
    }
    cursor->Next();
  }
  if (files % 10000 != 0) {
    LOG(INFO) << "Processed " << files << " files.";
    LOG(INFO) << "count: " << count;
  }
  LOG(INFO) << "Finished Iteration";

  std::cout.precision(15);
  LOG(INFO) << "Number of channels: " << datum.channels();
  for (int c = 0; c < datum.channels(); ++c) {
    stddev_values[c] /= (double)count;
    stddev_values[c] = sqrt(stddev_values[c]);
    LOG(INFO) << "stddev_value channel [" << c << "]:"
              << std::fixed << stddev_values[c];
  }
}
コード例 #22
0
ファイル: data_transformer.cpp プロジェクト: chprasad/caffe
void DataTransformer<Dtype>::Transform(const Datum& datum,
                                       Dtype* transformed_data) {
  const string& data = datum.data();
  const int datum_channels = datum.channels();
  const int datum_height = datum.height();
  const int datum_width = datum.width();

  const int crop_size = param_.crop_size();
  const Dtype scale = param_.scale();
  const bool do_mirror = param_.mirror() && Rand(2);
  const bool has_mean_file = param_.has_mean_file();
  const bool has_uint8 = data.size() > 0;
  const bool has_mean_values = mean_values_.size() > 0;
  // mask_size is defaulted to 0 in caffe/proto/caffe.proto
  const int mask_size = param_.mask_size();
  // mask_freq is defaulted to 1 in 3 in caffe/proto/caffe.proto
  const int mask_freq = param_.mask_freq();

  CHECK_GT(datum_channels, 0);
  CHECK_GE(datum_height, crop_size);
  CHECK_GE(datum_width, crop_size);

  Dtype* mean = NULL;
  if (has_mean_file) {
    CHECK_EQ(datum_channels, data_mean_.channels());
    CHECK_EQ(datum_height, data_mean_.height());
    CHECK_EQ(datum_width, data_mean_.width());
    mean = data_mean_.mutable_cpu_data();
  }
  if (has_mean_values) {
    CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
     "Specify either 1 mean_value or as many as channels: " << datum_channels;
    if (datum_channels > 1 && mean_values_.size() == 1) {
      // Replicate the mean_value for simplicity
      for (int c = 1; c < datum_channels; ++c) {
        mean_values_.push_back(mean_values_[0]);
      }
    }
  }

  int height = datum_height;
  int width = datum_width;

  int h_off = 0;
  int w_off = 0;
  if (crop_size) {
    height = crop_size;
    width = crop_size;
    // We only do random crop when we do training.
    if (phase_ == TRAIN) {
      h_off = Rand(datum_height - crop_size + 1);
      w_off = Rand(datum_width - crop_size + 1);
    } else {
      h_off = (datum_height - crop_size) / 2;
      w_off = (datum_width - crop_size) / 2;
    }
  }

  // initialize masking offsets to be same as cropping offsets
  // so that there is no conflict
  bool masking = (phase_ == TRAIN) && (mask_size > 0) && (Rand(mask_freq) == 0);
  int h_mask_start = h_off;
  int w_mask_start = w_off;
  if (masking) {
    int h_effective = datum_height;
    int w_effective = datum_width;
    if (crop_size) { h_effective = w_effective = crop_size; }
    CHECK_GE(h_effective, mask_size);
    CHECK_GE(w_effective, mask_size);
    h_mask_start += Rand(h_effective-mask_size+1);
    w_mask_start += Rand(w_effective-mask_size+1);
  }
  int h_mask_end = h_mask_start + mask_size;
  int w_mask_end = w_mask_start + mask_size;

  Dtype datum_element;
  int top_index, data_index;
  for (int c = 0; c < datum_channels; ++c) {
    for (int h = 0; h < height; ++h) {
      for (int w = 0; w < width; ++w) {
        data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
        if (do_mirror) {
          top_index = (c * height + h) * width + (width - 1 - w);
        } else {
          top_index = (c * height + h) * width + w;
        }
        if (has_uint8) {
          datum_element =
            static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
        } else {
          datum_element = datum.float_data(data_index);
        }
        if (has_mean_file) {
          transformed_data[top_index] =
            (datum_element - mean[data_index]) * scale;
        } else {
          if (has_mean_values) {
            transformed_data[top_index] =
              (datum_element - mean_values_[c]) * scale;
          } else {
            transformed_data[top_index] = datum_element * scale;
          }
        }
        if (masking) {
          if ((h > h_mask_start) && (w > w_mask_start) &&
              (h < h_mask_end) && (w < w_mask_end)) {
            transformed_data[top_index] = 0;
          }
        }
      }
    }
  }
}
コード例 #23
0
ファイル: compute_image_mean.cpp プロジェクト: pl8787/caffe
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);

  std::ifstream infile(argv[1]);
  std::vector<std::pair<string, int> > lines;
  string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  Datum datum;
  BlobProto sum_blob;
  int count = 0;

  if (!ReadImageToDatum(lines[0].first, lines[0].second, 
         resize_height, resize_width, is_color, &datum)) {
    return -1;
  }

  sum_blob.set_num(1);
  sum_blob.set_channels(datum.channels());
  sum_blob.set_height(datum.height());
  sum_blob.set_width(datum.width());
  const int data_size = datum.channels() * datum.height() * datum.width();
  int size_in_datum = std::max<int>(datum.data().size(),
                                    datum.float_data_size());
  for (int i = 0; i < size_in_datum; ++i) {
    sum_blob.add_data(0.);
  }

  LOG(INFO) << "Starting Iteration";
  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageToDatum(lines[line_id].first, lines[line_id].second, 
           resize_height, resize_width, is_color, &datum)) {
      continue;
    }

    const string& data = datum.data();
    size_in_datum = std::max<int>(datum.data().size(),
        datum.float_data_size());
    CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
        size_in_datum;
    if (data.size() != 0) {
      for (int i = 0; i < size_in_datum; ++i) {
        sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
      }
    } else {
      for (int i = 0; i < size_in_datum; ++i) {
        sum_blob.set_data(i, sum_blob.data(i) +
            static_cast<float>(datum.float_data(i)));
      }
    }
    ++count;
  }

  for (int i = 0; i < sum_blob.data_size(); ++i) {
    sum_blob.set_data(i, sum_blob.data(i) / count);
  }

  // Write to disk
  LOG(INFO) << "Write to " << argv[2];
  WriteProtoToBinaryFile(sum_blob, argv[2]);

  return 0;
}
コード例 #24
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 3 || argc > 4) {
    LOG(ERROR) << "Usage: compute_image_mean input_db output_file"
               << " db_backend[leveldb or lmdb]";
    return 1;
  }

  string db_backend = "lmdb";
  if (argc == 4) {
    db_backend = string(argv[3]);
  }

  // Open leveldb
  leveldb::DB* db;
  leveldb::Options options;
  options.create_if_missing = false;
  leveldb::Iterator* it = NULL;
  // lmdb
  MDB_env* mdb_env;
  MDB_dbi mdb_dbi;
  MDB_val mdb_key, mdb_value;
  MDB_txn* mdb_txn;
  MDB_cursor* mdb_cursor;

  // Open db
  if (db_backend == "leveldb") {  // leveldb
    LOG(INFO) << "Opening leveldb " << argv[1];
    leveldb::Status status = leveldb::DB::Open(
        options, argv[1], &db);
    CHECK(status.ok()) << "Failed to open leveldb " << argv[1];
    leveldb::ReadOptions read_options;
    read_options.fill_cache = false;
    it = db->NewIterator(read_options);
    it->SeekToFirst();
  } else if (db_backend == "lmdb") {  // lmdb
    LOG(INFO) << "Opening lmdb " << argv[1];
    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS);  // 1TB
    CHECK_EQ(mdb_env_open(mdb_env, argv[1], MDB_RDONLY, 0664),
        MDB_SUCCESS) << "mdb_env_open failed";
    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, MDB_RDONLY, &mdb_txn), MDB_SUCCESS)
        << "mdb_txn_begin failed";
    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
        << "mdb_open failed";
    CHECK_EQ(mdb_cursor_open(mdb_txn, mdb_dbi, &mdb_cursor), MDB_SUCCESS)
        << "mdb_cursor_open failed";
    CHECK_EQ(mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_FIRST),
        MDB_SUCCESS);
  } else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  // set size info
  Datum datum;
  BlobProto sum_blob;
  int count = 0;
  // load first datum
  if (db_backend == "leveldb") {
    datum.ParseFromString(it->value().ToString());
  } else if (db_backend == "lmdb") {
    datum.ParseFromArray(mdb_value.mv_data, mdb_value.mv_size);
  } else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  sum_blob.set_num(1);
  sum_blob.set_channels(datum.channels());
  sum_blob.set_height(datum.height());
  sum_blob.set_width(datum.width());
  const int data_size = datum.channels() * datum.height() * datum.width();
  int size_in_datum = std::max<int>(datum.data().size(),
                                    datum.float_data_size());
  for (int i = 0; i < size_in_datum; ++i) {
    sum_blob.add_data(0.);
  }
  // start collecting
  LOG(INFO) << "Starting Iteration";

  if (db_backend == "leveldb") {  // leveldb
    for (it->SeekToFirst(); it->Valid(); it->Next()) {
      // just a dummy operation
      datum.ParseFromString(it->value().ToString());
      const string& data = datum.data();
      size_in_datum = std::max<int>(datum.data().size(),
          datum.float_data_size());
      CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
          size_in_datum;
      if (data.size() != 0) {
        for (int i = 0; i < size_in_datum; ++i) {
          sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
        }
      } else {
        for (int i = 0; i < size_in_datum; ++i) {
          sum_blob.set_data(i, sum_blob.data(i) +
              static_cast<float>(datum.float_data(i)));
        }
      }
      ++count;
      if (count % 10000 == 0) {
        LOG(ERROR) << "Processed " << count << " files.";
      }
    }
  } else if (db_backend == "lmdb") {  // lmdb
    CHECK_EQ(mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_FIRST),
        MDB_SUCCESS);
    do {
      // just a dummy operation
      datum.ParseFromArray(mdb_value.mv_data, mdb_value.mv_size);
      const string& data = datum.data();
      size_in_datum = std::max<int>(datum.data().size(),
          datum.float_data_size());
      CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
          size_in_datum;
      if (data.size() != 0) {
        for (int i = 0; i < size_in_datum; ++i) {
          sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
        }
      } else {
        for (int i = 0; i < size_in_datum; ++i) {
          sum_blob.set_data(i, sum_blob.data(i) +
              static_cast<float>(datum.float_data(i)));
        }
      }
      ++count;
      if (count % 10000 == 0) {
        LOG(ERROR) << "Processed " << count << " files.";
      }
    } while (mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_NEXT)
        == MDB_SUCCESS);
  } else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  for (int i = 0; i < sum_blob.data_size(); ++i) {
    sum_blob.set_data(i, sum_blob.data(i) / count);
  }

  caffe::Blob<float> vis;
  vis.FromProto(sum_blob);
  caffe::imshow(&vis, 1, "mean img");
  cv::waitKey(0);
  
  google::protobuf::RepeatedField<float>* tmp = sum_blob.mutable_data();
  std::vector<float> mean_data(tmp->begin(), tmp->end());
  double sum = std::accumulate(mean_data.begin(), mean_data.end(), 0.0);
  double mean2 = sum / mean_data.size();
  double sq_sum = std::inner_product(mean_data.begin(), mean_data.end(), mean_data.begin(), 0.0);
  double stdev = std::sqrt(sq_sum / mean_data.size() - mean2 * mean2);

  LOG(INFO) << "mean of mean image: " << mean2 << " std: " << stdev;

  // Write to disk
  LOG(INFO) << "Write to " << argv[2];
  WriteProtoToBinaryFile(sum_blob, argv[2]);

  // Clean up
  if (db_backend == "leveldb") {
    delete db;
  } else if (db_backend == "lmdb") {
    mdb_cursor_close(mdb_cursor);
    mdb_close(mdb_env, mdb_dbi);
    mdb_txn_abort(mdb_txn);
    mdb_env_close(mdb_env);
  } else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }
  return 0;
}
コード例 #25
0
int main(int argc, char** argv) {
	::google::InitGoogleLogging(argv[0]);
	if (argc < 5) {
		printf(
				"Convert a set of images to the leveldb format used\n"
						"as input for Caffe.\n"
						"Usage:\n"
						"    convert_imageset ROOTFOLDER/ ANNOTATION DB_NAME"
						" MODE[0-train, 1-val, 2-test] RANDOM_SHUFFLE_DATA[0 or 1, default 1] RESIZE_WIDTH[default 256] RESIZE_HEIGHT[default 256](0 indicates no resize)\n"
						"The ImageNet dataset for the training demo is at\n"
						"    http://www.image-net.org/download-images\n");
		return 0;
	}
	std::ifstream infile(argv[2]);
	std::vector<Seg_Anno> annos;
	std::set<string> fNames;
	string filename;
	int prop;
	while (infile >> filename)
	{
		LOG(INFO)<<filename;

		Seg_Anno seg_Anno;
		seg_Anno.filename_ = filename;
		for (int i = 0; i < LABEL_LEN; i++)
		{
			infile >> prop;
			seg_Anno.pos_.push_back(prop);
		}
		if (fNames.find(filename)== fNames.end())
		{
			fNames.insert(filename);
			annos.push_back(seg_Anno);
		}
		//debug
		//if(annos.size() == 10)
		//	break;
	}
	if (argc < 6 || argv[5][0] != '0') {
		// randomly shuffle data
		LOG(INFO)<< "Shuffling data";
		std::random_shuffle(annos.begin(), annos.end());
	}
	LOG(INFO)<< "A total of " << annos.size() << " images.";

	leveldb::DB* db;
	leveldb::Options options;
	options.error_if_exists = true;
	options.create_if_missing = true;
	options.write_buffer_size = 268435456;
	LOG(INFO)<< "Opening leveldb " << argv[3];
	leveldb::Status status = leveldb::DB::Open(options, argv[3], &db);
	CHECK(status.ok()) << "Failed to open leveldb " << argv[3];

	string root_folder(argv[1]);
	string fchannel_folder(argv[8]);
	Datum datum;
	int count = 0;
	const int maxKeyLength = 256;
	char key_cstr[maxKeyLength];
	leveldb::WriteBatch* batch = new leveldb::WriteBatch();
	int data_size;
	bool data_size_initialized = false;

	// resize to height * width
    int width = RESIZE_LEN;
    int height = RESIZE_LEN;
    if (argc > 6) width = atoi(argv[6]);
    if (argc > 7) height = atoi(argv[7]);
    if (width == 0 || height == 0)
        LOG(INFO) << "NO RESIZE SHOULD BE DONE";
    else
        LOG(INFO) << "RESIZE DIM: " << width << "*" << height;

	for (int anno_id = 0; anno_id < annos.size(); ++anno_id)
	{
		string filename2 = parseString(annos[anno_id].filename_);
		if (!MyReadImageToDatum(root_folder + "/" + annos[anno_id].filename_, fchannel_folder + "/" + filename2,
				annos[anno_id].pos_, height, width, &datum))
		{
			continue;
		}
		if (!data_size_initialized)
		{
			data_size = datum.channels() * datum.height() * datum.width();
			data_size_initialized = true;
		}
		else
		{
			const string& data = datum.data();
			CHECK_EQ(data.size(), data_size)<< "Incorrect data field size " << data.size();
		}

		// sequential
		snprintf(key_cstr, maxKeyLength, "%07d_%s", anno_id, annos[anno_id].filename_.c_str());
		string value;
		// get the value
		datum.SerializeToString(&value);
		batch->Put(string(key_cstr), value);
		if (++count % 1000 == 0)
		{
			db->Write(leveldb::WriteOptions(), batch);
			LOG(ERROR)<< "Processed " << count << " files.";
			delete batch;
			batch = new leveldb::WriteBatch();
		}
	}
	// write the last batch
	if (count % 1000 != 0) {
		db->Write(leveldb::WriteOptions(), batch);
		LOG(ERROR)<< "Processed " << count << " files.";
	}

	delete batch;
	delete db;
	return 0;
}
コード例 #26
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 4) {
    printf("Convert a set of images to the leveldb format used\n"
        "as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset ROOTFOLDER/ LISTFILE DB_NAME"
        " RANDOM_SHUFFLE_DATA[0 or 1]\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
    return 0;
  }
  std::ifstream infile(argv[2]);
  std::vector<std::pair<string, int> > lines;
  string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (argc == 5 && argv[4][0] == '1') {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    std::random_shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  LOG(INFO) << "Opening leveldb " << argv[3];
  leveldb::Status status = leveldb::DB::Open(
      options, argv[3], &db);
  CHECK(status.ok()) << "Failed to open leveldb " << argv[3];

  string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
  int data_size;
  bool data_size_initialized = false;
  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageToDatum(root_folder + lines[line_id].first,
                          lines[line_id].second, &datum)) {
      continue;
    }
    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
      data_size_initialized = true;
    } else {
      const string& data = datum.data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
          << data.size();
    }
    // sequential
    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());
    string value;
    // get the value
    datum.SerializeToString(&value);
    batch->Put(string(key_cstr), value);
    if (++count % 1000 == 0) {
      db->Write(leveldb::WriteOptions(), batch);
      LOG(ERROR) << "Processed " << count << " files.";
      delete batch;
      batch = new leveldb::WriteBatch();
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    db->Write(leveldb::WriteOptions(), batch);
    LOG(ERROR) << "Processed " << count << " files.";
  }

  delete batch;
  delete db;
  return 0;
}
コード例 #27
0
void MyImageDataLayer<Dtype>::fetchData() {
	Datum datum;
	CHECK(prefetch_data_.count());
	Dtype* top_data = prefetch_data_.mutable_cpu_data();
	Dtype* top_label = prefetch_label_.mutable_cpu_data();
	ImageDataParameter image_data_param = this->layer_param_.image_data_param();
	const Dtype scale = image_data_param.scale();//image_data_layer相关参数
	const int batch_size = 1;//image_data_param.batch_size(); 这里我们只需要一张图片

	const int crop_size = image_data_param.crop_size();
	const bool mirror = image_data_param.mirror();
	const int new_height = image_data_param.new_height();
	const int new_width = image_data_param.new_width();

	if (mirror && crop_size == 0) {
	    LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
				   << "set at the same time.";
	}
	// datum scales
	const int channels = datum_channels_;
	const int height = datum_height_;
	const int width = datum_width_;
	const int size = datum_size_;
	const int lines_size = lines_.size();
	const Dtype* mean = data_mean_.cpu_data();

	for (int item_id = 0; item_id < batch_size; ++item_id) {//读取一图片
	    // get a blob
	    CHECK_GT(lines_size, lines_id_);
	    if (!ReadImageToDatum(lines_[lines_id_].first,
							  lines_[lines_id_].second,
							  new_height, new_width, &datum)) {
			continue;
	    }
	    const string& data = datum.data();
	    if (crop_size) {
			CHECK(data.size()) << "Image cropping only support uint8 data";
			int h_off, w_off;
			// We only do random crop when we do training.
	        h_off = (height - crop_size) / 2;
	        w_off = (width - crop_size) / 2;

	        // Normal copy 正常读取,把裁剪后的图片数据读给top_data
	        for (int c = 0; c < channels; ++c) {
				for (int h = 0; h < crop_size; ++h) {
					for (int w = 0; w < crop_size; ++w) {
						int top_index = ((item_id * channels + c) * crop_size + h)
										* crop_size + w;
						int data_index = (c * height + h + h_off) * width + w + w_off;
						Dtype datum_element =
							static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
						top_data[top_index] = (datum_element - mean[data_index]) * scale;
					}
				}
	        }

	    } else {
			// Just copy the whole data 正常读取,把图片数据读给top_data
			if (data.size()) {
				for (int j = 0; j < size; ++j) {
					Dtype datum_element =
						static_cast<Dtype>(static_cast<uint8_t>(data[j]));
					top_data[item_id * size + j] = (datum_element - mean[j]) * scale;
				}
			} else {
				for (int j = 0; j < size; ++j) {
					top_data[item_id * size + j] =
						(datum.float_data(j) - mean[j]) * scale;
				}
			}
	    }
	    top_label[item_id] = datum.label();//读取该图片的标签

	}
}
コード例 #28
0
void DataTransformer<Dtype>::Transform(const int batch_item_id,
                                       const Datum& datum,
                                       const Dtype* mean,
                                       Dtype* transformed_data) {
  const string& data = datum.data();
  const int channels = datum.channels();
  const int height = datum.height();
  const int width = datum.width();
  const int size = datum.channels() * datum.height() * datum.width();

  const int crop_size = param_.crop_size();
  const bool mirror = param_.mirror();
  const Dtype scale = param_.scale();

  if (mirror && crop_size == 0) {
    LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
               << "set at the same time.";
  }

  if (crop_size) {
    CHECK(data.size()) << "Image cropping only support uint8 data";
    int h_off, w_off;
    // We only do random crop when we do training.
    if (phase_ == Caffe::TRAIN) {
      h_off = Rand() % (height - crop_size);
      w_off = Rand() % (width - crop_size);
    } else {
      h_off = (height - crop_size) / 2;
      w_off = (width - crop_size) / 2;
    }
    if (mirror && Rand() % 2) {
      // Copy mirrored version
      for (int c = 0; c < channels; ++c) {
        for (int h = 0; h < crop_size; ++h) {
          for (int w = 0; w < crop_size; ++w) {
            int data_index = (c * height + h + h_off) * width + w + w_off;
            int top_index = ((batch_item_id * channels + c) * crop_size + h)
                * crop_size + (crop_size - 1 - w);
            Dtype datum_element =
                static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
            transformed_data[top_index] =
                (datum_element - mean[data_index]) * scale;
          }
        }
      }
    } else {
      // Normal copy
      for (int c = 0; c < channels; ++c) {
        for (int h = 0; h < crop_size; ++h) {
          for (int w = 0; w < crop_size; ++w) {
            int top_index = ((batch_item_id * channels + c) * crop_size + h)
                * crop_size + w;
            int data_index = (c * height + h + h_off) * width + w + w_off;
            Dtype datum_element =
                static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
            transformed_data[top_index] =
                (datum_element - mean[data_index]) * scale;
          }
        }
      }
    }
  } else {
    // we will prefer to use data() first, and then try float_data()
    if (data.size()) {
      for (int j = 0; j < size; ++j) {
        Dtype datum_element =
            static_cast<Dtype>(static_cast<uint8_t>(data[j]));
        transformed_data[j + batch_item_id * size] =
            (datum_element - mean[j]) * scale;
      }
    } else {
      for (int j = 0; j < size; ++j) {
        transformed_data[j + batch_item_id * size] =
            (datum.float_data(j) - mean[j]) * scale;
      }
    }
  }
}
コード例 #29
0
ファイル: data_transformer.cpp プロジェクト: Rt0220/caffe
void DataTransformer<Dtype>::Transform(const Datum& datum,
                                       Dtype* transformed_data) {
  const string& data = datum.data();
  const int datum_channels = datum.channels();
  const int datum_height = datum.height();
  const int datum_width = datum.width();

  const int crop_size = param_.crop_size();
  const Dtype scale = param_.scale();
  const bool do_mirror = param_.mirror() && Rand(2);
  const bool has_mean_file = param_.has_mean_file();
  const bool has_uint8 = data.size() > 0;
  const bool has_mean_values = mean_values_.size() > 0;

  CHECK_GT(datum_channels, 0);
  CHECK_GE(datum_height, crop_size);
  CHECK_GE(datum_width, crop_size);

  Dtype* mean = NULL;
  if (has_mean_file) {
    CHECK_EQ(datum_channels, data_mean_.channels());
    CHECK_EQ(datum_height, data_mean_.height());
    CHECK_EQ(datum_width, data_mean_.width());
    mean = data_mean_.mutable_cpu_data();
  }
  if (has_mean_values) {
    CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
     "Specify either 1 mean_value or as many as channels: " << datum_channels;
    if (datum_channels > 1 && mean_values_.size() == 1) {
      // Replicate the mean_value for simplicity
      for (int c = 1; c < datum_channels; ++c) {
        mean_values_.push_back(mean_values_[0]);
      }
    }
  }

  int height = datum_height;
  int width = datum_width;

  int h_off = 0;
  int w_off = 0;
  if (crop_size) {
    height = crop_size;
    width = crop_size;
    // We only do random crop when we do training.
    if (phase_ == TRAIN) {
      h_off = Rand(datum_height - crop_size + 1);
      w_off = Rand(datum_width - crop_size + 1);
    } else {
      h_off = (datum_height - crop_size) / 2;
      w_off = (datum_width - crop_size) / 2;
    }
  }

  Dtype datum_element;
  int top_index, data_index;
  for (int c = 0; c < datum_channels; ++c) {
    for (int h = 0; h < height; ++h) {
      for (int w = 0; w < width; ++w) {
        data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
        if (do_mirror) {
          top_index = (c * height + h) * width + (width - 1 - w);
        } else {
          top_index = (c * height + h) * width + w;
        }
        if (has_uint8) {
          datum_element =
            static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
        } else {
          datum_element = datum.float_data(data_index);
        }
        if (has_mean_file) {
          transformed_data[top_index] =
            (datum_element - mean[data_index]) * scale;
        } else {
          if (has_mean_values) {
            transformed_data[top_index] =
              (datum_element - mean_values_[c]) * scale;
          } else {
            transformed_data[top_index] = datum_element * scale;
          }
        }
      }
    }
  }
}
コード例 #30
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc != 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
    return 1;
  }


  bool is_color = !FLAGS_gray;
  std::ifstream infile(argv[2]);
  std::vector<std::pair<string, int> > lines;
  string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (FLAGS_shuffle) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  const string& db_backend = FLAGS_backend;
  const char* db_path = argv[3];

  int resize_height = std::max<int>(0, FLAGS_resize_height);
  int resize_width = std::max<int>(0, FLAGS_resize_width);

  // Open new db
  // lmdb
  MDB_env *mdb_env;
  MDB_dbi mdb_dbi;
  MDB_val mdb_key, mdb_data;
  MDB_txn *mdb_txn;
  // leveldb
  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  leveldb::WriteBatch* batch = NULL;

  // Open db
  if (db_backend == "leveldb") {  // leveldb
    LOG(INFO) << "Opening leveldb " << db_path;
    leveldb::Status status = leveldb::DB::Open(
        options, db_path, &db);
    CHECK(status.ok()) << "Failed to open leveldb " << db_path
        << ". Is it already existing?";
    batch = new leveldb::WriteBatch();
  } else if (db_backend == "lmdb") {  // lmdb
    LOG(INFO) << "Opening lmdb " << db_path;
    CHECK_EQ(mkdir(db_path, 0744), 0)
        << "mkdir " << db_path << "failed";
    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB
        << "mdb_env_set_mapsize failed";
    CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS)
        << "mdb_env_open failed";
    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
        << "mdb_txn_begin failed";
    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
        << "mdb_open failed. Does the lmdb already exist? ";
  } else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  // Storing to db
  string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int data_size;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, resize_height, resize_width, is_color, &datum)) {
      continue;
    }
    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
      data_size_initialized = true;
    } else {
      const string& data = datum.data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
          << data.size();
    }
    // sequential
    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());
    string value;
    datum.SerializeToString(&value);
    string keystr(key_cstr);

    // Put in db
    if (db_backend == "leveldb") {  // leveldb
      batch->Put(keystr, value);
    } else if (db_backend == "lmdb") {  // lmdb
      mdb_data.mv_size = value.size();
      mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
      mdb_key.mv_size = keystr.size();
      mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);
      CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS)
          << "mdb_put failed";
    } else {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }

    if (++count % 1000 == 0) {
      // Commit txn
      if (db_backend == "leveldb") {  // leveldb
        db->Write(leveldb::WriteOptions(), batch);
        delete batch;
        batch = new leveldb::WriteBatch();
      } else if (db_backend == "lmdb") {  // lmdb
        CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS)
            << "mdb_txn_commit failed";
        CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
            << "mdb_txn_begin failed";
      } else {
        LOG(FATAL) << "Unknown db backend " << db_backend;
      }
      LOG(ERROR) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    if (db_backend == "leveldb") {  // leveldb
      db->Write(leveldb::WriteOptions(), batch);
      delete batch;
      delete db;
    } else if (db_backend == "lmdb") {  // lmdb
      CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
      mdb_close(mdb_env, mdb_dbi);
      mdb_env_close(mdb_env);
    } else {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }
    LOG(ERROR) << "Processed " << count << " files.";
  }
  return 0;
}