Пример #1
0
typename trainer_type::trained_function_type train2 (
    const trainer_type& trainer,
    const std::vector<ranking_pair<typename trainer_type::sample_type> >& samples
)
{
    pyassert(is_ranking_problem(samples), "Invalid inputs");
    return trainer.train(samples);
}
Пример #2
0
typename trainer_type::trained_function_type train1 (
    const trainer_type& trainer,
    const ranking_pair<typename trainer_type::sample_type>& sample
)
{
    typedef ranking_pair<typename trainer_type::sample_type> st;
    pyassert(is_ranking_problem(std::vector<st>(1, sample)), "Invalid inputs");
    return trainer.train(sample);
}
typename trainer_type::trained_function_type train (
    const trainer_type& trainer,
    const std::vector<typename trainer_type::sample_type>& samples,
    const std::vector<double>& labels
)
{
    pyassert(is_binary_classification_problem(samples,labels), "Invalid inputs");
    return trainer.train(samples, labels);
}
Пример #4
0
    void train_batch(IIT input_first, IIT input_last, EIT expected_first, EIT expected_last, trainer_type& trainer, rbm_training_context& context, rbm_t& rbm) {
        ++batches;

        auto input_batch    = make_batch(input_first, input_last);
        auto expected_batch = make_batch(expected_first, expected_last);
        trainer->train_batch(input_batch, expected_batch, context);

        context.reconstruction_error += context.batch_error;
        context.sparsity += context.batch_sparsity;

        cpp::static_if<EnableWatcher && layer_traits<rbm_t>::free_energy()>([&](auto f) {
            for (auto& v : input_batch) {
                context.free_energy += f(rbm).free_energy(v);
            }
        });

        if (EnableWatcher && layer_traits<rbm_t>::is_verbose()) {
            watcher.batch_end(rbm, context, batches, total_batches);
        }
    }
Пример #5
0
double get_c (const trainer_type& trainer)
{
    return trainer.get_c();
}
Пример #6
0
void set_c ( trainer_type& trainer, double C)
{
    pyassert(C > 0, "C must be > 0");
    trainer.set_c(C);
}
Пример #7
0
double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); }
Пример #8
0
void set_epsilon ( trainer_type& trainer, double eps)
{
    pyassert(eps > 0, "epsilon must be > 0");
    trainer.set_epsilon(eps);
}
double get_c_class2 ( const trainer_type& trainer) { return trainer.get_c_class2(); }
long get_cache_size ( const trainer_type& trainer) { return trainer.get_cache_size(); }
void set_cache_size ( trainer_type& trainer, long cache_size)
{
    pyassert(cache_size > 0, "cache size must be > 0");
    trainer.set_cache_size(cache_size);
}