void train(ME_Model & model, const string & filename) { ifstream ifile(filename.c_str()); if (!ifile) { cerr << "error: cannot open " << filename << endl; exit(1); } string line; int n = 0; while (getline(ifile, line)) { vector<Token> vs = read_line(line); for (int j = 0; j < (int)vs.size(); j++) { ME_Sample mes = sample(vs, j); model.add_training_sample(mes); } if (n++ > 10000) break; } model.use_l1_regularizer(1.0); // model.use_l2_regularizer(1.0); // model.use_SGD(); model.set_heldout(100); model.train(); model.save_to_file("model"); }
// Train model RcppExport SEXP train_model(double l1=0, double l2=0, bool sgd=FALSE, int sgd_iter=30, double sgd_eta0=1, double sgd_alpha=0.85, int heldout=0) { Rprintf("Training the new model...\n"); if (heldout > 0) model.set_heldout(heldout); if (l1 > 0) model.use_l1_regularizer(l1); else if (l2 > 0) model.use_l2_regularizer(l2); else if (sgd) model.use_SGD(); model.train(); string model_data = model.save_to_string(); vector< vector<string> > weights = export_weights(); List rs = List::create(model_data,weights[0],weights[1],weights[2]); return rs; }