void BilingualModel::trainChunk(const string& src_file, const string& trg_file, const vector<long long>& src_chunks, const vector<long long>& trg_chunks, int chunk_id) { ifstream src_infile(src_file); ifstream trg_infile(trg_file); try { check_is_open(src_infile, src_file); check_is_open(trg_infile, trg_file); check_is_non_empty(src_infile, src_file); check_is_non_empty(trg_infile, trg_file); } catch (...) { throw; } float starting_alpha = config->learning_rate; int max_iterations = config->iterations; long long training_words = src_model.training_words + trg_model.training_words; for (int k = 0; k < max_iterations; ++k) { int word_count = 0, last_count = 0; src_infile.clear(); trg_infile.clear(); src_infile.seekg(src_chunks[chunk_id], src_infile.beg); trg_infile.seekg(trg_chunks[chunk_id], trg_infile.beg); string src_sent, trg_sent; while (getline(src_infile, src_sent) && getline(trg_infile, trg_sent)) { word_count += trainSentence(src_sent, trg_sent); // update learning rate if (word_count - last_count > 10000) { words_processed += word_count - last_count; // asynchronous update last_count = word_count; alpha = starting_alpha * (1 - static_cast<float>(words_processed) / (max_iterations * training_words)); alpha = std::max(alpha, starting_alpha * 0.0001f); if (config->verbose) { printf("\rAlpha: %f Progress: %.2f%%", alpha, 100.0 * words_processed / (max_iterations * training_words)); fflush(stdout); } } // stop when reaching the end of a chunk if (chunk_id < src_chunks.size() - 1 && src_infile.tellg() >= src_chunks[chunk_id + 1]) break; } words_processed += word_count - last_count; } }
void MonolingualModel::trainChunk(const string& training_file, const vector<long long>& chunks, int chunk_id) { ifstream infile(training_file); float starting_alpha = config.starting_alpha; int max_iterations = config.max_iterations; if (!infile.is_open()) { throw runtime_error("couldn't open file " + training_file); } for (int k = 0; k < max_iterations; ++k) { int word_count = 0, last_count = 0; infile.clear(); infile.seekg(chunks[chunk_id], infile.beg); int chunk_size = training_lines / chunks.size(); int sent_id = chunk_id * chunk_size; string sent; while (getline(infile, sent)) { word_count += trainSentence(sent, sent_id++); // asynchronous update (possible race conditions) // update learning rate if (word_count - last_count > 10000) { words_processed += word_count - last_count; // asynchronous update last_count = word_count; if (!config.freeze) { // FIXME // decreasing learning rate alpha = starting_alpha * (1 - static_cast<float>(words_processed) / (max_iterations * training_words)); alpha = max(alpha, starting_alpha * 0.0001f); if (config.verbose) { printf("\rAlpha: %f Progress: %.2f%%", alpha, 100.0 * words_processed / (max_iterations * training_words)); fflush(stdout); } } } // stop when reaching the end of a chunk if (chunk_id < chunks.size() - 1 && infile.tellg() >= chunks[chunk_id + 1]) break; } words_processed += word_count - last_count; } }