std::unique_ptr<feature_selector> make_selector(const cpptoml::table& config, std::shared_ptr<index::forward_index> idx) { auto table = config.get_table("features"); if (!table) throw selector_factory_exception{ "[features] table missing from config file"}; auto prefix = table->get_as<std::string>("prefix"); if (!prefix) throw selector_factory_exception{"no prefix in [features] table"}; auto method = table->get_as<std::string>("method"); if (!method) throw selector_factory_exception{ "feature selection method required in [features] table"}; auto features_per_class = static_cast<uint64_t>( table->get_as<int64_t>("features-per-class").value_or(20)); auto selector = selector_factory::get().create(*method, *table, std::move(idx)); selector->init(features_per_class); // make_selector is a friend return selector; }
std::unique_ptr<classifier> make_classifier<dual_perceptron>(const cpptoml::table& config, multiclass_dataset_view training) { auto alpha = config.get_as<double>("alpha") .value_or(dual_perceptron::default_alpha); auto gamma = config.get_as<double>("gamma") .value_or(dual_perceptron::default_gamma); auto bias = config.get_as<double>("bias").value_or(dual_perceptron::default_bias); auto max_iter = config.get_as<int64_t>("max-iter") .value_or(dual_perceptron::default_max_iter); auto kernel_cfg = config.get_table("kernel"); if (!kernel_cfg) return make_unique<dual_perceptron>(std::move(training), make_unique<kernel::polynomial>(), alpha, gamma, bias, max_iter); return make_unique<dual_perceptron>(std::move(training), kernel::make_kernel(*kernel_cfg), alpha, gamma, bias, max_iter); }
std::unique_ptr<classifier> make_classifier<one_vs_all>(const cpptoml::table& config, std::shared_ptr<index::forward_index> idx) { auto base = config.get_table("base"); if (!base) throw classifier_factory::exception{ "one-vs-all missing base-classifier parameter in config file"}; return make_unique<one_vs_all>(idx, [&](class_label positive_label) { return make_binary_classifier(*base, idx, positive_label, class_label{"negative"}); }); }
std::unique_ptr<analyzer> load(const cpptoml::table& config) { using namespace analyzers; std::vector<std::unique_ptr<analyzer>> toks; auto analyzers = config.get_table_array("analyzers"); for (auto group : analyzers->get()) { auto method = group->get_as<std::string>("method"); if (!method) throw analyzer_exception{"failed to find analyzer method"}; toks.emplace_back( analyzer_factory::get().create(*method, config, *group)); } return make_unique<multi_analyzer>(std::move(toks)); }
metadata::schema_type metadata_schema(const cpptoml::table& config) { metadata::schema_type schema; if (auto metadata = config.get_table_array("metadata")) { const auto& arr = metadata->get(); schema.reserve(arr.size()); for (const auto& table : arr) { auto name = table->get_as<std::string>("name"); auto type = table->get_as<std::string>("type"); if (!name) throw metadata_exception{"name needed for metadata field"}; if (!type) throw metadata_exception{"type needed for metadata field"}; metadata::field_type ftype; if (*type == "int") { ftype = metadata::field_type::SIGNED_INT; } else if (*type == "uint") { ftype = metadata::field_type::UNSIGNED_INT; } else if (*type == "double") { ftype = metadata::field_type::DOUBLE; } else if (*type == "string") { ftype = metadata::field_type::STRING; } else { throw metadata_exception{"invalid metadata type: \"" + *type + "\""}; } schema.emplace_back(*name, ftype); } } return schema; }
language_model::language_model(const cpptoml::table& config) { auto table = config.get_table("language-model"); auto arpa_file = table->get_as<std::string>("arpa-file"); auto binary_file = table->get_as<std::string>("binary-file-prefix"); N_ = 0; if (binary_file && filesystem::file_exists(*binary_file + "0.binlm")) { LOG(info) << "Loading language model from binary files: " << *binary_file << "*" << ENDLG; auto time = common::time( [&]() { prefix_ = *binary_file; load_vocab(); while (filesystem::file_exists(*binary_file + std::to_string(N_) + ".binlm")) lm_.emplace_back(*binary_file + std::to_string(N_++) + ".binlm"); }); LOG(info) << "Done. (" << time.count() << "ms)" << ENDLG; } else if (arpa_file && binary_file) { LOG(info) << "Loading language model from .arpa file: " << *arpa_file << ENDLG; prefix_ = *binary_file; auto time = common::time([&]() { read_arpa_format(*arpa_file); }); LOG(info) << "Done. (" << time.count() << "ms)" << ENDLG; } else throw language_model_exception{ "arpa-file or binary-file-prefix needed in config file"}; // cache this value auto unk = vocabulary_.at("<unk>"); unk_node_ = *lm_[0].find(&unk, &unk + 1); }
std::unique_ptr<token_stream> load_filters(const cpptoml::table& global, const cpptoml::table& config) { auto check = config.get_as<std::string>("filter"); if (check) { if (*check == "default-chain") return default_filter_chain(global); else if (*check == "default-unigram-chain") return default_unigram_chain(global); else throw analyzer_exception{"unknown filter option: " + *check}; } auto filters = config.get_table_array("filter"); if (!filters) throw analyzer_exception{"analyzer group missing filter configuration"}; std::unique_ptr<token_stream> result; for (const auto filter : filters->get()) result = load_filter(std::move(result), *filter); return result; }
/** * Parses all sentences in a text file. */ void parse(const std::string& file, const cpptoml::table& config) { std::cout << "Running parser" << std::endl; auto seq_grp = config.get_table("sequence"); if (!seq_grp) { std::cerr << "[sequence] group needed in config file" << std::endl; return; } auto prefix = seq_grp->get_as<std::string>("prefix"); if (!prefix) { std::cerr << "[sequence] group needs a prefix key" << std::endl; return; } auto parser_grp = config.get_table("parser"); if (!parser_grp) { std::cerr << "[parser] group needed in config file" << std::endl; return; } auto parser_prefix = parser_grp->get_as<std::string>("prefix"); if (!parser_prefix) { std::cerr << "[parser] group needs a prefix key" << std::endl; return; } std::cout << "Loading tagging model" << std::endl; // load POS-tagging model sequence::perceptron tagger{*prefix}; std::cout << "Loading parser model" << std::endl; // load parser model parser::sr_parser parser{*parser_prefix}; // construct the token filter chain std::unique_ptr<analyzers::token_stream> stream = make_unique<analyzers::tokenizers::icu_tokenizer>(); stream = make_unique<analyzers::filters::ptb_normalizer>(std::move(stream)); stream->set_content(filesystem::file_text(file)); // parse each sentence in the file // and write its output to the output file auto out_name = no_ext(file) + ".parsed.txt"; std::ofstream outfile{out_name}; sequence::sequence seq; while (*stream) { auto token = stream->next(); if (token == "<s>") { seq = {}; } else if (token == "</s>") { tagger.tag(seq); parser.parse(seq).pretty_print(outfile); } else { seq.add_symbol(sequence::symbol_t{token}); } } std::cout << " -> file saved as " << out_name << std::endl; }
/** * Performs part-of-speech tagging on a text file. * @param file The input file * @param config Configuration settings * @param replace Whether or not to replace words with their POS tags */ void pos(const std::string& file, const cpptoml::table& config, bool replace) { std::cout << "Running POS-tagging with replace = " << std::boolalpha << replace << std::endl; auto seq_grp = config.get_table("sequence"); if (!seq_grp) { std::cerr << "[sequence] group needed in config file" << std::endl; return; } auto prefix = seq_grp->get_as<std::string>("prefix"); if (!prefix) { std::cerr << "[sequence] group needs a prefix key" << std::endl; return; } std::cout << "Loading tagging model" << std::endl; sequence::perceptron tagger{*prefix}; // construct the token filter chain std::unique_ptr<analyzers::token_stream> stream = make_unique<analyzers::tokenizers::icu_tokenizer>(); stream = make_unique<analyzers::filters::ptb_normalizer>(std::move(stream)); stream->set_content(filesystem::file_text(file)); // tag each sentence in the file // and write its output to the output file auto out_name = no_ext(file) + (replace ? ".pos-replace.txt" : ".pos-tagged.txt"); std::ofstream outfile{out_name}; sequence::sequence seq; while (*stream) { auto token = stream->next(); if (token == "<s>") { seq = {}; } else if (token == "</s>") { tagger.tag(seq); for (const auto& obs : seq) { if (replace) outfile << obs.tag() << " "; else outfile << obs.symbol() << "_" << obs.tag() << " "; } outfile << std::endl; } else { seq.add_symbol(sequence::symbol_t{token}); } } std::cout << " -> file saved as " << out_name << std::endl; }