// Returns a string corresponding to a given single label id, falling back to // a default of ".." for part of a multi-label unichar-id. const char* LSTMRecognizer::DecodeSingleLabel(int label) { if (label == null_char_) return "<null>"; if (IsRecoding()) { // Decode label via recoder_. RecodedCharID code; code.Set(0, label); label = recoder_.DecodeUnichar(code); if (label == INVALID_UNICHAR_ID) return ".."; // Part of a bigger code. } if (label == UNICHAR_SPACE) return " "; return GetUnicharset().get_normed_unichar(label); }
// Returns a string corresponding to the label starting at start. Sets *end // to the next start and if non-null, *decoded to the unichar id. const char* LSTMRecognizer::DecodeLabel(const GenericVector<int>& labels, int start, int* end, int* decoded) { *end = start + 1; if (IsRecoding()) { // Decode labels via recoder_. RecodedCharID code; if (labels[start] == null_char_) { if (decoded != NULL) { code.Set(0, null_char_); *decoded = recoder_.DecodeUnichar(code); } return "<null>"; } int index = start; while (index < labels.size() && code.length() < RecodedCharID::kMaxCodeLen) { code.Set(code.length(), labels[index++]); while (index < labels.size() && labels[index] == null_char_) ++index; int uni_id = recoder_.DecodeUnichar(code); // If the next label isn't a valid first code, then we need to continue // extending even if we have a valid uni_id from this prefix. if (uni_id != INVALID_UNICHAR_ID && (index == labels.size() || code.length() == RecodedCharID::kMaxCodeLen || recoder_.IsValidFirstCode(labels[index]))) { *end = index; if (decoded != NULL) *decoded = uni_id; if (uni_id == UNICHAR_SPACE) return " "; return GetUnicharset().get_normed_unichar(uni_id); } } return "<Undecodable>"; } else { if (decoded != NULL) *decoded = labels[start]; if (labels[start] == null_char_) return "<null>"; if (labels[start] == UNICHAR_SPACE) return " "; return GetUnicharset().get_normed_unichar(labels[start]); } }
// Prints debug information on the results. void ShapeClassifier::UnicharPrintResults( const char* context, const GenericVector<UnicharRating>& results) const { tprintf("%s\n", context); for (int i = 0; i < results.size(); ++i) { tprintf("%g: c_id=%d=%s", results[i].rating, results[i].unichar_id, GetUnicharset().id_to_unichar(results[i].unichar_id)); if (results[i].fonts.size() != 0) { tprintf(" Font Vector:"); for (int f = 0; f < results[i].fonts.size(); ++f) { tprintf(" %d", results[i].fonts[f].fontinfo_id); } } tprintf("\n"); } }
// Loads the Recoder. bool LSTMRecognizer::LoadRecoder(TFile* fp) { if (IsRecoding()) { if (!recoder_.DeSerialize(fp)) return false; RecodedCharID code; recoder_.EncodeUnichar(UNICHAR_SPACE, &code); if (code(0) != UNICHAR_SPACE) { tprintf("Space was garbled in recoding!!\n"); return false; } } else { recoder_.SetupPassThrough(GetUnicharset()); training_flags_ |= TF_COMPRESS_UNICHARSET; } return true; }
// Recognizes the line image, contained within image_data, returning the // ratings matrix and matching box_word for each WERD_RES in the output. void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert, bool debug, double worst_dict_cert, const TBOX& line_box, PointerVector<WERD_RES>* words) { NetworkIO outputs; float scale_factor; NetworkIO inputs; if (!RecognizeLine(image_data, invert, debug, false, &scale_factor, &inputs, &outputs)) return; if (search_ == NULL) { search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_); } search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL); search_->ExtractBestPathAsWords(line_box, scale_factor, debug, &GetUnicharset(), words); }
// Writes to the given file. Returns false in case of error. bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const { bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) || !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET); if (!network_->Serialize(fp)) return false; if (include_charsets && !GetUnicharset().save_to_file(fp)) return false; if (!network_str_.Serialize(fp)) return false; if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1) return false; if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1) return false; if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) return false; if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false; if (fp->FWrite(&weight_range_, sizeof(weight_range_), 1) != 1) return false; if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false; if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false; return true; }
// Visual debugger classifies the given sample, displays the results and // solicits user input to display other classifications. Returns when // the user has finished with debugging the sample. // Probably doesn't need to be overridden if the subclass provides // DisplayClassifyAs. void ShapeClassifier::DebugDisplay(const TrainingSample& sample, Pix* page_pix, UNICHAR_ID unichar_id) { #ifndef GRAPHICS_DISABLED static ScrollView* terminator = NULL; if (terminator == NULL) { terminator = new ScrollView("XIT", 0, 0, 50, 50, 50, 50, true); } ScrollView* debug_win = CreateFeatureSpaceWindow("ClassifierDebug", 0, 0); // Provide a right-click menu to choose the class. SVMenuNode* popup_menu = new SVMenuNode(); popup_menu->AddChild("Choose class to debug", 0, "x", "Class to debug"); popup_menu->BuildMenu(debug_win, false); // Display the features in green. const INT_FEATURE_STRUCT* features = sample.features(); int num_features = sample.num_features(); for (int f = 0; f < num_features; ++f) { RenderIntFeature(debug_win, &features[f], ScrollView::GREEN); } debug_win->Update(); GenericVector<UnicharRating> results; // Debug classification until the user quits. const UNICHARSET& unicharset = GetUnicharset(); SVEvent* ev; SVEventType ev_type; do { PointerVector<ScrollView> windows; if (unichar_id >= 0) { tprintf("Debugging class %d = %s\n", unichar_id, unicharset.id_to_unichar(unichar_id)); UnicharClassifySample(sample, page_pix, 1, unichar_id, &results); DisplayClassifyAs(sample, page_pix, unichar_id, 1, &windows); } else { tprintf("Invalid unichar_id: %d\n", unichar_id); UnicharClassifySample(sample, page_pix, 1, -1, &results); } if (unichar_id >= 0) { tprintf("Debugged class %d = %s\n", unichar_id, unicharset.id_to_unichar(unichar_id)); } tprintf("Right-click in ClassifierDebug window to choose debug class,"); tprintf(" Left-click or close window to quit...\n"); UNICHAR_ID old_unichar_id; do { old_unichar_id = unichar_id; ev = debug_win->AwaitEvent(SVET_ANY); ev_type = ev->type; if (ev_type == SVET_POPUP) { if (unicharset.contains_unichar(ev->parameter)) { unichar_id = unicharset.unichar_to_id(ev->parameter); } else { tprintf("Char class '%s' not found in unicharset", ev->parameter); } } delete ev; } while (unichar_id == old_unichar_id && ev_type != SVET_CLICK && ev_type != SVET_DESTROY); } while (ev_type != SVET_CLICK && ev_type != SVET_DESTROY); delete debug_win; #endif // GRAPHICS_DISABLED }