示例#1
0
// Returns a string corresponding to a given single label id, falling back to
// a default of ".." for part of a multi-label unichar-id.
const char* LSTMRecognizer::DecodeSingleLabel(int label) {
  if (label == null_char_) return "<null>";
  if (IsRecoding()) {
    // Decode label via recoder_.
    RecodedCharID code;
    code.Set(0, label);
    label = recoder_.DecodeUnichar(code);
    if (label == INVALID_UNICHAR_ID) return "..";  // Part of a bigger code.
  }
  if (label == UNICHAR_SPACE) return " ";
  return GetUnicharset().get_normed_unichar(label);
}
示例#2
0
// Returns a string corresponding to the label starting at start. Sets *end
// to the next start and if non-null, *decoded to the unichar id.
const char* LSTMRecognizer::DecodeLabel(const GenericVector<int>& labels,
                                        int start, int* end, int* decoded) {
  *end = start + 1;
  if (IsRecoding()) {
    // Decode labels via recoder_.
    RecodedCharID code;
    if (labels[start] == null_char_) {
      if (decoded != NULL) {
        code.Set(0, null_char_);
        *decoded = recoder_.DecodeUnichar(code);
      }
      return "<null>";
    }
    int index = start;
    while (index < labels.size() &&
           code.length() < RecodedCharID::kMaxCodeLen) {
      code.Set(code.length(), labels[index++]);
      while (index < labels.size() && labels[index] == null_char_) ++index;
      int uni_id = recoder_.DecodeUnichar(code);
      // If the next label isn't a valid first code, then we need to continue
      // extending even if we have a valid uni_id from this prefix.
      if (uni_id != INVALID_UNICHAR_ID &&
          (index == labels.size() ||
           code.length() == RecodedCharID::kMaxCodeLen ||
           recoder_.IsValidFirstCode(labels[index]))) {
        *end = index;
        if (decoded != NULL) *decoded = uni_id;
        if (uni_id == UNICHAR_SPACE) return " ";
        return GetUnicharset().get_normed_unichar(uni_id);
      }
    }
    return "<Undecodable>";
  } else {
    if (decoded != NULL) *decoded = labels[start];
    if (labels[start] == null_char_) return "<null>";
    if (labels[start] == UNICHAR_SPACE) return " ";
    return GetUnicharset().get_normed_unichar(labels[start]);
  }
}
示例#3
0
// Prints debug information on the results.
void ShapeClassifier::UnicharPrintResults(
    const char* context, const GenericVector<UnicharRating>& results) const {
  tprintf("%s\n", context);
  for (int i = 0; i < results.size(); ++i) {
    tprintf("%g: c_id=%d=%s", results[i].rating, results[i].unichar_id,
            GetUnicharset().id_to_unichar(results[i].unichar_id));
    if (results[i].fonts.size() != 0) {
      tprintf(" Font Vector:");
      for (int f = 0; f < results[i].fonts.size(); ++f) {
        tprintf(" %d", results[i].fonts[f].fontinfo_id);
      }
    }
    tprintf("\n");
  }
}
示例#4
0
// Loads the Recoder.
bool LSTMRecognizer::LoadRecoder(TFile* fp) {
  if (IsRecoding()) {
    if (!recoder_.DeSerialize(fp)) return false;
    RecodedCharID code;
    recoder_.EncodeUnichar(UNICHAR_SPACE, &code);
    if (code(0) != UNICHAR_SPACE) {
      tprintf("Space was garbled in recoding!!\n");
      return false;
    }
  } else {
    recoder_.SetupPassThrough(GetUnicharset());
    training_flags_ |= TF_COMPRESS_UNICHARSET;
  }
  return true;
}
示例#5
0
// Recognizes the line image, contained within image_data, returning the
// ratings matrix and matching box_word for each WERD_RES in the output.
void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
                                   bool debug, double worst_dict_cert,
                                   const TBOX& line_box,
                                   PointerVector<WERD_RES>* words) {
  NetworkIO outputs;
  float scale_factor;
  NetworkIO inputs;
  if (!RecognizeLine(image_data, invert, debug, false, &scale_factor, &inputs,
                     &outputs))
    return;
  if (search_ == NULL) {
    search_ =
        new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
  }
  search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL);
  search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
                                  &GetUnicharset(), words);
}
示例#6
0
// Writes to the given file. Returns false in case of error.
bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const {
  bool include_charsets = mgr == nullptr ||
                          !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
                          !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
  if (!network_->Serialize(fp)) return false;
  if (include_charsets && !GetUnicharset().save_to_file(fp)) return false;
  if (!network_str_.Serialize(fp)) return false;
  if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1)
    return false;
  if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1)
    return false;
  if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
    return false;
  if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false;
  if (fp->FWrite(&weight_range_, sizeof(weight_range_), 1) != 1) return false;
  if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
  if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false;
  if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false;
  return true;
}
示例#7
0
// Visual debugger classifies the given sample, displays the results and
// solicits user input to display other classifications. Returns when
// the user has finished with debugging the sample.
// Probably doesn't need to be overridden if the subclass provides
// DisplayClassifyAs.
void ShapeClassifier::DebugDisplay(const TrainingSample& sample,
                                   Pix* page_pix,
                                   UNICHAR_ID unichar_id) {
#ifndef GRAPHICS_DISABLED
  static ScrollView* terminator = NULL;
  if (terminator == NULL) {
    terminator = new ScrollView("XIT", 0, 0, 50, 50, 50, 50, true);
  }
  ScrollView* debug_win = CreateFeatureSpaceWindow("ClassifierDebug", 0, 0);
  // Provide a right-click menu to choose the class.
  SVMenuNode* popup_menu = new SVMenuNode();
  popup_menu->AddChild("Choose class to debug", 0, "x", "Class to debug");
  popup_menu->BuildMenu(debug_win, false);
  // Display the features in green.
  const INT_FEATURE_STRUCT* features = sample.features();
  int num_features = sample.num_features();
  for (int f = 0; f < num_features; ++f) {
    RenderIntFeature(debug_win, &features[f], ScrollView::GREEN);
  }
  debug_win->Update();
  GenericVector<UnicharRating> results;
  // Debug classification until the user quits.
  const UNICHARSET& unicharset = GetUnicharset();
  SVEvent* ev;
  SVEventType ev_type;
  do {
    PointerVector<ScrollView> windows;
    if (unichar_id >= 0) {
      tprintf("Debugging class %d = %s\n",
              unichar_id, unicharset.id_to_unichar(unichar_id));
      UnicharClassifySample(sample, page_pix, 1, unichar_id, &results);
      DisplayClassifyAs(sample, page_pix, unichar_id, 1, &windows);
    } else {
      tprintf("Invalid unichar_id: %d\n", unichar_id);
      UnicharClassifySample(sample, page_pix, 1, -1, &results);
    }
    if (unichar_id >= 0) {
      tprintf("Debugged class %d = %s\n",
              unichar_id, unicharset.id_to_unichar(unichar_id));
    }
    tprintf("Right-click in ClassifierDebug window to choose debug class,");
    tprintf(" Left-click or close window to quit...\n");
    UNICHAR_ID old_unichar_id;
    do {
      old_unichar_id = unichar_id;
      ev = debug_win->AwaitEvent(SVET_ANY);
      ev_type = ev->type;
      if (ev_type == SVET_POPUP) {
        if (unicharset.contains_unichar(ev->parameter)) {
          unichar_id = unicharset.unichar_to_id(ev->parameter);
        } else {
          tprintf("Char class '%s' not found in unicharset", ev->parameter);
        }
      }
      delete ev;
    } while (unichar_id == old_unichar_id &&
             ev_type != SVET_CLICK && ev_type != SVET_DESTROY);
  } while (ev_type != SVET_CLICK && ev_type != SVET_DESTROY);
  delete debug_win;
#endif  // GRAPHICS_DISABLED
}