示例#1
0
Conf& Conf::def_choice(std::string name,
                       std::vector<std::string> choices,
                       std::string default_value) {
    assert2(in_vector(choices, default_value),
        MS() << default_value << " is not an option for " << name);
    assert2(choices.size() >= 2,
        MS() << "At least two choices are needed for " << name);
    auto c = make_shared<Choice>();
    c->choices = choices;
    c->default_value = default_value;
    c->value = default_value;


    items[name] = c;
    return *this;
}
示例#2
0
Conf& Conf::def_int(std::string name,
            int lower_bound,
            int upper_bound,
            int default_value) {
    assert2(lower_bound <= default_value && default_value <= upper_bound,
                MS() << "Default value for " << name << "not in range.");
    auto i = make_shared<Int>();
    i->lower_bound = lower_bound;
    i->upper_bound = upper_bound;
    i->default_value = default_value;
    i->value = default_value;

    items[name] = i;
    return *this;
}
示例#3
0
Conf& Conf::def_float(std::string name,
            double lower_bound,
            double upper_bound,
            double default_value) {
    assert2(lower_bound <= default_value && default_value <= upper_bound,
                MS() << "Default value for " << name << "not in range.");
    auto f = make_shared<Float>();
    f->lower_bound = lower_bound;
    f->upper_bound = upper_bound;
    f->default_value = default_value;
    f->value = default_value;

    items[name] = f;
    return *this;
}
示例#4
0
void training_loop(std::shared_ptr<Solver::AbstractSolver<REAL_t>> solver,
                   model_t& model,
                   std::function<vector<uint>(vector<uint>&)> pred_fun,
                   vector<numeric_example_t>& train,
                   vector<numeric_example_t>& validate) {
    auto& vocab = arithmetic::vocabulary;

    auto params = model.parameters();

    int epoch = 0;
    int difficulty_waiting = 0;
    auto end_symbol_idx = vocab.word2index[utils::end_symbol];

    int beam_width = FLAGS_beam_width;

    if (beam_width < 1)
        utils::exit_with_message(MS() << "Beam width must be strictly positive (got " << beam_width << ")");

    Throttled throttled_examples;
    Throttled throttled_validation;

    bool target_accuracy_reached = false;

    while (!target_accuracy_reached && epoch++ < FLAGS_graduation_time) {

        auto indices = utils::random_arange(train.size());
        auto indices_begin = indices.begin();

        REAL_t minibatch_error = 0.0;

        // one minibatch
        for (auto indices_begin = indices.begin();
                indices_begin < indices.begin() + std::min((size_t)FLAGS_minibatch, train.size());
                indices_begin++) {
            // <training>
            auto& example = train[*indices_begin];

            auto error = model.error(example, beam_width);
            error.grad();
            graph::backward();
            minibatch_error += error.w(0);
            // </training>
            // // <reporting>
            throttled_examples.maybe_run(seconds(10), [&]() {
                graph::NoBackprop nb;
                auto random_example_index = utils::randint(0, validate.size() -1);
                auto& expression = validate[random_example_index].first;
                auto predictions = model.predict(expression,
                                                 beam_width,
                                                 MAX_OUTPUT_LENGTH,
                                                 vocab.word2index.at(utils::end_symbol));

                auto expression_string = arithmetic::vocabulary.decode(&expression);
                if (expression_string.back() == utils::end_symbol)
                    expression_string.resize(expression_string.size() - 1);
                std::cout << utils::join(expression_string) << std::endl;


                vector<string> prediction_string;
                vector<double> prediction_probability;

                for (auto& prediction : predictions) {
                    if (validate[random_example_index].second == prediction.prediction) {
                        std::cout << utils::green;
                    }
                    prediction_probability.push_back(prediction.get_probability().w(0));
                    std::cout << "= (" << std::setprecision( 3 ) << prediction.get_probability().log().w(0) << ") ";
                    auto digits = vocab.decode(&prediction.prediction);
                    if (digits.back() == utils::end_symbol)
                        digits.pop_back();
                    auto joined_digits = utils::join(digits);
                    prediction_string.push_back(joined_digits);
                    std::cout << joined_digits << utils::reset_color << std::endl;
                }
                auto vgrid = make_shared<visualizable::GridLayout>();

                assert2(predictions[0].derivations.size() == predictions[0].nodes.size(),
                        "Szymon messed up.");
                for (int didx = 0;
                        didx < min((size_t)FLAGS_visualizer_trees, predictions[0].derivations.size());
                        ++didx) {
                    auto visualization = visualize_derivation(
                                             predictions[0].derivations[didx],
                                             vocab.decode(&expression)
                                         );
                    auto tree_prob = predictions[0].nodes[didx].log_probability.exp().w(0,0);
                    vgrid->add_in_column(0, make_shared<visualizable::Probability<double>>(tree_prob));
                    vgrid->add_in_column(0, visualization);
                }
                vgrid->add_in_column(1, make_shared<visualizable::Sentence<double>>(expression_string));
                vgrid->add_in_column(1, make_shared<visualizable::FiniteDistribution<double>>(
                                         prediction_probability,
                                         prediction_string
                                     ));

                if (visualizer)
                    visualizer->feed(vgrid->to_json());

            });
            double current_accuracy = -1;
            throttled_validation.maybe_run(seconds(30), [&]() {
                current_accuracy = arithmetic::average_recall(validate, pred_fun, FLAGS_j);
                std::cout << "epoch: " << epoch << ", accuracy = " << std::setprecision( 3 )
                          << 100.0 * current_accuracy << "%" << std::endl;
            });
            if (current_accuracy != -1 && current_accuracy > 0.9) {
                std::cout << "Current accuracy is now " << current_accuracy << std::endl;
                target_accuracy_reached = true;
                break;
            }
            // </reporting>
        }
        solver->step(params); // One step of gradient descent
        epoch++;
    }
}