コード例 #1
0
void convert_dataset(const string& input_folder, const string& output_folder,
    const string& db_type) {
  scoped_ptr<db::DB> train_db(db::GetDB(db_type));
  train_db->Open(output_folder + "/GCN300_cifar100_train_" + db_type, db::NEW);
  scoped_ptr<db::Transaction> txn(train_db->NewTransaction());
  // Data buffer
  int label;
  char str_buffer[kCIFARImageNBytes];
  Datum datum;
  datum.set_channels(3);
  datum.set_height(kCIFARSize);
  datum.set_width(kCIFARSize);

  LOG(INFO) << "Writing Training data";

    // Open files
    LOG(INFO) << "Training Batch " << 1;
    //change batch file name;
    string batchFileName = input_folder + "/train.bin";
    std::ifstream data_file(batchFileName.c_str(),
        std::ios::in | std::ios::binary);
    CHECK(data_file) << "Unable to open train file #" << 1;
    for (int itemid = 0; itemid < kCIFARBatchSize; ++itemid) {
      read_image(&data_file, &label, str_buffer);
      if (count[label] < NUMPERCLASS) {
        datum.set_label(label);
        datum.set_data(str_buffer, kCIFARImageNBytes);
        string out;
        CHECK(datum.SerializeToString(&out));
        txn->Put(caffe::format_int(itemid, 5), out);
        count[label]++;        
      }
    }

  txn->Commit();
  train_db->Close();

  LOG(INFO) << "Writing Testing data";
  scoped_ptr<db::DB> test_db(db::GetDB(db_type));
  test_db->Open(output_folder + "/GCN300_cifar100_test_" + db_type, db::NEW);
  txn.reset(test_db->NewTransaction());
  // Open files
  std::ifstream data_file2((input_folder + "/test.bin").c_str(),
      std::ios::in | std::ios::binary);
  CHECK(data_file2) << "Unable to open test file.";
  for (int itemid = 0; itemid < 10000; ++itemid) {//change test size;
    read_image(&data_file2, &label, str_buffer);
    datum.set_label(label);
    datum.set_data(str_buffer, kCIFARImageNBytes);
    string out;
    CHECK(datum.SerializeToString(&out));
    txn->Put(caffe::format_int(itemid, 5), out);
  }
  txn->Commit();
  test_db->Close();
}
コード例 #2
0
void convert_dataset(const string& input_folder, const string& output_folder,
    const string& db_type) {
  scoped_ptr<db::DB> train_db(db::GetDB(db_type));
  train_db->Open(output_folder + "/cifar10_train_" + db_type, db::NEW);
  scoped_ptr<db::Transaction> txn(train_db->NewTransaction());
  // Data buffer
  int label;
  char str_buffer[kCIFARImageNBytes];
  Datum datum;
  datum.set_channels(3);
  datum.set_height(kCIFARSize);
  datum.set_width(kCIFARSize);

  LOG(INFO) << "Writing Training data";
  for (int fileid = 0; fileid < kCIFARTrainBatches; ++fileid) {
    // Open files
    LOG(INFO) << "Training Batch " << fileid + 1;
    snprintf(str_buffer, kCIFARImageNBytes, "/data_batch_%d.bin", fileid + 1);
    std::ifstream data_file((input_folder + str_buffer).c_str(),
        std::ios::in | std::ios::binary);
    CHECK(data_file) << "Unable to open train file #" << fileid + 1;
    for (int itemid = 0; itemid < kCIFARBatchSize; ++itemid) {
      read_image(&data_file, &label, str_buffer);
      datum.set_label(label);
      datum.set_data(str_buffer, kCIFARImageNBytes);
      int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
          fileid * kCIFARBatchSize + itemid);
      string out;
      CHECK(datum.SerializeToString(&out));
      txn->Put(string(str_buffer, length), out);
    }
  }
  txn->Commit();
  train_db->Close();

  LOG(INFO) << "Writing Testing data";
  scoped_ptr<db::DB> test_db(db::GetDB(db_type));
  test_db->Open(output_folder + "/cifar10_test_" + db_type, db::NEW);
  txn.reset(test_db->NewTransaction());
  // Open files
  std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
      std::ios::in | std::ios::binary);
  CHECK(data_file) << "Unable to open test file.";
  for (int itemid = 0; itemid < kCIFARBatchSize; ++itemid) {
    read_image(&data_file, &label, str_buffer);
    datum.set_label(label);
    datum.set_data(str_buffer, kCIFARImageNBytes);
    int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
    string out;
    CHECK(datum.SerializeToString(&out));
    txn->Put(string(str_buffer, length), out);
  }
  txn->Commit();
  test_db->Close();
}
コード例 #3
0
TYPED_TEST(DBTest, TestWrite) {
  unique_ptr<db::DB> db(db::GetDB(TypeParam::backend));
  db->Open(this->source_, db::WRITE);
  unique_ptr<db::Transaction> txn(db->NewTransaction());
  Datum datum;
  ReadFileToDatum(this->root_images_ + "cat.jpg", 0, &datum);
  string out;
  CHECK(datum.SerializeToString(&out));
  txn->Put("cat.jpg", out);
  ReadFileToDatum(this->root_images_ + "fish-bike.jpg", 1, &datum);
  CHECK(datum.SerializeToString(&out));
  txn->Put("fish-bike.jpg", out);
  txn->Commit();
}
コード例 #4
0
 // Fill the DB with data: if unique_pixels, each pixel is unique but
 // all images are the same; else each image is unique but all pixels within
 // an image are the same.
 void Fill(const bool unique_pixels, DataParameter_DB backend) {
   backend_ = backend;
   LOG(INFO) << "Using temporary dataset " << *filename_;
   scoped_ptr<db::DB> db(db::GetDB(backend));
   db->Open(*filename_, db::NEW);
   scoped_ptr<db::Transaction> txn(db->NewTransaction());
   for (int i = 0; i < 5; ++i) {
     Datum datum;
     datum.set_label(i);
     datum.set_channels(2);
     datum.set_height(3);
     datum.set_width(4);
     std::string* data = datum.mutable_data();
     for (int j = 0; j < 24; ++j) {
       int datum = unique_pixels ? j : i;
       data->push_back(static_cast<uint8_t>(datum));
     }
     stringstream ss;
     ss << i;
     string out;
     CHECK(datum.SerializeToString(&out));
     txn->Put(ss.str(), out);
   }
   txn->Commit();
   db->Close();
 }
コード例 #5
0
void convert_dataset(string phrase) {
  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  leveldb::WriteBatch* batch = NULL;
  string db_path = ".//examples//language_model//lm_" + phrase + "_leveldb";
  leveldb::Status status = leveldb::DB::Open(
      options, db_path, &db);

  batch = new leveldb::WriteBatch();
  const int kMaxKeyLength = 10;
  char key_cstr[kMaxKeyLength];
  string value;
  Datum datum;
  datum.set_channels(2*maximum_length);
  datum.set_height(1);
  datum.set_width(1);

  for (int i=0;i<2*maximum_length;i++)
    datum.add_float_data(0.0);


  string tmp = ".//data//language_model//"+phrase+"_indices.txt";
  std::ifstream infile(tmp);
  string s;
  int item_id = 0;
  while (getline(infile,s)){
    std::vector<float> dt,real_dt;
	  std::istringstream iss(s);
	  int num;
	  while (iss >> num){
		  if (num >= unknown_symbol)
			  num = unknown_symbol;
		  dt.push_back(num);
    }
    if (dt.size()<maximum_length){
      int l = maximum_length-dt.size();
      for (int i=0;i<l;i++)
        dt.push_back(zero_symbol);
    }
    real_dt.push_back(zero_symbol);
    for (int i=0;i<dt.size()-1;i++)
      real_dt.push_back(dt[i]);
    for (int i=0;i<dt.size();i++)
      real_dt.push_back(dt[i]);
    _snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
    string keystr(key_cstr);
    for (int i=0;i<2*maximum_length;i++)
      datum.set_float_data(i,real_dt[i]);
    datum.SerializeToString(&value);
    batch->Put(keystr, value);
    item_id++;
  }
  db->Write(leveldb::WriteOptions(), batch);
  delete batch;
}
コード例 #6
0
 virtual void SetUp() {
   MakeTempDir(&source_);
   source_ += "/db";
   string keys[] = {"cat.jpg", "fish-bike.jpg"};
   LOG(INFO) << "Using temporary db " << source_;
   unique_ptr<db::DB> db(db::GetDB(TypeParam::backend));
   db->Open(this->source_, db::NEW);
   unique_ptr<db::Transaction> txn(db->NewTransaction());
   for (int i = 0; i < 2; ++i) {
     Datum datum;
     ReadImageToDatum(root_images_ + keys[i], i, &datum);
     string out;
     CHECK(datum.SerializeToString(&out));
     txn->Put(keys[i], out);
   }
   txn->Commit();
 }
int main(int argc, char** argv) {
	::google::InitGoogleLogging(argv[0]);
	if (argc < 5) {
		printf(
				"Convert a set of images to the leveldb format used\n"
						"as input for Caffe.\n"
						"Usage:\n"
						"    convert_imageset ROOTFOLDER/ ANNOTATION DB_NAME"
						" MODE[0-train, 1-val, 2-test] RANDOM_SHUFFLE_DATA[0 or 1, default 1] RESIZE_WIDTH[default 256] RESIZE_HEIGHT[default 256](0 indicates no resize)\n"
						"The ImageNet dataset for the training demo is at\n"
						"    http://www.image-net.org/download-images\n");
		return 0;
	}
	std::ifstream infile(argv[2]);
	string root_folder(argv[1]);
	string coarse_folder(argv[8]);
	string local_folder(argv[9]);
	string van_folder(argv[10]);
	string edge_folder(argv[11]);
	string layout_folder(argv[12]);
	std::vector<Seg_Anno> annos;
	std::set<string> fNames;
	string filename;
	float prop;
	int cc = 0;
	while (infile >> filename)
	{
		if (cc % 1000 == 0)
		LOG(INFO)<<filename;
		cc ++;

		Seg_Anno seg_Anno;
		seg_Anno.filename_ = filename;
		int x,y;
		infile >> x >> y;
		for (int i = 0; i < LABEL_LEN; i++)
		{
			//infile >> prop;
			if(!(prop < 1000000 && prop > -1000000))
			{
				printf("123");
			}
			seg_Anno.pos_.push_back(0);
		}
		//string labelFile = filename;
		//labelFile[labelFile.size() - 1] = 't';
		//labelFile[labelFile.size() - 2] = 'x';
		//labelFile[labelFile.size() - 3] = 't';
		//labelFile =  coarse_folder + "/" + labelFile;
		//FILE * tf = fopen(labelFile.c_str(), "rb");
		//if(tf == NULL) continue;
		//fclose(tf);
		if (fNames.find(filename)== fNames.end())
		{
			fNames.insert(filename);
			annos.push_back(seg_Anno);
		}
		//debug
		//if(annos.size() == 10)
		//	break;
	}
	if (argc < 6 || argv[5][0] != '0') {
		// randomly shuffle data
		LOG(INFO)<< "Shuffling data";
		std::random_shuffle(annos.begin(), annos.end());
	}
	LOG(INFO)<< "A total of " << annos.size() << " images.";

	leveldb::DB* db;
	leveldb::Options options;
	options.error_if_exists = true;
	options.create_if_missing = true;
	options.write_buffer_size = 268435456;
	LOG(INFO)<< "Opening leveldb " << argv[3];
	leveldb::Status status = leveldb::DB::Open(options, argv[3], &db);
	CHECK(status.ok()) << "Failed to open leveldb " << argv[3];

	Datum datum;
	int count = 0;
	const int maxKeyLength = 256;
	char key_cstr[maxKeyLength];
	leveldb::WriteBatch* batch = new leveldb::WriteBatch();
	int data_size;
	bool data_size_initialized = false;

	// resize to height * width
    int width = RESIZE_LEN;
    int height = RESIZE_LEN;
    if (argc > 6) width = atoi(argv[6]);
    if (argc > 7) height = atoi(argv[7]);
    if (width == 0 || height == 0)
        LOG(INFO) << "NO RESIZE SHOULD BE DONE";
    else
        LOG(INFO) << "RESIZE DIM: " << width << "*" << height;

	for (int anno_id = 0; anno_id < annos.size(); ++anno_id)
	{
		string labelFile = annos[anno_id].filename_;
		labelFile[labelFile.size() - 1] = 't';
		labelFile[labelFile.size() - 2] = 'x';
		labelFile[labelFile.size() - 3] = 't';
		if (!MyReadImageToDatum(root_folder + "/" + annos[anno_id].filename_, coarse_folder + "/" + labelFile, local_folder + "/" + labelFile, van_folder + "/" + labelFile,
				edge_folder + '/' + labelFile, layout_folder + '/' + labelFile , annos[anno_id].pos_, height, width, &datum))
		{
			continue;
		}
		if (!data_size_initialized)
		{
			data_size = datum.channels() * datum.height() * datum.width() ;
			data_size_initialized = true;
		}
		else
		{
			int dataLen = datum.float_data_size();
			CHECK_EQ(dataLen, data_size)<< "Incorrect data field size " << dataLen;
		}

		// sequential
		snprintf(key_cstr, maxKeyLength, "%07d_%s", anno_id, annos[anno_id].filename_.c_str());
		string value;
		// get the value
		datum.SerializeToString(&value);
		batch->Put(string(key_cstr), value);
		if (++count % 1000 == 0)
		{
			db->Write(leveldb::WriteOptions(), batch);
			LOG(ERROR)<< "Processed " << count << " files.";
			delete batch;
			batch = new leveldb::WriteBatch();
		}
	}
	// write the last batch
	if (count % 1000 != 0) {
		db->Write(leveldb::WriteOptions(), batch);
		LOG(ERROR)<< "Processed " << count << " files.";
	}

	delete batch;
	delete db;
	return 0;
}
コード例 #8
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  // Print output to stderr (while still logging)
  FLAGS_alsologtostderr = 1;

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc < 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
    return 1;
  }

  const bool is_color = !FLAGS_gray;
  const bool check_size = FLAGS_check_size;
  const bool encoded = FLAGS_encoded;
  const string encode_type = FLAGS_encode_type;

  std::ifstream infile(argv[2]);
  std::vector<std::pair<std::vector<std::string>, int> > lines;
  std::string filename_1;
  std::string filename_2;
  std::string filename_3;
  std::string filename_4;
  std::string filename_5;
  std::string filename_6;
  std::string filename_7;
  std::string filename_8;
  std::string filename_9;
  std::string filename_10;
  int label;
  while (infile >> filename_1 >> filename_2 >> filename_3 >> filename_4 >> filename_5 >> filename_6 >> filename_7 >> filename_8 >> filename_9 >> filename_10 >> label) {
    std::vector<std::string> patches_filenames;
    patches_filenames.push_back(filename_1);
    patches_filenames.push_back(filename_2);
    patches_filenames.push_back(filename_3);
    patches_filenames.push_back(filename_4);
    patches_filenames.push_back(filename_5);
    patches_filenames.push_back(filename_6);
    patches_filenames.push_back(filename_7);
    patches_filenames.push_back(filename_8);
    patches_filenames.push_back(filename_9);
    patches_filenames.push_back(filename_10);
    lines.push_back(std::make_pair(patches_filenames, label));
  }
  if (FLAGS_shuffle) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  if (encode_type.size() && !encoded)
    LOG(INFO) << "encode_type specified, assuming encoded=true.";

  int resize_height = std::max<int>(0, FLAGS_resize_height);
  int resize_width = std::max<int>(0, FLAGS_resize_width);

  // Create new DB
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(argv[3], db::NEW);
  scoped_ptr<db::Transaction> txn(db->NewTransaction());

  // Storing to db
  std::string root_folder(argv[1]);
  Datum datum;
  Datum datum_aux;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int data_size = 0;
  bool data_size_initialized = false;


  for (int line_id = 0; line_id < lines.size(); ++line_id) 
  {

    //LOG(INFO) << "Processing image " << line_id << " with " << lines[line_id].first.size() << " patches.";
    //LOG(INFO) << "                 (we have " << (lines[line_id].first.size() / NUM_CHANNELS) << " groups";


      //LOG(INFO) << "      got " << NUM_CHANNELS << " patches from " << group_id;

      bool status;
      std::string enc = encode_type;
      if (encoded && !enc.size()) {
        // Guess the encoding type from the file name
        string fn = lines[line_id].first[0];
        size_t p = fn.rfind('.');
        if ( p == fn.npos )
          LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
        enc = fn.substr(p);
        std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
      }
  
      datum.set_channels(NUM_CHANNELS);  // one channel for each image in the group
      std::string buffer;

      for (int patch_id = 0; patch_id < NUM_CHANNELS; patch_id++)
      {
        //LOG(INFO) << "          channel " << patch_id << ": " << root_folder + lines[line_id].first[group_id+patch_id];
        status = ReadImageToDatum(root_folder + lines[line_id].first[patch_id],
            lines[line_id].second, resize_height, resize_width, is_color,
            enc, &datum_aux);
        if (status == false) continue;
        if (patch_id == 0)
        {
          datum.set_height(datum_aux.height());
          datum.set_width(datum_aux.width());
        }
        if (check_size) {
          if (!data_size_initialized) {
            data_size = datum_aux.channels() * datum_aux.height() * datum_aux.width();
            data_size_initialized = true;
          } else {
            const std::string& data = datum_aux.data();
            CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
                << data.size();
          }
        }

        int datum_channels = NUM_CHANNELS;
        int datum_height = datum.height();
        int datum_width = datum.width();
        int datum_size = datum_channels * datum_height * datum_width;
        buffer.insert(datum_height * datum_width * patch_id, datum_aux.data());
        //LOG(INFO) << "          channel " << patch_id << " inserted!";
      }
  
      datum.set_data(buffer);
      datum.set_label(lines[line_id].second);

      // sequential
      int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
          (lines[line_id].first[0]+lines[line_id].first[1]).c_str());



      //LOG(INFO) << "      put group " << group_id << " in the db.";

      // Put in db
      string out;
      CHECK(datum.SerializeToString(&out));
      txn->Put(string(key_cstr, length), out);

      if (++count % 1000 == 0) {
        // Commit db
        txn->Commit();
        txn.reset(db->NewTransaction());
        LOG(INFO) << "Processed " << count << " files.";
      }

  }
  // write the last batch
  if (count % 1000 != 0) {
    txn->Commit();
    LOG(INFO) << "Processed " << count << " files.";
  }
  return 0;
}
コード例 #9
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 4 || argc > 9) {
    printf("Convert a set of images to the leveldb format used\n"
        "as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [-g] ROOTFOLDER/ LISTFILE DB_NAME"
        " RANDOM_SHUFFLE_DATA[0 or 1] DB_BACKEND[leveldb or lmdb]"
        " [resize_height] [resize_width]\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
    return 1;
  }

  // Test whether argv[1] == "-g"
  bool is_color= !(string("-g") == string(argv[1]));
  int  arg_offset = (is_color ? 0 : 1);
  std::ifstream infile(argv[arg_offset+2]);
  std::vector<std::pair<string, int> > lines;
  string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (argc >= (arg_offset+5) && argv[arg_offset+4][0] == '1') {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  string db_backend = "leveldb";
  if (argc >= (arg_offset+6)) {
    db_backend = string(argv[arg_offset+5]);
    if (!(db_backend == "leveldb") && !(db_backend == "lmdb")) {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }
  }

  int resize_height = 0;
  int resize_width = 0;
  if (argc >= (arg_offset+7)) {
    resize_height = atoi(argv[arg_offset+6]);
  }
  if (argc >= (arg_offset+8)) {
    resize_width = atoi(argv[arg_offset+7]);
  }

  // Open new db
  // lmdb
  MDB_env *mdb_env;
  MDB_dbi mdb_dbi;
  MDB_val mdb_key, mdb_data;
  MDB_txn *mdb_txn;
  // leveldb
  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  leveldb::WriteBatch* batch = NULL;

  // Open db
  if (db_backend == "leveldb") {  // leveldb
    LOG(INFO) << "Opening leveldb " << argv[arg_offset+3];
    leveldb::Status status = leveldb::DB::Open(
        options, argv[arg_offset+3], &db);
    CHECK(status.ok()) << "Failed to open leveldb " << argv[arg_offset+3];
    batch = new leveldb::WriteBatch();
  } else if (db_backend == "lmdb") {  // lmdb
    LOG(INFO) << "Opening lmdb " << argv[arg_offset+3];
    CHECK_EQ(mkdir(argv[arg_offset+3], 0744), 0)
        << "mkdir " << argv[arg_offset+3] << "failed";
    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB
        << "mdb_env_set_mapsize failed";
    CHECK_EQ(mdb_env_open(mdb_env, argv[3], 0, 0664), MDB_SUCCESS)
        << "mdb_env_open failed";
    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
        << "mdb_txn_begin failed";
    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
        << "mdb_open failed";
  } else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  // Storing to db
  string root_folder(argv[arg_offset+1]);
  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int data_size;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, resize_height, resize_width, is_color, &datum)) {
      continue;
    }
    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
      data_size_initialized = true;
    } else {
      const string& data = datum.data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
          << data.size();
    }
    // sequential
    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());
    string value;
    datum.SerializeToString(&value);
    string keystr(key_cstr);

    // Put in db
    if (db_backend == "leveldb") {  // leveldb
      batch->Put(keystr, value);
    } else if (db_backend == "lmdb") {  // lmdb
      mdb_data.mv_size = value.size();
      mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
      mdb_key.mv_size = keystr.size();
      mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);
      CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS)
          << "mdb_put failed";
    } else {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }

    if (++count % 1000 == 0) {
      // Commit txn
      if (db_backend == "leveldb") {  // leveldb
        db->Write(leveldb::WriteOptions(), batch);
        delete batch;
        batch = new leveldb::WriteBatch();
      } else if (db_backend == "lmdb") {  // lmdb
        CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS)
            << "mdb_txn_commit failed";
        CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
            << "mdb_txn_begin failed";
      } else {
        LOG(FATAL) << "Unknown db backend " << db_backend;
      }
      LOG(ERROR) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    if (db_backend == "leveldb") {  // leveldb
      db->Write(leveldb::WriteOptions(), batch);
      delete batch;
      delete db;
    } else if (db_backend == "lmdb") {  // lmdb
      CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
      mdb_close(mdb_env, mdb_dbi);
      mdb_env_close(mdb_env);
    } else {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }
    LOG(ERROR) << "Processed " << count << " files.";
  }
  return 0;
}
コード例 #10
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 4 || argc > 5) {
    printf("Convert a set of images to the leveldb format used\n"
        "as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset ROOTFOLDER/ LISTFILE DB_NAME"
        " RANDOM_SHUFFLE_DATA[0 or 1]\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
    return 1;
  }
  std::ifstream infile(argv[2]);
  if(!infile)
	  LOG(INFO) <<"there is no file named " << argv[2];
  std::vector<string> lines;
  string infor;
  int label;
  while (infile >> infor) {
    lines.push_back(infor);
  }
  if (argc == 5 && argv[4][0] == '1') {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    std::random_shuffle(lines.begin()+1, lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  LOG(INFO) << "Opening leveldb " << argv[3];
  leveldb::Status status = leveldb::DB::Open(
      options, argv[3], &db);
  CHECK(status.ok()) << "Failed to open leveldb " << argv[3];

  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
  int data_size;
  bool data_size_initialized = false;

  int width = 0 , height = 0 , channel = 0;
  std::string::size_type pos1 = 0 , pos2 = 0;
  pos2 = lines[0].find(",", pos1);
  channel = atoi(lines[0].substr(pos1,pos2-pos1).c_str());

  pos1 = pos2 + 1;
  pos2 = lines[0].find(",", pos1);
  height = atoi(lines[0].substr(pos1,pos2-pos1).c_str());

  pos1 = pos2 + 1;
  pos2 = lines[0].find(",", pos1);
  width = atoi(lines[0].substr(pos1,pos2-pos1).c_str());

  for (int line_id = 1; line_id < lines.size(); ++line_id) {
    if (!ReadCSVToDatum(lines[line_id], channel, width, height, &datum)) {
      continue;
    }

    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
      data_size_initialized = true;
    } else {
      ::google::protobuf::RepeatedField< float > data = datum.float_data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
          << data.size();
    }
    // sequential
    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].c_str());
    string value;
    // get the value
    datum.SerializeToString(&value);
    batch->Put(string(key_cstr), value);
    if (++count % 1000 == 0) {
      db->Write(leveldb::WriteOptions(), batch);
      LOG(ERROR) << "Processed " << count << " files.";
      delete batch;
      batch = new leveldb::WriteBatch();
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    db->Write(leveldb::WriteOptions(), batch);
    LOG(ERROR) << "Processed " << count << " files.";
  }

  delete batch;
  delete db;
  return 0;
}
コード例 #11
0
ファイル: prepare_data.cpp プロジェクト: phecy/caffe_test
int main(int argc, char** argv) 
{
#ifndef GFLAGS_GFLAGS_H_
    namespace gflags = google;
#endif
    gflags::SetUsageMessage("This script converts the images dataset to\n"
        "the leveldb format used by Caffe to load data.\n");
    gflags::ParseCommandLineFlags(&argc, &argv, true);

    /*  check */
    if (argc != 3) 
    {
        gflags::ShowUsageWithFlagsRestrict(argv[0],"./convert_dataset filelist.txt leveldb_path");
        return -2;
    }
    else 
        google::InitGoogleLogging(argv[0]);
    
    /*  create the leveldb, open it */
    leveldb::DB *db;
    leveldb::Options options;
    options.error_if_exists = true;
    options.create_if_missing = true;
    options.write_buffer_size = 268435456;
    unsigned long counter = 0;
    
    /*  open the leveldb file  */
    string output_db_path( argv[2] );
    LOG(INFO)<<"Opening leveldb "<<output_db_path;
    leveldb::Status status = leveldb::DB::Open( options, output_db_path.c_str(), &db );
    CHECK(status.ok()) << "Failed to open leveldb " << output_db_path<< ". Is it already existing?";
    
    /* also save the map from class_name to label ,
     * and map from label to class_name*/
    map<int, int> class_to_label;
    map<int, int> label_to_class;
    
    /* check if argv[1] is a folder */
    if( !bf::is_regular_file( string(argv[1])) )
    {
        cout<<string(argv[1])<<" is not a file !"<<endl;
        return -3;
    }
    /* iterate the folder */
    unsigned int label = 0;
    stringstream ss;

	int r, img_label, bbox_x, bbox_y, bbox_width, bbox_height;
	char img_path[30];
	FILE *fp = fopen( argv[1], "r");
	if(fp == NULL)
	{
		cout<<"can not open file "<<argv[1]<<endl;
		return -3;
	}

	/*  leveldb writer buffer */
	leveldb::WriteBatch* batch = NULL;
	batch =  new leveldb::WriteBatch();
	while(1)
	{
		r = fscanf(fp, "%s %d %d %d %d %d\n", img_path, &img_label, &bbox_x, &bbox_y, &bbox_width, &bbox_height);
		if( r == EOF )
			break;
		cv::Mat input_img = cv::imread( string(img_path) );
		if(input_img.empty())
		{
			cout<<"img empty ! "<<endl;
			return -5;
		}
		
		/*  read the adjust the image to the fixed size 256x256 */
		Rect adjusted_rect;
		if( bbox_height > bbox_width )
			adjusted_rect = resizeToFixedRatio(Rect( bbox_x, bbox_y, bbox_width, bbox_height), 1, 1);
		else
			adjusted_rect = resizeToFixedRatio(Rect( bbox_x, bbox_y, bbox_width, bbox_height), 1, 0);

		Mat crop_img = cropImage( input_img, adjusted_rect );
		resize( crop_img, crop_img, Size(256,256), 0, 0, INTER_AREA);

		/*decide the label*/
		if( class_to_label.count(img_label) ==0)
		{
			class_to_label[ img_label] = label;
			label_to_class[label] = img_label;
			label++;
		}
		
		/* write the Datum to leveldb */
		Datum datum;
		string value;
		const int kMaxKeyLength = 10;   /*  enough for this dataset */
		char key_cstr[kMaxKeyLength];
		
        /*  convert Mat to Datum using caffe util */
        CVMatToDatum( crop_img, &datum );
        datum.set_label(class_to_label[img_label]);

        //datum.set_channels(crop_img.channels());
        //datum.set_height(crop_img.rows);
        //datum.set_width(crop_img.cols);
        //datum.set_data(crop_img.data, crop_img.cols*crop_img.cols*crop_img.channels()); /* wrong, caffe's data format is num channel height wid */
        //datum.set_label( class_to_label[img_label] ); /*  remember to set the label */
        
        snprintf(key_cstr, kMaxKeyLength, "%08d", counter++);
        datum.SerializeToString(&value);

        batch->Put( key_cstr, value);
		if( counter % 1000 == 0)
		{
			db->Write(leveldb::WriteOptions(), batch);
			delete batch;
			batch =  new leveldb::WriteBatch();
		}
		
		cout<<"processing image "<<img_path<<" with label "<<class_to_label[img_label]<<endl;
		cout<<"car type is "<<class_to_label[img_label]<<endl;
		//rectangle( input_img, cv::Rect( bbox_x, bbox_y, bbox_width, bbox_height), cv::Scalar(255,0,0) );
		//imshow( "input", input_img);
		//imshow("adjust", crop_img);
        
        //cout<<"write image to "<<"./show_test/"+string(img_path)<<std::endl;
        //imwrite("./show_test/"+string(img_path), crop_img);
		//waitKey(0);
	}
	if( counter%1000 != 0 )
	{
		db->Write(leveldb::WriteOptions(), batch);
		delete batch;
		delete db;
	}
	
    cout<<"size of class_to_label is "<<class_to_label.size()<<endl;
    cout<<"size of label_to_class is "<<label_to_class.size()<<endl;

    save_map( class_to_label, "class_to_label_test.data");
    save_map( label_to_class, "label_to_class_test.data");

    return 0;
}
コード例 #12
0
ファイル: main.cpp プロジェクト: cciliber/objrecpipe_mat
int main(int argc, char** argv) {
#ifdef USE_OPENCV
  ::google::InitGoogleLogging(argv[0]);
  // Print output to stderr (while still logging)
  FLAGS_alsologtostderr = 1;

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc < 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "convert_imageset");
    return 1;
  }

  const bool is_color = !FLAGS_gray;
  const bool check_size = FLAGS_check_size;
  const bool encoded = FLAGS_encoded;
  const string encode_type = FLAGS_encode_type;

  std::ifstream infile(argv[2]);
  std::vector<std::pair<std::string, int> > lines;
  std::string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (FLAGS_shuffle) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  if (encode_type.size() && !encoded)
    LOG(INFO) << "encode_type specified, assuming encoded=true.";

  int resize_height = std::max<int>(0, FLAGS_resize_height);
  int resize_width = std::max<int>(0, FLAGS_resize_width);

  // Create new DB
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(argv[3], db::NEW);
  scoped_ptr<db::Transaction> txn(db->NewTransaction());

  // Storing to db
  std::string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  int data_size = 0;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    bool status;
    std::string enc = encode_type;
    if (encoded && !enc.size()) {
      // Guess the encoding type from the file name
      string fn = lines[line_id].first;
      size_t p = fn.rfind('.');
      if ( p == fn.npos )
        LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
      enc = fn.substr(p);
      std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
    }
    status = ReadImageToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, resize_height, resize_width, is_color,
        enc, &datum);
    if (status == false) continue;
    if (check_size) {
      if (!data_size_initialized) {
        data_size = datum.channels() * datum.height() * datum.width();
        data_size_initialized = true;
      } else {
        const std::string& data = datum.data();
        CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
            << data.size();
      }
    }
    // sequential
    string key_str = caffe::format_int(line_id, 8) + "_" + lines[line_id].first;

    // Put in db
    string out;
    CHECK(datum.SerializeToString(&out));
    txn->Put(key_str, out);

    if (++count % 1000 == 0) {
      // Commit db
      txn->Commit();
      txn.reset(db->NewTransaction());
      LOG(INFO) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    txn->Commit();
    LOG(INFO) << "Processed " << count << " files.";
  }
#else
  LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV.";
#endif  // USE_OPENCV
  return 0;
}
コード例 #13
0
  void TestReshape(DataParameter_DB backend) {
    const int num_inputs = 5;
    // Save data of varying shapes.
    LOG(INFO) << "Using temporary dataset " << *filename_;
    scoped_ptr<db::DB> db(db::GetDB(backend));
    db->Open(*filename_, db::NEW);
    scoped_ptr<db::Transaction> txn(db->NewTransaction());
    for (int i = 0; i < num_inputs; ++i) {
      Datum datum;
      datum.set_label(i);
      datum.set_channels(2);
      datum.set_height(i % 2 + 1);
      datum.set_width(i % 4 + 1);
      std::string* data = datum.mutable_data();
      const int data_size = datum.channels() * datum.height() * datum.width();
      for (int j = 0; j < data_size; ++j) {
        data->push_back(static_cast<uint8_t>(j));
      }
      stringstream ss;
      ss << i;
      string out;
      CHECK(datum.SerializeToString(&out));
      txn->Put(ss.str(), out);
    }
    txn->Commit();
    db->Close();

    // Load and check data of various shapes.
    LayerParameter param;
    param.set_phase(TEST);
    DataParameter* data_param = param.mutable_data_param();
    data_param->set_batch_size(1);
    data_param->set_source(filename_->c_str());
    data_param->set_backend(backend);

    DataLayer<Dtype> layer(param);
    layer.SetUp(blob_bottom_vec_, blob_top_vec_);
    EXPECT_EQ(blob_top_data_->num(), 1);
    EXPECT_EQ(blob_top_data_->channels(), 2);
    EXPECT_EQ(blob_top_label_->num(), 1);
    EXPECT_EQ(blob_top_label_->channels(), 1);
    EXPECT_EQ(blob_top_label_->height(), 1);
    EXPECT_EQ(blob_top_label_->width(), 1);

    for (int iter = 0; iter < num_inputs; ++iter) {
      layer.Forward(blob_bottom_vec_, blob_top_vec_);
      EXPECT_EQ(blob_top_data_->height(), iter % 2 + 1);
      EXPECT_EQ(blob_top_data_->width(), iter % 4 + 1);
      EXPECT_EQ(iter, blob_top_label_->cpu_data()[0]);
      const int channels = blob_top_data_->channels();
      const int height = blob_top_data_->height();
      const int width = blob_top_data_->width();
      for (int c = 0; c < channels; ++c) {
        for (int h = 0; h < height; ++h) {
          for (int w = 0; w < width; ++w) {
            const int idx = (c * height + h) * width + w;
            EXPECT_EQ(idx, static_cast<int>(blob_top_data_->cpu_data()[idx]))
                << "debug: iter " << iter << " c " << c
                << " h " << h << " w " << w;
          }
        }
      }
    }
  }
コード例 #14
0
void helper::videoToDatabase(const string &inputVideoDir, const string &outputDatabaseDir, vec2i dimensions, int sampleCount)
{
    const int frameChannels = 64;
    const int frameX = 16;
    const int frameY = 16;

    const int historyFrames = 4;
    const int imageChannelCount = 3;

    const int totalFrames = historyFrames + 1;
    const int totalChannelCount = imageChannelCount * totalFrames;
    const int pixelCount = dimensions.x * dimensions.y;

    const int samplesPerBlock = 10;

    auto videoImagePaths = Directory::enumerateFilesWithPath(inputVideoDir, ".png");
    sort(videoImagePaths.begin(), videoImagePaths.end());

    cout << "Making video database for " << inputVideoDir << endl;

    // leveldb
    leveldb::DB* db;
    leveldb::Options options;
    options.error_if_exists = true;
    options.create_if_missing = true;
    options.write_buffer_size = 268435456;
    leveldb::WriteBatch* batch = NULL;

    // Open db
    cout << "Opening leveldb " << outputDatabaseDir << endl;
    leveldb::Status status = leveldb::DB::Open(options, outputDatabaseDir, &db);
    if (!status.ok())
    {
        cout << "Failed to open " << outputDatabaseDir << " or it already exists" << endl;
        return;
    }
    batch = new leveldb::WriteBatch();

    // Storing to db
    char* rawData = new char[pixelCount * totalChannelCount];

    int count = 0;
    const int kMaxKeyLength = 10;
    char key_cstr[kMaxKeyLength];
    string value;

    Datum datum;
    datum.set_channels(totalChannelCount);
    datum.set_height(dimensions.x);
    datum.set_width(dimensions.y);

    ColorImageR8G8B8A8 dummyImage(dimensions);

    cout << "A total of " << sampleCount << " samples will be generated." << endl;
    cout << "Rows: " << dimensions.x << " Cols: " << dimensions.y << endl;
    for (int sampleIndex = 0; sampleIndex < sampleCount;)
    {
        if (sampleIndex % 1000 == 0)
            cout << "Sample " << sampleIndex << " / " << sampleCount << endl;

        vector<ColorImageR8G8B8A8> sampleImages;
        const int startImageIndex = util::randomInteger(0, videoImagePaths.size() - 6);
        for (int i = 0; i < 5; i++)
            sampleImages.push_back(LodePNG::load(videoImagePaths[startImageIndex + i]));

        for (int blockSampleIndex = 0; blockSampleIndex < samplesPerBlock; blockSampleIndex++)
        {
            vec2i sampleStart(
                util::randomInteger(0, sampleImages[0].getWidth()  - dimensions.x - 2),
                util::randomInteger(0, sampleImages[0].getHeight() - dimensions.y - 2));

            int pIndex = 0;
            
            for (int frameIndex = 0; frameIndex < totalFrames; frameIndex++)
            {
                for (int channel = 0; channel < 3; channel++)
                {
                    for (const auto &p : dummyImage)
                    {
                        rawData[pIndex++] = sampleImages[frameIndex](sampleStart.x + p.x, sampleStart.y + p.y)[channel];
                    }
                }
            }

            datum.set_data(rawData, pixelCount * totalChannelCount);
            datum.set_label(0);

            sprintf_s(key_cstr, kMaxKeyLength, "%08d", sampleIndex);
            datum.SerializeToString(&value);
            string keystr(key_cstr);

            // Put in db
            batch->Put(keystr, value);

            if (++count % 1000 == 0) {
                // Commit txn
                db->Write(leveldb::WriteOptions(), batch);
                delete batch;
                batch = new leveldb::WriteBatch();
            }

            sampleIndex++;
        }
    }
    // write the last batch
    if (count % 1000 != 0) {
        db->Write(leveldb::WriteOptions(), batch);
    }
    delete batch;
    delete db;
    cout << "Processed " << count << " files." << endl;
}
コード例 #15
0
int main(int argc, char** argv)
{
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 5)
  {
    printf("Convert a set of images to the leveldb format used\n"
           "as input for Caffe.\n"
           "Usage:\n"
           "    convert_imageset ROOTFOLDER/ LABELFILE CONTEXT DB_NAME"
           " RANDOM_SHUFFLE_DATA[0 or 1]\n");
    return 0;
  }

  std::vector<std::pair<string, vector<float> > > lines;
  {
    std::ifstream infile(argv[2]);

    vector<float> label(NUMLABEL, 0);
    while (infile.good())
    {
      string filename;
      infile >> filename;
      if (filename.empty())
        break;

      for (int i = 0; i < NUMLABEL; ++i)
        infile >> label[i];

      lines.push_back(std::make_pair(filename, label));
    }
    infile.close();
    if (argc == 6 && argv[5][0] == '1')
    {
      // randomly shuffle data
      LOG(INFO)<< "Shuffling data";
      std::random_shuffle(lines.begin(), lines.end());
    }
    LOG(INFO)<< "A total of " << lines.size() << " images.";
  }

  std::map<string, vector<float> > map_name_contxt;
  {
    vector<float> contxt(NUMCONTEXT, 0);
    std::ifstream input(argv[3], 0);
    while (input.good())
    {
      string filename;
      input >> filename;
      if (filename.empty())
        break;

      for (int i = 0; i < NUMCONTEXT; ++i)
        input >> contxt[i];

      map_name_contxt.insert(std::make_pair(filename, contxt));
    }
    input.close();
  }

  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  LOG(INFO)<< "Opening leveldb " << argv[4];
  leveldb::Status status = leveldb::DB::Open(options, argv[4], &db);
  CHECK(status.ok()) << "Failed to open leveldb " << argv[4];

  string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
  int data_size;
  bool data_size_initialized = false;
  for (int line_id = 0; line_id < lines.size(); ++line_id)
  {
    const std::pair<string, vector<float> >& name_label = lines[line_id];
    const string& name = name_label.first;
    const vector<float>& cur_labels = name_label.second;
    const vector<float>& cur_conxts = map_name_contxt.find(name)->second;

    // set image name
    datum.set_img_name(name);

    // set image data
    {
      const string img_full_name = root_folder + name;
      cv::Mat cv_img = cv::imread(img_full_name, CV_LOAD_IMAGE_COLOR);
      if (!cv_img.data)
      {
        LOG(ERROR)<< "Could not open or find file " << img_full_name;
        return false;
      }

      datum.set_channels(3);
      datum.set_height(cv_img.rows);
      datum.set_width(cv_img.cols);
      datum.clear_data();
      datum.clear_float_data();
      string* datum_string = datum.mutable_data();
      for (int c = 0; c < 3; ++c)
      {
        for (int h = 0; h < cv_img.rows; ++h)
        {
          for (int w = 0; w < cv_img.cols; ++w)
          {
            datum_string->push_back(
                static_cast<char>(cv_img.at<cv::Vec3b>(h, w)[c]));
          }
        }
      }
    }

    // set multi-label
    {
      datum.set_num_multi_label(NUMLABEL);
      datum.clear_multi_label();
      datum.mutable_multi_label->Reserve(cur_labels.size());
      for (int i = 0; i < cur_labels.size(); ++i)
        datum.add_multi_label(cur_labels[i]);
    }

    // set context
    {
      datum.set_num_context(NUMCONTEXT);
      datum.clear_context();
      datum.mutable_context->Reserve(cur_conxts.size());
      for (int i = 0; i < cur_conxts.size(); ++i)
        datum.add_context(cur_conxts[i]);
    }

    string value;
    // get the value
    datum.SerializeToString(&value);
    batch->Put(name, value);
    if (++count % 1000 == 0)
    {
      db->Write(leveldb::WriteOptions(), batch);
      LOG(ERROR)<< "Processed " << count << " files.";
      delete batch;
      batch = new leveldb::WriteBatch();
    }
  }
  // write the last batch
  if (count % 1000 != 0)
  {
    db->Write(leveldb::WriteOptions(), batch);
    LOG(ERROR)<< "Processed " << count << " files.";
  }

  delete batch;
  delete db;
  return 0;
}
コード例 #16
0
ファイル: feature_extractor.cpp プロジェクト: ZhitingHu/NN
void FeatureExtractor<Dtype>::ExtractFeatures(const NetParameter& net_param) {
  util::Context& context = util::Context::get_instance();
  int client_id = context.get_int32("client_id");
  string weights_path = context.get_string("weights");
  string extract_feature_blob_names 
      = context.get_string("extract_feature_blob_names");

  shared_ptr<Net<Dtype> > feature_extraction_net(
      new Net<Dtype>(net_param, thread_id_, 0));
  map<string, vector<int> >::const_iterator it 
      = layer_blobs_global_idx_ptr_->begin();
  for (; it != layer_blobs_global_idx_ptr_->end(); ++it) {
    const shared_ptr<Layer<Dtype> > layer 
        = feature_extraction_net->layer_by_name(it->first);
    layer->SetUpBlobGlobalTable(it->second, false, false);
  }
  if (client_id == 0 && thread_id_ == 0) {
    LOG(INFO) << "Extracting features by " << weights_path;
    feature_extraction_net->CopyTrainedLayersFrom(weights_path, true);
  } 
  petuum::PSTableGroup::GlobalBarrier();

  feature_extraction_net->SyncWithPS(0);

  vector<string> blob_names;
  boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));

  string save_feature_leveldb_names  
      = context.get_string("save_feature_leveldb_names");
  vector<string> leveldb_names;
  boost::split(leveldb_names, save_feature_leveldb_names,
               boost::is_any_of(","));
  CHECK_EQ(blob_names.size(), leveldb_names.size()) <<
      " the number of blob names and leveldb names must be equal";
  size_t num_features = blob_names.size();

  for (size_t i = 0; i < num_features; i++) {
    CHECK(feature_extraction_net->has_blob(blob_names[i]))
        << "Unknown feature blob name " << blob_names[i]
        << " in the network ";
  } 
  CHECK(feature_extraction_net->has_blob("label"))
      << "Fail to find label blob in the network ";

  // Differentiate leveldb names
  std::ostringstream suffix;
  suffix  << "_" << client_id << "_" << thread_id_;
  for (size_t i = 0; i < num_features; i++) {
      leveldb_names[i] = leveldb_names[i] + suffix.str();
  }
  
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  vector<shared_ptr<leveldb::DB> > feature_dbs;
  for (size_t i = 0; i < num_features; ++i) {
    leveldb::DB* db;
    leveldb::Status status = leveldb::DB::Open(options,
                                               leveldb_names[i].c_str(),
                                               &db);
    CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i];
    feature_dbs.push_back(shared_ptr<leveldb::DB>(db));
  }

  int num_mini_batches = context.get_int32("num_mini_batches");
 
  Datum datum;
  vector<shared_ptr<leveldb::WriteBatch> > feature_batches(
      num_features,
      shared_ptr<leveldb::WriteBatch>(new leveldb::WriteBatch()));
  const int kMaxKeyStrLength = 100;
  char key_str[kMaxKeyStrLength];
  vector<Blob<float>*> input_vec;
  vector<int> image_indices(num_features, 0);
  for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
    feature_extraction_net->Forward(input_vec);
    for (int i = 0; i < num_features; ++i) {
      const shared_ptr<Blob<Dtype> > feature_blob 
          = feature_extraction_net->blob_by_name(blob_names[i]);
      const shared_ptr<Blob<Dtype> > label_blob
          = feature_extraction_net->blob_by_name("label");
      const Dtype* labels = label_blob->cpu_data(); 
      int batch_size = feature_blob->num();
      int dim_features = feature_blob->count() / batch_size;
      Dtype* feature_blob_data;
      for (int n = 0; n < batch_size; ++n) {
        datum.set_height(dim_features);
        datum.set_width(1);
        datum.set_channels(1);
        datum.clear_data();
        datum.clear_float_data();
        feature_blob_data = feature_blob->mutable_cpu_data() +
            feature_blob->offset(n);
        for (int d = 0; d < dim_features; ++d) {
          datum.add_float_data(feature_blob_data[d]);
        }
        datum.set_label(static_cast<int>(labels[n]));

        string value;
        datum.SerializeToString(&value);
        snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]);
        feature_batches[i]->Put(string(key_str), value);
        ++image_indices[i];
        if (image_indices[i] % 1000 == 0) {
          feature_dbs[i]->Write(leveldb::WriteOptions(),
                                feature_batches[i].get());
          feature_batches[i].reset(new leveldb::WriteBatch());
        }
      }  // for (int n = 0; n < batch_size; ++n)
    }  // for (int i = 0; i < num_features; ++i)
  }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
  // write the last batch
  for (int i = 0; i < num_features; ++i) {
    if (image_indices[i] % 1000 != 0) {
      feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i].get());
    }
  }
}
コード例 #17
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 4) {
    printf("Convert a set of images to the leveldb format used\n"
        "as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset ROOTFOLDER/ LISTFILE NUM_LABELS DB_NAME"
        " RANDOM_SHUFFLE_DATA[0 or 1]\n");
    return 0;
  }
  //Arguments to our program
  std::ifstream infile(argv[2]);
  int numLabels = atoi(argv[3]);
  
  // Each line is constituted of the path to the file and the vector of 
  // labels
  std::vector<std::pair<string, std::vector<float> > > lines;
  // --------
  string filename;
  std::vector<float> labels(numLabels);
  
  while (infile >> filename) {
	  for(int l=0; l<numLabels; l++)
		  infile >> (labels[l]);
    lines.push_back(std::make_pair(filename, labels));
    /*
    LOG(ERROR) <<  "filepath: " << lines[lines.size()-1].first;
    LOG(ERROR) << "values: " << lines[lines.size()-1].second[0] 
			   << ","		 << lines[lines.size()-1].second[5] 
			   << ","		 << lines[lines.size()-1].second[8];
			   * */
  }
  if (argc == 5 && argv[5][0] == '1') {
    // randomly shuffle data
    LOG(ERROR) << "Shuffling data";
    std::random_shuffle(lines.begin(), lines.end());
  }
  LOG(ERROR) << "A total of " << lines.size() << " images.";

  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  LOG(ERROR) << "Opening leveldb " << argv[4];
  leveldb::Status status = leveldb::DB::Open(
      options, argv[4], &db);
  CHECK(status.ok()) << "Failed to open leveldb " << argv[4];
 
  string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  const int maxKeyLength = 256;
  char key_cstr[maxKeyLength];
  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
  int data_size;
  bool data_size_initialized = false;
  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageWithLabelVectorToDatum(root_folder + lines[line_id].first, lines[line_id].second,
										 &datum)) {
      continue;
    };
  
    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
    } else {
      const string& data = datum.data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size " << data.size();
    }
    
    // sequential
    snprintf(key_cstr, maxKeyLength, "%08d_%s", line_id, lines[line_id].first.c_str());
    string value;
    
    // get the value
    datum.SerializeToString(&value);
    batch->Put(string(key_cstr), value);
    if (++count % 1000 == 0) {
      db->Write(leveldb::WriteOptions(), batch);
      LOG(ERROR) << "Processed " << count << " files.";
      delete batch;
      batch = new leveldb::WriteBatch();
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    db->Write(leveldb::WriteOptions(), batch);
    LOG(ERROR) << "Processed " << count << " files.";
  }

  delete batch;
  delete db;
  return 0;
}
コード例 #18
0
int feature_extraction_pipeline(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  const int num_required_args = 6;
  if (argc < num_required_args) {
    LOG(ERROR)<<
    "This program takes in a trained network and an input data layer, and then"
    " extract features of the input data produced by the net.\n"
    "Usage: extract_features  pretrained_net_param"
    "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
    "  save_feature_leveldb_name1[,name2,...]  num_mini_batches  [CPU/GPU]"
    "  [DEVICE_ID=0]\n"
    "Note: you can extract multiple features in one pass by specifying"
    " multiple feature blob names and leveldb names seperated by ','."
    " The names cannot contain white space characters and the number of blobs"
    " and leveldbs must be equal.";
    return 1;
  }
  int arg_pos = num_required_args;

  arg_pos = num_required_args;
  if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
    LOG(ERROR)<< "Using GPU";
    uint device_id = 0;
    if (argc > arg_pos + 1) {
      device_id = atoi(argv[arg_pos + 1]);
      CHECK_GE(device_id, 0);
    }
    LOG(ERROR) << "Using Device_id=" << device_id;
    Caffe::SetDevice(device_id);
    Caffe::set_mode(Caffe::GPU);
  } else {
    LOG(ERROR) << "Using CPU";
    Caffe::set_mode(Caffe::CPU);
  }
  Caffe::set_phase(Caffe::TEST);

  arg_pos = 0;  // the name of the executable
  string pretrained_binary_proto(argv[++arg_pos]);

  // Expected prototxt contains at least one data layer such as
  //  the layer data_layer_name and one feature blob such as the
  //  fc7 top blob to extract features.
  /*
   layers {
     name: "data_layer_name"
     type: DATA
     data_param {
       source: "/path/to/your/images/to/extract/feature/images_leveldb"
       mean_file: "/path/to/your/image_mean.binaryproto"
       batch_size: 128
       crop_size: 227
       mirror: false
     }
     top: "data_blob_name"
     top: "label_blob_name"
   }
   layers {
     name: "drop7"
     type: DROPOUT
     dropout_param {
       dropout_ratio: 0.5
     }
     bottom: "fc7"
     top: "fc7"
   }
   */
  string feature_extraction_proto(argv[++arg_pos]);
  shared_ptr<Net<Dtype> > feature_extraction_net(
      new Net<Dtype>(feature_extraction_proto));
  feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);

  string extract_feature_blob_names(argv[++arg_pos]);
  vector<string> blob_names;
  boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));

  string save_feature_leveldb_names(argv[++arg_pos]);
  vector<string> leveldb_names;
  boost::split(leveldb_names, save_feature_leveldb_names,
               boost::is_any_of(","));
  CHECK_EQ(blob_names.size(), leveldb_names.size()) <<
      " the number of blob names and leveldb names must be equal";
  size_t num_features = blob_names.size();

  for (size_t i = 0; i < num_features; i++) {
    CHECK(feature_extraction_net->has_blob(blob_names[i]))
        << "Unknown feature blob name " << blob_names[i]
        << " in the network " << feature_extraction_proto;
  }

  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  vector<shared_ptr<leveldb::DB> > feature_dbs;
  for (size_t i = 0; i < num_features; ++i) {
    LOG(INFO)<< "Opening leveldb " << leveldb_names[i];
    leveldb::DB* db;
    leveldb::Status status = leveldb::DB::Open(options,
                                               leveldb_names[i].c_str(),
                                               &db);
    CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i];
    feature_dbs.push_back(shared_ptr<leveldb::DB>(db));
  }

  int num_mini_batches = atoi(argv[++arg_pos]);

  LOG(ERROR)<< "Extracting Features";

  Datum datum;
  vector<shared_ptr<leveldb::WriteBatch> > feature_batches(
      num_features,
      shared_ptr<leveldb::WriteBatch>(new leveldb::WriteBatch()));
  const int kMaxKeyStrLength = 100;
  char key_str[kMaxKeyStrLength];
  vector<Blob<float>*> input_vec;
  vector<int> image_indices(num_features, 0);
  for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
    feature_extraction_net->Forward(input_vec);
    for (int i = 0; i < num_features; ++i) {
      const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
          ->blob_by_name(blob_names[i]);
      int batch_size = feature_blob->num();
      int dim_features = feature_blob->count() / batch_size;
      Dtype* feature_blob_data;
      for (int n = 0; n < batch_size; ++n) {
        datum.set_height(dim_features);
        datum.set_width(1);
        datum.set_channels(1);
        datum.clear_data();
        datum.clear_float_data();
        feature_blob_data = feature_blob->mutable_cpu_data() +
            feature_blob->offset(n);
        for (int d = 0; d < dim_features; ++d) {
          datum.add_float_data(feature_blob_data[d]);
        }
        string value;
        datum.SerializeToString(&value);
        snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]);
        feature_batches[i]->Put(string(key_str), value);
        ++image_indices[i];
        if (image_indices[i] % 1000 == 0) {
          feature_dbs[i]->Write(leveldb::WriteOptions(),
                                feature_batches[i].get());
          LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
              " query images for feature blob " << blob_names[i];
          feature_batches[i].reset(new leveldb::WriteBatch());
        }
      }  // for (int n = 0; n < batch_size; ++n)
    }  // for (int i = 0; i < num_features; ++i)
  }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
  // write the last batch
  for (int i = 0; i < num_features; ++i) {
    if (image_indices[i] % 1000 != 0) {
      feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i].get());
    }
    LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
        " query images for feature blob " << blob_names[i];
  }

  LOG(ERROR)<< "Successfully extracted the features!";
  return 0;
}
コード例 #19
0
ファイル: convert_imageset.cpp プロジェクト: MaoXu/jade
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc < 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
    return 1;
  }

  const bool is_color = !FLAGS_gray;
  const bool check_size = FLAGS_check_size;
  const bool encoded = FLAGS_encoded;

  std::ifstream infile(argv[2]);
  std::vector<std::pair<std::string, int> > lines;
  std::string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (FLAGS_shuffle) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  if (encoded) {
    CHECK_EQ(FLAGS_resize_height, 0) << "With encoded don't resize images";
    CHECK_EQ(FLAGS_resize_width, 0) << "With encoded don't resize images";
    CHECK(!check_size) << "With encoded cannot check_size";
  }

  int resize_height = std::max<int>(0, FLAGS_resize_height);
  int resize_width = std::max<int>(0, FLAGS_resize_width);

  // Create new DB
  scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
  db->Open(argv[3], db::NEW);
  scoped_ptr<db::Transaction> txn(db->NewTransaction());

  // Storing to db
  std::string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int data_size;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    bool status;
    if (encoded) {
      status = ReadFileToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, &datum);
    } else {
      status = ReadImageToDatum(root_folder + lines[line_id].first,
          lines[line_id].second, resize_height, resize_width, is_color, &datum);
    }
    if (status == false) continue;
    if (check_size) {
      if (!data_size_initialized) {
        data_size = datum.channels() * datum.height() * datum.width();
        data_size_initialized = true;
      } else {
        const std::string& data = datum.data();
        CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
            << data.size();
      }
    }
    // sequential
    int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());

    // Put in db
    string out;
    CHECK(datum.SerializeToString(&out));
    txn->Put(string(key_cstr, length), out);

    if (++count % 1000 == 0) {
      // Commit db
      txn->Commit();
      txn.reset(db->NewTransaction());
      LOG(ERROR) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    txn->Commit();
    LOG(ERROR) << "Processed " << count << " files.";
  }
  return 0;
}
コード例 #20
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 4) {
    printf("Convert a set of images to the leveldb format used\n"
        "as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset ROOTFOLDER/ LISTFILE DB_NAME"
        " RANDOM_SHUFFLE_DATA[0 or 1]\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
    return 0;
  }
  std::ifstream infile(argv[2]);
  std::vector<std::pair<string, int> > lines;
  string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (argc == 5 && argv[4][0] == '1') {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    std::random_shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  LOG(INFO) << "Opening leveldb " << argv[3];
  leveldb::Status status = leveldb::DB::Open(
      options, argv[3], &db);
  CHECK(status.ok()) << "Failed to open leveldb " << argv[3];

  string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
  int data_size;
  bool data_size_initialized = false;
  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageToDatum(root_folder + lines[line_id].first,
                          lines[line_id].second, &datum)) {
      continue;
    }
    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
      data_size_initialized = true;
    } else {
      const string& data = datum.data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
          << data.size();
    }
    // sequential
    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());
    string value;
    // get the value
    datum.SerializeToString(&value);
    batch->Put(string(key_cstr), value);
    if (++count % 1000 == 0) {
      db->Write(leveldb::WriteOptions(), batch);
      LOG(ERROR) << "Processed " << count << " files.";
      delete batch;
      batch = new leveldb::WriteBatch();
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    db->Write(leveldb::WriteOptions(), batch);
    LOG(ERROR) << "Processed " << count << " files.";
  }

  delete batch;
  delete db;
  return 0;
}
コード例 #21
0
ファイル: extract_features.cpp プロジェクト: koufeifei/caffe
int feature_extraction_pipeline(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  const int num_required_args = 7;
  if (argc < num_required_args) {
    LOG(ERROR)<<
    "This program takes in a trained network and an input data layer, and then"
    " extract features of the input data produced by the net.\n"
    "Usage: extract_features  pretrained_net_param"
    "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
    "  save_feature_dataset_name1[,name2,...]  num_mini_batches  db_type"
    "  [CPU/GPU] [DEVICE_ID=0]\n"
    "Note: you can extract multiple features in one pass by specifying"
    " multiple feature blob names and dataset names seperated by ','."
    " The names cannot contain white space characters and the number of blobs"
    " and datasets must be equal.";
    return 1;
  }
  int arg_pos = num_required_args;

  arg_pos = num_required_args;
  if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
    LOG(ERROR)<< "Using GPU";
    uint device_id = 0;
    if (argc > arg_pos + 1) {
      device_id = atoi(argv[arg_pos + 1]);
      CHECK_GE(device_id, 0);
    }
    LOG(ERROR) << "Using Device_id=" << device_id;
    Caffe::SetDevice(device_id);
    Caffe::set_mode(Caffe::GPU);
  } else {
    LOG(ERROR) << "Using CPU";
    Caffe::set_mode(Caffe::CPU);
  }

  arg_pos = 0;  // the name of the executable
  std::string pretrained_binary_proto(argv[++arg_pos]);

  // Expected prototxt contains at least one data layer such as
  //  the layer data_layer_name and one feature blob such as the
  //  fc7 top blob to extract features.
  /*
   layers {
     name: "data_layer_name"
     type: DATA
     data_param {
       source: "/path/to/your/images/to/extract/feature/images_leveldb"
       mean_file: "/path/to/your/image_mean.binaryproto"
       batch_size: 128
       crop_size: 227
       mirror: false
     }
     top: "data_blob_name"
     top: "label_blob_name"
   }
   layers {
     name: "drop7"
     type: DROPOUT
     dropout_param {
       dropout_ratio: 0.5
     }
     bottom: "fc7"
     top: "fc7"
   }
   */
  std::string feature_extraction_proto(argv[++arg_pos]);
  shared_ptr<Net<Dtype> > feature_extraction_net(
      new Net<Dtype>(feature_extraction_proto, caffe::TEST));
  feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);

  std::string extract_feature_blob_names(argv[++arg_pos]);
  std::vector<std::string> blob_names;
  boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));

  std::string save_feature_dataset_names(argv[++arg_pos]);
  std::vector<std::string> dataset_names;
  boost::split(dataset_names, save_feature_dataset_names,
               boost::is_any_of(","));
  CHECK_EQ(blob_names.size(), dataset_names.size()) <<
      " the number of blob names and dataset names must be equal";
  size_t num_features = blob_names.size();

  for (size_t i = 0; i < num_features; i++) {
    CHECK(feature_extraction_net->has_blob(blob_names[i]))
        << "Unknown feature blob name " << blob_names[i]
        << " in the network " << feature_extraction_proto;
  }

  int num_mini_batches = atoi(argv[++arg_pos]);

  std::vector<shared_ptr<db::DB> > feature_dbs;
  std::vector<shared_ptr<db::Transaction> > txns;
  const char* db_type = argv[++arg_pos];
  for (size_t i = 0; i < num_features; ++i) {
    LOG(INFO)<< "Opening dataset " << dataset_names[i];
    shared_ptr<db::DB> db(db::GetDB(db_type));
    db->Open(dataset_names.at(i), db::NEW);
    feature_dbs.push_back(db);
    shared_ptr<db::Transaction> txn(db->NewTransaction());
    txns.push_back(txn);
  }

  LOG(ERROR)<< "Extacting Features";

  Datum datum;
  const int kMaxKeyStrLength = 100;
  char key_str[kMaxKeyStrLength];
  std::vector<Blob<float>*> input_vec;
  std::vector<int> image_indices(num_features, 0);
  for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
    feature_extraction_net->Forward(input_vec);
    for (int i = 0; i < num_features; ++i) {
      const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
          ->blob_by_name(blob_names[i]);
      int batch_size = feature_blob->num();
      int dim_features = feature_blob->count() / batch_size;
      const Dtype* feature_blob_data;
      for (int n = 0; n < batch_size; ++n) {
        datum.set_height(feature_blob->height());
        datum.set_width(feature_blob->width());
        datum.set_channels(feature_blob->channels());
        datum.clear_data();
        datum.clear_float_data();
        feature_blob_data = feature_blob->cpu_data() +
            feature_blob->offset(n);
        for (int d = 0; d < dim_features; ++d) {
          datum.add_float_data(feature_blob_data[d]);
        }
        // int length = snprintf(key_str, kMaxKeyStrLength, "%08d",
        int length = snprintf(key_str, kMaxKeyStrLength, "%010d",
            image_indices[i]);
        string out;
        CHECK(datum.SerializeToString(&out));
        txns.at(i)->Put(std::string(key_str, length), out);
        ++image_indices[i];
        if (image_indices[i] % 1000 == 0) {
          txns.at(i)->Commit();
          txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
          LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
              " query images for feature blob " << blob_names[i];
        }
      }  // for (int n = 0; n < batch_size; ++n)
    }  // for (int i = 0; i < num_features; ++i)
  }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
  // write the last batch
  for (int i = 0; i < num_features; ++i) {
    if (image_indices[i] % 1000 != 0) {
      txns.at(i)->Commit();
    }
    LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
        " query images for feature blob " << blob_names[i];
    feature_dbs.at(i)->Close();
  }

  LOG(ERROR)<< "Successfully extracted the features!";
  return 0;
}
コード例 #22
0
int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
        "The ImageNet dataset for the training demo is at\n"
        "    http://www.image-net.org/download-images\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc != 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
    return 1;
  }


  bool is_color = !FLAGS_gray;
  std::ifstream infile(argv[2]);
  std::vector<std::pair<string, int> > lines;
  string filename;
  int label;
  while (infile >> filename >> label) {
    lines.push_back(std::make_pair(filename, label));
  }
  if (FLAGS_shuffle) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
  }
  LOG(INFO) << "A total of " << lines.size() << " images.";

  const string& db_backend = FLAGS_backend;
  const char* db_path = argv[3];

  int resize_height = std::max<int>(0, FLAGS_resize_height);
  int resize_width = std::max<int>(0, FLAGS_resize_width);

  // Open new db
  // lmdb
  MDB_env *mdb_env;
  MDB_dbi mdb_dbi;
  MDB_val mdb_key, mdb_data;
  MDB_txn *mdb_txn;
  // leveldb
  leveldb::DB* db;
  leveldb::Options options;
  options.error_if_exists = true;
  options.create_if_missing = true;
  options.write_buffer_size = 268435456;
  leveldb::WriteBatch* batch = NULL;

  // Open db
  if (db_backend == "leveldb") {  // leveldb
    LOG(INFO) << "Opening leveldb " << db_path;
    leveldb::Status status = leveldb::DB::Open(
        options, db_path, &db);
    CHECK(status.ok()) << "Failed to open leveldb " << db_path
        << ". Is it already existing?";
    batch = new leveldb::WriteBatch();
  } else if (db_backend == "lmdb") {  // lmdb
    LOG(INFO) << "Opening lmdb " << db_path;
    CHECK_EQ(mkdir(db_path, 0744), 0)
        << "mkdir " << db_path << "failed";
    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB
        << "mdb_env_set_mapsize failed";
    CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS)
        << "mdb_env_open failed";
    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
        << "mdb_txn_begin failed";
    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
        << "mdb_open failed. Does the lmdb already exist? ";
  } else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  // Storing to db
  string root_folder(argv[1]);
  Datum datum;
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int data_size;
  bool data_size_initialized = false;

  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    if (!ReadImageToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, resize_height, resize_width, is_color, &datum)) {
      continue;
    }
    if (!data_size_initialized) {
      data_size = datum.channels() * datum.height() * datum.width();
      data_size_initialized = true;
    } else {
      const string& data = datum.data();
      CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
          << data.size();
    }
    // sequential
    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());
    string value;
    datum.SerializeToString(&value);
    string keystr(key_cstr);

    // Put in db
    if (db_backend == "leveldb") {  // leveldb
      batch->Put(keystr, value);
    } else if (db_backend == "lmdb") {  // lmdb
      mdb_data.mv_size = value.size();
      mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
      mdb_key.mv_size = keystr.size();
      mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);
      CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS)
          << "mdb_put failed";
    } else {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }

    if (++count % 1000 == 0) {
      // Commit txn
      if (db_backend == "leveldb") {  // leveldb
        db->Write(leveldb::WriteOptions(), batch);
        delete batch;
        batch = new leveldb::WriteBatch();
      } else if (db_backend == "lmdb") {  // lmdb
        CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS)
            << "mdb_txn_commit failed";
        CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
            << "mdb_txn_begin failed";
      } else {
        LOG(FATAL) << "Unknown db backend " << db_backend;
      }
      LOG(ERROR) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    if (db_backend == "leveldb") {  // leveldb
      db->Write(leveldb::WriteOptions(), batch);
      delete batch;
      delete db;
    } else if (db_backend == "lmdb") {  // lmdb
      CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
      mdb_close(mdb_env, mdb_dbi);
      mdb_env_close(mdb_env);
    } else {
      LOG(FATAL) << "Unknown db backend " << db_backend;
    }
    LOG(ERROR) << "Processed " << count << " files.";
  }
  return 0;
}