int main(int argc, char **argv) { srand ( time(NULL) ); try { CMDLine cmdline(argc, argv); std::cout << "----------------------------------------------------------------------------" << std::endl; std::cout << "libFM" << std::endl; std::cout << " Version: 1.40" << std::endl; std::cout << " Author: Steffen Rendle, [email protected]" << std::endl; std::cout << " WWW: http://www.libfm.org/" << std::endl; std::cout << " License: Free for academic use. See license.txt." << std::endl; std::cout << "----------------------------------------------------------------------------" << std::endl; const std::string param_task = cmdline.registerParameter("task", "r=regression, c=binary classification [MANDATORY]"); const std::string param_meta_file = cmdline.registerParameter("meta", "filename for meta information about data set"); const std::string param_train_file = cmdline.registerParameter("train", "filename for training data [MANDATORY]"); const std::string param_test_file = cmdline.registerParameter("test", "filename for test data [MANDATORY]"); const std::string param_val_file = cmdline.registerParameter("validation", "filename for validation data (only for SGDA)"); const std::string param_out = cmdline.registerParameter("out", "filename for output"); const std::string param_dim = cmdline.registerParameter("dim", "'k0,k1,k2': k0=use bias, k1=use 1-way interactions, k2=dim of 2-way interactions; default=1,1,8"); const std::string param_regular = cmdline.registerParameter("regular", "'r0,r1,r2' for SGD and ALS: r0=bias regularization, r1=1-way regularization, r2=2-way regularization"); const std::string param_init_stdev = cmdline.registerParameter("init_stdev", "stdev for initialization of 2-way factors; default=0.1"); const std::string param_num_iter = cmdline.registerParameter("iter", "number of iterations; default=100"); const std::string param_learn_rate = cmdline.registerParameter("learn_rate", "learn_rate for SGD; default=0.1"); const std::string param_method = cmdline.registerParameter("method", "learning method (SGD, SGDA, ALS, MCMC); default=MCMC"); const std::string param_verbosity = cmdline.registerParameter("verbosity", "how much infos to print; default=0"); const std::string param_r_log = cmdline.registerParameter("rlog", "write measurements within iterations to a file; default=''"); const std::string param_help = cmdline.registerParameter("help", "this screen"); const std::string param_relation = cmdline.registerParameter("relation", "BS: filenames for the relations, default=''"); const std::string param_cache_size = cmdline.registerParameter("cache_size", "cache size for data storage (only applicable if data is in binary format), default=infty"); const std::string param_do_sampling = "do_sampling"; const std::string param_do_multilevel = "do_multilevel"; const std::string param_num_eval_cases = "num_eval_cases"; if (cmdline.hasParameter(param_help) || (argc == 1)) { cmdline.print_help(); return 0; } cmdline.checkParameters(); if (! cmdline.hasParameter(param_method)) { cmdline.setValue(param_method, "mcmc"); } if (! cmdline.hasParameter(param_init_stdev)) { cmdline.setValue(param_init_stdev, "0.1"); } if (! cmdline.hasParameter(param_dim)) { cmdline.setValue(param_dim, "1,1,8"); } if (! cmdline.getValue(param_method).compare("als")) { // als is an mcmc without sampling and hyperparameter inference cmdline.setValue(param_method, "mcmc"); if (! cmdline.hasParameter(param_do_sampling)) { cmdline.setValue(param_do_sampling, "0"); } if (! cmdline.hasParameter(param_do_multilevel)) { cmdline.setValue(param_do_multilevel, "0"); } } // (1) Load the data std::cout << "Loading train...\t" << std::endl; Data train( cmdline.getValue(param_cache_size, 0), ! (!cmdline.getValue(param_method).compare("mcmc")), // no original data for mcmc ! (!cmdline.getValue(param_method).compare("sgd") || !cmdline.getValue(param_method).compare("sgda")) // no transpose data for sgd, sgda ); if (cmdline.hasParameter(param_train_file)) { train.loadFromFile(cmdline.getValue(param_train_file)); } else { std::cout << "No train data file provided. Waited for stdin..." << std::endl << "WARNING! DO NOT USE STDIN DIRECTLY! THIS MODE PROVIDED ONLY FOR JMLL!" << std::endl; train.loadFromStdin(&std::cin); } if (cmdline.getValue(param_verbosity, 0) > 0) { train.debug(); } Data* test = NULL; if (cmdline.hasParameter(param_test_file)) { std::cout << "Loading test... \t" << std::endl; test = new Data( cmdline.getValue(param_cache_size, 0), ! (!cmdline.getValue(param_method).compare("mcmc")), // no original data for mcmc ! (!cmdline.getValue(param_method).compare("sgd") || !cmdline.getValue(param_method).compare("sgda")) // no transpose data for sgd, sgda ); test->loadFromFile(cmdline.getValue(param_test_file)); if (cmdline.getValue(param_verbosity, 0) > 0) test->debug(); } Data* validation = NULL; if (cmdline.hasParameter(param_val_file)) { if (cmdline.getValue(param_method).compare("sgda")) { std::cout << "WARNING: Validation data is only used for SGDA. The data is ignored." << std::endl; } else { std::cout << "Loading validation set...\t" << std::endl; validation = new Data( cmdline.getValue(param_cache_size, 0), ! (!cmdline.getValue(param_method).compare("mcmc")), // no original data for mcmc ! (!cmdline.getValue(param_method).compare("sgd") || !cmdline.getValue(param_method).compare("sgda")) // no transpose data for sgd, sgda ); validation->loadFromFile(cmdline.getValue(param_val_file)); if (cmdline.getValue(param_verbosity, 0) > 0) { validation->debug(); } } } DVector<RelationData*> relation; // (1.2) Load relational data { vector<std::string> rel = cmdline.getStrValues(param_relation); std::cout << "#relations: " << rel.size() << std::endl; relation.setSize(rel.size()); train.relation.setSize(rel.size()); if (test != NULL) test->relation.setSize(rel.size()); for (uint i = 0; i < rel.size(); i++) { relation(i) = new RelationData( cmdline.getValue(param_cache_size, 0), ! (!cmdline.getValue(param_method).compare("mcmc")), // no original data for mcmc ! (!cmdline.getValue(param_method).compare("sgd") || !cmdline.getValue(param_method).compare("sgda")) // no transpose data for sgd, sgda ); relation(i)->load(rel[i]); train.relation(i).data = relation(i); train.relation(i).load(rel[i] + ".train", train.num_cases); if (test != NULL) { test->relation(i).data = relation(i); test->relation(i).load(rel[i] + ".test", test->num_cases); } } } // (1.3) Load meta data std::cout << "Loading meta data...\t" << std::endl; // (main table) uint num_all_attribute = train.num_feature; if (test != NULL) { num_all_attribute = std::max(train.num_feature, test->num_feature); } if (validation != NULL) { num_all_attribute = std::max(num_all_attribute, (uint) validation->num_feature); } DataMetaInfo meta_main(num_all_attribute); if (cmdline.hasParameter(param_meta_file)) { meta_main.loadGroupsFromFile(cmdline.getValue(param_meta_file)); } // build the joined meta table for (uint r = 0; r < train.relation.dim; r++) { train.relation(r).data->attr_offset = num_all_attribute; num_all_attribute += train.relation(r).data->num_feature; } DataMetaInfo meta(num_all_attribute); { meta.num_attr_groups = meta_main.num_attr_groups; for (uint r = 0; r < relation.dim; r++) { meta.num_attr_groups += relation(r)->meta->num_attr_groups; } meta.num_attr_per_group.setSize(meta.num_attr_groups); meta.num_attr_per_group.init(0); for (uint i = 0; i < meta_main.attr_group.dim; i++) { meta.attr_group(i) = meta_main.attr_group(i); meta.num_attr_per_group(meta.attr_group(i))++; } uint attr_cntr = meta_main.attr_group.dim; uint attr_group_cntr = meta_main.num_attr_groups; for (uint r = 0; r < relation.dim; r++) { for (uint i = 0; i < relation(r)->meta->attr_group.dim; i++) { meta.attr_group(i+attr_cntr) = attr_group_cntr + relation(r)->meta->attr_group(i); meta.num_attr_per_group(attr_group_cntr + relation(r)->meta->attr_group(i))++; } attr_cntr += relation(r)->meta->attr_group.dim; attr_group_cntr += relation(r)->meta->num_attr_groups; } if (cmdline.getValue(param_verbosity, 0) > 0) { meta.debug(); } } meta.num_relations = train.relation.dim; // (2) Setup the factorization machine fm_model fm; { fm.num_attribute = num_all_attribute; fm.init_stdev = cmdline.getValue(param_init_stdev, 0.1); // set the number of dimensions in the factorization { vector<int> dim = cmdline.getIntValues(param_dim); assert(dim.size() == 3); fm.k0 = dim[0] != 0; fm.k1 = dim[1] != 0; fm.num_factor = dim[2]; } fm.init(); } // (3) Setup the learning method: fm_learn* fml; if (! cmdline.getValue(param_method).compare("sgd")) { fml = new fm_learn_sgd_element(); ((fm_learn_sgd*)fml)->num_iter = cmdline.getValue(param_num_iter, 100); } else if (! cmdline.getValue(param_method).compare("sgda")) { assert(validation != NULL); fml = new fm_learn_sgd_element_adapt_reg(); ((fm_learn_sgd*)fml)->num_iter = cmdline.getValue(param_num_iter, 100); ((fm_learn_sgd_element_adapt_reg*)fml)->validation = validation; } else if (! cmdline.getValue(param_method).compare("mcmc")) { fm.w.init_normal(fm.init_mean, fm.init_stdev); fml = new fm_learn_mcmc_simultaneous(); fml->validation = validation; ((fm_learn_mcmc*)fml)->num_iter = cmdline.getValue(param_num_iter, 100); ((fm_learn_mcmc*)fml)->num_eval_cases = cmdline.getValue(param_num_eval_cases, (test != NULL)? test->num_cases : 0); ((fm_learn_mcmc*)fml)->do_sample = cmdline.getValue(param_do_sampling, true); ((fm_learn_mcmc*)fml)->do_multilevel = cmdline.getValue(param_do_multilevel, true); } else { throw "unknown method"; } if (test != NULL) { fml->test = test; } fml->fm = &fm; fml->max_target = train.max_target; fml->min_target = train.min_target; fml->meta = &meta; if (! cmdline.getValue("task").compare("r") ) { fml->task = 0; } else if (! cmdline.getValue("task").compare("c") ) { fml->task = 1; for (uint i = 0; i < train.target.dim; i++) { if (train.target(i) <= 0.0) { train.target(i) = -1.0; } else { train.target(i) = 1.0; } } if (test != NULL) { for (uint i = 0; i < test->target.dim; i++) { if (test->target(i) <= 0.0) { test->target(i) = -1.0; } else { test->target(i) = 1.0; } } } if (validation != NULL) { for (uint i = 0; i < validation->target.dim; i++) { if (validation->target(i) <= 0.0) { validation->target(i) = -1.0; } else { validation->target(i) = 1.0; } } } } else { throw "unknown task"; } // (4) init the logging RLog* rlog = NULL; if (cmdline.hasParameter(param_r_log)) { ofstream* out_rlog = NULL; std::string r_log_str = cmdline.getValue(param_r_log); out_rlog = new ofstream(r_log_str.c_str()); if (! out_rlog->is_open()) { throw "Unable to open file " + r_log_str; } std::cout << "logging to " << r_log_str.c_str() << std::endl; rlog = new RLog(out_rlog); } fml->log = rlog; fml->init(); if (! cmdline.getValue(param_method).compare("mcmc")) { // set the regularization; for als and mcmc this can be individual per group { vector<double> reg = cmdline.getDblValues(param_regular); assert((reg.size() == 0) || (reg.size() == 1) || (reg.size() == 3) || (reg.size() == (1+meta.num_attr_groups*2))); if (reg.size() == 0) { fm.reg0 = 0.0; fm.regw = 0.0; fm.regv = 0.0; ((fm_learn_mcmc*)fml)->w_lambda.init(fm.regw); ((fm_learn_mcmc*)fml)->v_lambda.init(fm.regv); } else if (reg.size() == 1) { fm.reg0 = reg[0]; fm.regw = reg[0]; fm.regv = reg[0]; ((fm_learn_mcmc*)fml)->w_lambda.init(fm.regw); ((fm_learn_mcmc*)fml)->v_lambda.init(fm.regv); } else if (reg.size() == 3) { fm.reg0 = reg[0]; fm.regw = reg[1]; fm.regv = reg[2]; ((fm_learn_mcmc*)fml)->w_lambda.init(fm.regw); ((fm_learn_mcmc*)fml)->v_lambda.init(fm.regv); } else { fm.reg0 = reg[0]; fm.regw = 0.0; fm.regv = 0.0; int j = 1; for (uint g = 0; g < meta.num_attr_groups; g++) { ((fm_learn_mcmc*)fml)->w_lambda(g) = reg[j]; j++; } for (uint g = 0; g < meta.num_attr_groups; g++) { for (int f = 0; f < fm.num_factor; f++) { ((fm_learn_mcmc*)fml)->v_lambda(g,f) = reg[j]; } j++; } } } } else { // set the regularization; for standard SGD, groups are not supported { vector<double> reg = cmdline.getDblValues(param_regular); assert((reg.size() == 0) || (reg.size() == 1) || (reg.size() == 3)); if (reg.size() == 0) { fm.reg0 = 0.0; fm.regw = 0.0; fm.regv = 0.0; } else if (reg.size() == 1) { fm.reg0 = reg[0]; fm.regw = reg[0]; fm.regv = reg[0]; } else { fm.reg0 = reg[0]; fm.regw = reg[1]; fm.regv = reg[2]; } } } { fm_learn_sgd* fmlsgd= dynamic_cast<fm_learn_sgd*>(fml); if (fmlsgd) { // set the learning rates (individual per layer) { vector<double> lr = cmdline.getDblValues(param_learn_rate); assert((lr.size() == 1) || (lr.size() == 3)); if (lr.size() == 1) { fmlsgd->learn_rate = lr[0]; fmlsgd->learn_rates.init(lr[0]); } else { fmlsgd->learn_rate = 0; fmlsgd->learn_rates(0) = lr[0]; fmlsgd->learn_rates(1) = lr[1]; fmlsgd->learn_rates(2) = lr[2]; } } } } if (rlog != NULL) { rlog->init(); } if (cmdline.getValue(param_verbosity, 0) > 0) { fm.debug(); fml->debug(); } // () learn fml->learn(train); // () Prediction at the end (not for mcmc and als) if (cmdline.getValue(param_method).compare("mcmc")) { std::cout << "Final\t" << "Train=" << fml->evaluate(train); if (test != NULL) std::cout << "\tTest=" << fml->evaluate(*test); std::cout << std::endl; } if (!cmdline.hasParameter(param_train_file)) { fml->fm->out(&std::cout); } // () Save prediction if (cmdline.hasParameter(param_out)) { DVector<double> pred; pred.setSize(test->num_cases); fml->predict(*test, pred); pred.saveToFile(cmdline.getValue(param_out)); } } catch (std::string &e) { std::cerr << std::endl << "ERROR: " << e << std::endl; } catch (char const* &e) { std::cerr << std::endl << "ERROR: " << e << std::endl; } }