typename trainer_type::trained_function_type train2 ( const trainer_type& trainer, const std::vector<ranking_pair<typename trainer_type::sample_type> >& samples ) { pyassert(is_ranking_problem(samples), "Invalid inputs"); return trainer.train(samples); }
typename trainer_type::trained_function_type train1 ( const trainer_type& trainer, const ranking_pair<typename trainer_type::sample_type>& sample ) { typedef ranking_pair<typename trainer_type::sample_type> st; pyassert(is_ranking_problem(std::vector<st>(1, sample)), "Invalid inputs"); return trainer.train(sample); }
typename trainer_type::trained_function_type train ( const trainer_type& trainer, const std::vector<typename trainer_type::sample_type>& samples, const std::vector<double>& labels ) { pyassert(is_binary_classification_problem(samples,labels), "Invalid inputs"); return trainer.train(samples, labels); }
void train_batch(IIT input_first, IIT input_last, EIT expected_first, EIT expected_last, trainer_type& trainer, rbm_training_context& context, rbm_t& rbm) { ++batches; auto input_batch = make_batch(input_first, input_last); auto expected_batch = make_batch(expected_first, expected_last); trainer->train_batch(input_batch, expected_batch, context); context.reconstruction_error += context.batch_error; context.sparsity += context.batch_sparsity; cpp::static_if<EnableWatcher && layer_traits<rbm_t>::free_energy()>([&](auto f) { for (auto& v : input_batch) { context.free_energy += f(rbm).free_energy(v); } }); if (EnableWatcher && layer_traits<rbm_t>::is_verbose()) { watcher.batch_end(rbm, context, batches, total_batches); } }
double get_c (const trainer_type& trainer) { return trainer.get_c(); }
void set_c ( trainer_type& trainer, double C) { pyassert(C > 0, "C must be > 0"); trainer.set_c(C); }
double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); }
void set_epsilon ( trainer_type& trainer, double eps) { pyassert(eps > 0, "epsilon must be > 0"); trainer.set_epsilon(eps); }
double get_c_class2 ( const trainer_type& trainer) { return trainer.get_c_class2(); }
long get_cache_size ( const trainer_type& trainer) { return trainer.get_cache_size(); }
void set_cache_size ( trainer_type& trainer, long cache_size) { pyassert(cache_size > 0, "cache size must be > 0"); trainer.set_cache_size(cache_size); }