int BilingualModel::trainSentence(const string& src_sent, const string& trg_sent) { auto src_nodes = src_model.getNodes(src_sent); // same size as src_sent, OOV words are replaced by <UNK> auto trg_nodes = trg_model.getNodes(trg_sent); // counts the number of words that are in the vocabulary int words = 0; words += src_nodes.size() - count(src_nodes.begin(), src_nodes.end(), HuffmanNode::UNK); words += trg_nodes.size() - count(trg_nodes.begin(), trg_nodes.end(), HuffmanNode::UNK); if (config->subsampling > 0) { src_model.subsample(src_nodes); // puts <UNK> tokens in place of the discarded tokens trg_model.subsample(trg_nodes); } if (src_nodes.empty() || trg_nodes.empty()) { return words; } // The <UNK> tokens are necessary to perform the alignment (the nodes vector should have the same size // as the original sentence) auto alignment = uniformAlignment(src_nodes, trg_nodes); // remove <UNK> tokens src_nodes.erase( std::remove(src_nodes.begin(), src_nodes.end(), HuffmanNode::UNK), src_nodes.end()); trg_nodes.erase( std::remove(trg_nodes.begin(), trg_nodes.end(), HuffmanNode::UNK), trg_nodes.end()); // Monolingual training for (int src_pos = 0; src_pos < src_nodes.size(); ++src_pos) { trainWord(src_model, src_model, src_nodes, src_nodes, src_pos, src_pos, alpha); } for (int trg_pos = 0; trg_pos < trg_nodes.size(); ++trg_pos) { trainWord(trg_model, trg_model, trg_nodes, trg_nodes, trg_pos, trg_pos, alpha); } if (config->beta == 0) return words; // Bilingual training for (int src_pos = 0; src_pos < src_nodes.size(); ++src_pos) { // 1-1 mapping between src_nodes and trg_nodes int trg_pos = alignment[src_pos]; if (trg_pos != -1) { // target word isn't OOV trainWord(src_model, trg_model, src_nodes, trg_nodes, src_pos, trg_pos, alpha * config->beta); trainWord(trg_model, src_model, trg_nodes, src_nodes, trg_pos, src_pos, alpha * config->beta); } } return words; // returns the number of words processed (for progress estimation) }
int MonolingualModel::trainSentence(const string& sent, int sent_id) { auto nodes = getNodes(sent); // same size as sent, OOV words are replaced by <UNK> // counts the number of words that are in the vocabulary int words = nodes.size() - count(nodes.begin(), nodes.end(), HuffmanNode::UNK); if (config.subsampling > 0) { subsample(nodes); // puts <UNK> tokens in place of the discarded tokens } if (nodes.empty()) { return words; } // remove <UNK> tokens nodes.erase( remove(nodes.begin(), nodes.end(), HuffmanNode::UNK), nodes.end()); // Monolingual training for (int pos = 0; pos < nodes.size(); ++pos) { trainWord(nodes, pos, sent_id); } return words; // returns the number of words processed, for progress estimation }