示例#1
0
// Sets the mapped_features_ from the features_ using the provided
// feature_space to the indexed versions of the features.
void TrainingSample::IndexFeatures(const IntFeatureSpace& feature_space) {
  GenericVector<int> indexed_features;
  feature_space.IndexAndSortFeatures(features_, num_features_,
                                     &mapped_features_);
  features_are_indexed_ = true;
  features_are_mapped_ = false;
}
// Display the samples with the given indexed feature that also match
// the given shape.
    void TrainingSampleSet::DisplaySamplesWithFeature(int f_index,
                                                      const Shape &shape,
                                                      const IntFeatureSpace &space,
                                                      ScrollView::Color color,
                                                      ScrollView *window) const {
        for (int s = 0; s < num_raw_samples(); ++s) {
            const TrainingSample *sample = GetSample(s);
            if (shape.ContainsUnichar(sample->class_id())) {
                GenericVector <int> indexed_features;
                space.IndexAndSortFeatures(sample->features(), sample->num_features(),
                                           &indexed_features);
                for (int f = 0; f < indexed_features.size(); ++f) {
                    if (indexed_features[f] == f_index) {
                        sample->DisplayFeatures(color, window);
                    }
                }
            }
        }
    }
示例#3
0
/**
 * Creates a MasterTraininer and loads the training data into it:
 * Initializes feature_defs and IntegerFX.
 * Loads the shape_table if shape_table != nullptr.
 * Loads initial unicharset from -U command-line option.
 * If FLAGS_T is set, loads the majority of data from there, else:
 *  - Loads font info from -F option.
 *  - Loads xheights from -X option.
 *  - Loads samples from .tr files in remaining command-line args.
 *  - Deletes outliers and computes canonical samples.
 *  - If FLAGS_output_trainer is set, saves the trainer for future use.
 * Computes canonical and cloud features.
 * If shape_table is not nullptr, but failed to load, make a fake flat one,
 * as shape clustering was not run.
 */
MasterTrainer* LoadTrainingData(int argc, const char* const * argv,
                                bool replication,
                                ShapeTable** shape_table,
                                STRING* file_prefix) {
  InitFeatureDefs(&feature_defs);
  InitIntegerFX();
  *file_prefix = "";
  if (!FLAGS_D.empty()) {
    *file_prefix += FLAGS_D.c_str();
    *file_prefix += "/";
  }
  // If we are shape clustering (nullptr shape_table) or we successfully load
  // a shape_table written by a previous shape clustering, then
  // shape_analysis will be true, meaning that the MasterTrainer will replace
  // some members of the unicharset with their fragments.
  bool shape_analysis = false;
  if (shape_table != nullptr) {
    *shape_table = LoadShapeTable(*file_prefix);
    if (*shape_table != nullptr) shape_analysis = true;
  } else {
    shape_analysis = true;
  }
  MasterTrainer* trainer = new MasterTrainer(NM_CHAR_ANISOTROPIC,
                                             shape_analysis,
                                             replication,
                                             FLAGS_debug_level);
  IntFeatureSpace fs;
  fs.Init(kBoostXYBuckets, kBoostXYBuckets, kBoostDirBuckets);
  if (FLAGS_T.empty()) {
    trainer->LoadUnicharset(FLAGS_U.c_str());
    // Get basic font information from font_properties.
    if (!FLAGS_F.empty()) {
      if (!trainer->LoadFontInfo(FLAGS_F.c_str())) {
        delete trainer;
        return nullptr;
      }
    }
    if (!FLAGS_X.empty()) {
      if (!trainer->LoadXHeights(FLAGS_X.c_str())) {
        delete trainer;
        return nullptr;
      }
    }
    trainer->SetFeatureSpace(fs);
    const char* page_name;
    // Load training data from .tr files on the command line.
    while ((page_name = GetNextFilename(argc, argv)) != nullptr) {
      tprintf("Reading %s ...\n", page_name);
      trainer->ReadTrainingSamples(page_name, feature_defs, false);

      // If there is a file with [lang].[fontname].exp[num].fontinfo present,
      // read font spacing information in to fontinfo_table.
      int pagename_len = strlen(page_name);
      char *fontinfo_file_name = new char[pagename_len + 7];
      strncpy(fontinfo_file_name, page_name, pagename_len - 2);  // remove "tr"
      strcpy(fontinfo_file_name + pagename_len - 2, "fontinfo");  // +"fontinfo"
      trainer->AddSpacingInfo(fontinfo_file_name);
      delete[] fontinfo_file_name;

      // Load the images into memory if required by the classifier.
      if (FLAGS_load_images) {
        STRING image_name = page_name;
        // Chop off the tr and replace with tif. Extension must be tif!
        image_name.truncate_at(image_name.length() - 2);
        image_name += "tif";
        trainer->LoadPageImages(image_name.string());
      }
    }
    trainer->PostLoadCleanup();
    // Write the master trainer if required.
    if (!FLAGS_output_trainer.empty()) {
      FILE* fp = fopen(FLAGS_output_trainer.c_str(), "wb");
      if (fp == nullptr) {
        tprintf("Can't create saved trainer data!\n");
      } else {
        trainer->Serialize(fp);
        fclose(fp);
      }
    }
  } else {
    bool success = false;
    tprintf("Loading master trainer from file:%s\n",
            FLAGS_T.c_str());
    FILE* fp = fopen(FLAGS_T.c_str(), "rb");
    if (fp == nullptr) {
      tprintf("Can't read file %s to initialize master trainer\n",
              FLAGS_T.c_str());
    } else {
      success = trainer->DeSerialize(false, fp);
      fclose(fp);
    }
    if (!success) {
      tprintf("Deserialize of master trainer failed!\n");
      delete trainer;
      return nullptr;
    }
    trainer->SetFeatureSpace(fs);
  }
  trainer->PreTrainingSetup();
  if (!FLAGS_O.empty() &&
      !trainer->unicharset().save_to_file(FLAGS_O.c_str())) {
    fprintf(stderr, "Failed to save unicharset to file %s\n", FLAGS_O.c_str());
    delete trainer;
    return nullptr;
  }
  if (shape_table != nullptr) {
    // If we previously failed to load a shapetable, then shape clustering
    // wasn't run so make a flat one now.
    if (*shape_table == nullptr) {
      *shape_table = new ShapeTable;
      trainer->SetupFlatShapeTable(*shape_table);
      tprintf("Flat shape table summary: %s\n",
              (*shape_table)->SummaryStr().string());
    }
    (*shape_table)->set_unicharset(trainer->unicharset());
  }
  return trainer;
}
// Delete outlier samples with few features that are shared with others.
// IndexFeatures must have been called already.
    void TrainingSampleSet::DeleteOutliers(const IntFeatureSpace &feature_space,
                                           bool debug) {
        if (font_class_array_ == NULL)
            OrganizeByFontAndClass();
        Pixa *pixa = NULL;
        if (debug)
            pixa = pixaCreate(0);
        GenericVector <int> feature_counts;
        int fs_size = feature_space.Size();
        int font_size = font_id_map_.CompactSize();
        for (int font_index = 0; font_index < font_size; ++font_index) {
            for (int c = 0; c < unicharset_size_; ++c) {
                // Create a histogram of the features used by all samples of this
                // font/class combination.
                feature_counts.init_to_size(fs_size, 0);
                FontClassInfo &fcinfo = (*font_class_array_)(font_index, c);
                int sample_count = fcinfo.samples.size();
                if (sample_count < kMinOutlierSamples)
                    continue;
                for (int i = 0; i < sample_count; ++i) {
                    int s = fcinfo.samples[i];
                    const GenericVector <int> &features = samples_[s]->indexed_features();
                    for (int f = 0; f < features.size(); ++f) {
                        ++feature_counts[features[f]];
                    }
                }
                for (int i = 0; i < sample_count; ++i) {
                    int s = fcinfo.samples[i];
                    const TrainingSample &sample = *samples_[s];
                    const GenericVector <int> &features = sample.indexed_features();
                    // A feature that has a histogram count of 1 is only used by this
                    // sample, making it 'bad'. All others are 'good'.
                    int good_features = 0;
                    int bad_features = 0;
                    for (int f = 0; f < features.size(); ++f) {
                        if (feature_counts[features[f]] > 1)
                            ++good_features;
                        else
                            ++bad_features;
                    }
                    // If more than 1/3 features are bad, then this is an outlier.
                    if (bad_features * 2 > good_features) {
                        tprintf("Deleting outlier sample of %s, %d good, %d bad\n",
                                SampleToString(sample).string(),
                                good_features, bad_features);
                        if (debug) {
                            pixaAddPix(pixa, sample.RenderToPix(&unicharset_), L_INSERT);
                            // Add the previous sample as well, so it is easier to see in
                            // the output what is wrong with this sample.
                            int t;
                            if (i == 0)
                                t = fcinfo.samples[1];
                            else
                                t = fcinfo.samples[i - 1];
                            const TrainingSample &csample = *samples_[t];
                            pixaAddPix(pixa, csample.RenderToPix(&unicharset_), L_INSERT);
                        }
                        // Mark the sample for deletion.
                        KillSample(samples_[s]);
                    }
                }
            }
        }
        // Truly delete all bad samples and renumber everything.
        DeleteDeadSamples();
        if (pixa != NULL) {
            Pix *pix = pixaDisplayTiledInRows(pixa, 1, 2600, 1.0, 0, 10, 10);
            pixaDestroy(&pixa);
            pixWrite("outliers.png", pix, IFF_PNG);
            pixDestroy(&pix);
        }
    }