예제 #1
0
int main(int argc, char *argv[]) {
	google::ParseCommandLineFlags(&argc, &argv, true);
	google::InitGoogleLogging(argv[0]);

	petuum::HighResolutionTimer data_loading_timer;
	LOG(INFO)<< "training file location: " << FLAGS_train_file;
	KMeans kmeans;
	kmeans.ReadData();
	LOG(INFO)<< "Data Loading Complete. Loaded "  << kmeans.GetTrainingDataSize() << " in "  <<data_loading_timer.elapsed();

	//  kmeans.

	petuum::TableGroupConfig table_group_config;
	table_group_config.num_comm_channels_per_client =
			FLAGS_num_comm_channels_per_client;
	table_group_config.num_total_clients = FLAGS_num_clients;

	table_group_config.num_tables = 4;
	//  // + 1 for main() thread.
	table_group_config.num_local_app_threads = FLAGS_num_app_threads + 1;
	table_group_config.client_id = FLAGS_client_id;
	table_group_config.stats_path = FLAGS_stats_path;
	//
	petuum::GetHostInfos(FLAGS_hostfile, &table_group_config.host_map);
	if (std::string("SSP").compare(FLAGS_consistency_model) == 0) {
		table_group_config.consistency_model = petuum::SSP;
	} else if (std::string("SSPPush").compare(FLAGS_consistency_model) == 0) {
		table_group_config.consistency_model = petuum::SSPPush;
	} else if (std::string("LocalOOC").compare(FLAGS_consistency_model) == 0) {
		table_group_config.consistency_model = petuum::LocalOOC;
	} else {
		LOG(FATAL)<< "Unkown consistency model: " << FLAGS_consistency_model;
	}

	petuum::PSTableGroup::RegisterRow<petuum::DenseRow<float> >(
			kDenseRowFloatTypeID);
	petuum::PSTableGroup::RegisterRow<petuum::DenseRow<int> >(
			kDenseRowIntTypeID);
	//

	petuum::PSTableGroup::Init(table_group_config, false);
	//

	petuum::ClientTableConfig table_config;
	table_config.table_info.row_type = kDenseRowFloatTypeID;
	table_config.table_info.table_staleness = FLAGS_staleness;
	//  //table_config.table_info.row_capacity = feature_dim * num_labels;
	table_config.table_info.row_capacity = FLAGS_dimensionality;
	table_config.table_info.row_oplog_type = FLAGS_row_oplog_type;
	table_config.table_info.oplog_dense_serialized =
			FLAGS_oplog_dense_serialized;
	table_config.table_info.dense_row_oplog_capacity = FLAGS_dimensionality;
	//  //table_config.process_cache_capacity = 1;
	table_config.process_cache_capacity = FLAGS_num_centers;
	table_config.oplog_capacity = table_config.process_cache_capacity;
	petuum::PSTableGroup::CreateTable(FLAGS_centres_table_id, table_config);

	LOG(INFO) << "created centers table";


	//Objective Function table.
	table_config.table_info.dense_row_oplog_capacity = FLAGS_num_epochs+1;
	table_config.table_info.table_staleness = 0;
	petuum::PSTableGroup::CreateTable(FLAGS_objective_function_value_tableId, table_config);
	LOG(INFO) << "created objective values table";


	// Centers table
	table_config.table_info.row_type = kDenseRowIntTypeID;
	table_config.table_info.table_staleness = FLAGS_count_table_staleness;
	table_config.table_info.row_capacity = FLAGS_num_centers;
	table_config.process_cache_capacity = 1000;
	table_config.oplog_capacity = table_config.process_cache_capacity;
	petuum::PSTableGroup::CreateTable(FLAGS_center_count_tableId, table_config);


	// Table to hold the local deltas.
	petuum::PSTableGroup::CreateTable(FLAGS_update_centres_table_id, table_config);


	LOG(INFO) << "Completed creating tables" ;

	petuum::PSTableGroup::CreateTableDone();

	std::vector<std::thread> threads(FLAGS_num_app_threads);
	for (auto& thr : threads) {

		thr = std::thread(&KMeans::Start, std::ref(kmeans));
	}
	for (auto& thr : threads) {
		thr.join();
	}

	petuum::PSTableGroup::ShutDown();
	LOG(INFO)<< "Kmeans finished and shut down!";
	return 0;
}