// Generates a TrainingSample from a TBLOB. Extracts features and sets // the bounding box, so classifiers that operate on the image can work. // TODO(rays) Make BlobToTrainingSample a member of Classify now that // the FlexFx and FeatureDescription code have been removed and LearnBlob // is now a member of Classify. TrainingSample* BlobToTrainingSample( const TBLOB& blob, bool nonlinear_norm, INT_FX_RESULT_STRUCT* fx_info, GenericVector<INT_FEATURE_STRUCT>* bl_features) { GenericVector<INT_FEATURE_STRUCT> cn_features; Classify::ExtractFeatures(blob, nonlinear_norm, bl_features, &cn_features, fx_info, nullptr); // TODO(rays) Use blob->PreciseBoundingBox() instead. TBOX box = blob.bounding_box(); TrainingSample* sample = nullptr; int num_features = fx_info->NumCN; if (num_features > 0) { sample = TrainingSample::CopyFromFeatures(*fx_info, box, &cn_features[0], num_features); } if (sample != nullptr) { // Set the bounding box (in original image coordinates) in the sample. TPOINT topleft, botright; topleft.x = box.left(); topleft.y = box.top(); botright.x = box.right(); botright.y = box.bottom(); TPOINT original_topleft, original_botright; blob.denorm().DenormTransform(nullptr, topleft, &original_topleft); blob.denorm().DenormTransform(nullptr, botright, &original_botright); sample->set_bounding_box(TBOX(original_topleft.x, original_botright.y, original_botright.x, original_topleft.y)); } return sample; }
pair<Tensor, float> computeSampleGradient(const TrainingSample &sample, NetworkContext &ctx) { Vector output = process(sample.input, ctx); ctx.layerDeltas.resize(numLayers); ctx.layerDeltas[ctx.layerDeltas.size() - 1] = output - sample.expectedOutput; // cross entropy error function. for (int i = ctx.layerDeltas.size() - 2; i >= 0; i--) { Matrix noBiasWeights = layerWeights(i+1).bottomRightCorner(layerWeights(i+1).rows(), layerWeights(i+1).cols()-1); ctx.layerDeltas[i] = noBiasWeights.transpose() * ctx.layerDeltas[i+1]; assert(ctx.layerDeltas[i].rows() == ctx.layerOutputs[i].rows()); for (unsigned r = 0; r < ctx.layerDeltas[i].rows(); r++) { float out = ctx.layerOutputs[i](r); ctx.layerDeltas[i](r) *= out * (1.0f - out); } } auto result = make_pair(Tensor(), 0.0f); for (unsigned i = 0; i < numLayers; i++) { auto inputs = getInputWithBias(i == 0 ? sample.input : ctx.layerOutputs[i-1]); result.first.AddLayer(ctx.layerDeltas[i] * inputs.transpose()); } for (unsigned i = 0; i < output.rows(); i++) { result.second += (output(i) - sample.expectedOutput(i)) * (output(i) - sample.expectedOutput(i)); } return result; }
// Adjust the weights of all the samples to be uniform in the given charset. // Returns the number of samples in the iterator. int SampleIterator::UniformSamples() { int num_good_samples = 0; for (Begin(); !AtEnd(); Next()) { TrainingSample* sample = MutableSample(); sample->set_weight(1.0); ++num_good_samples; } NormalizeSamples(); return num_good_samples; }
// Displays classification as the given shape_id. Creates as many windows // as it feels fit, using index as a guide for placement. Adds any created // windows to the windows output and returns a new index that may be used // by any subsequent classifiers. Caller waits for the user to view and // then destroys the windows by clearing the vector. int TessClassifier::DisplayClassifyAs( const TrainingSample& sample, Pix* page_pix, int unichar_id, int index, PointerVector<ScrollView>* windows) { int shape_id = unichar_id; if (GetShapeTable() != NULL) shape_id = BestShapeForUnichar(sample, page_pix, unichar_id, NULL); if (shape_id < 0) return index; if (UnusedClassIdIn(classify_->PreTrainedTemplates, shape_id)) { tprintf("No built-in templates for class/shape %d\n", shape_id); return index; } classify_->ShowBestMatchFor(shape_id, sample.features(), sample.num_features()); return index; }
// Classifies the given [training] sample, writing to results. // See ShapeClassifier for a full description. int CubeClassifier::ClassifySample(const TrainingSample& sample, Pix* page_pix, int debug, int keep_this, GenericVector<ShapeRating>* results) { results->clear(); if (page_pix == NULL) return 0; ASSERT_HOST(cube_cntxt_ != NULL); const TBOX& char_box = sample.bounding_box(); CubeObject* cube_obj = new tesseract::CubeObject( cube_cntxt_, page_pix, char_box.left(), pixGetHeight(page_pix) - char_box.top(), char_box.width(), char_box.height()); CharAltList* alt_list = cube_obj->RecognizeChar(); alt_list->Sort(); CharSet* char_set = cube_cntxt_->CharacterSet(); if (alt_list != NULL) { for (int i = 0; i < alt_list->AltCount(); ++i) { // Convert cube representation to a shape_id. int alt_id = alt_list->Alt(i); int unichar_id = char_set->UnicharID(char_set->ClassString(alt_id)); int shape_id = shape_table_.FindShape(unichar_id, -1); if (shape_id >= 0) results->push_back(ShapeRating(shape_id, alt_list->AltProb(i))); } delete alt_list; } delete cube_obj; return results->size(); }
// Normalize the weights of all the samples in the charset_map so they sum // to 1. Returns the minimum assigned sample weight. double SampleIterator::NormalizeSamples() { double total_weight = 0.0; int sample_count = 0; for (Begin(); !AtEnd(); Next()) { const TrainingSample& sample = GetSample(); total_weight += sample.weight(); ++sample_count; } // Normalize samples. double min_assigned_sample_weight = 1.0; if (total_weight > 0.0) { for (Begin(); !AtEnd(); Next()) { TrainingSample* sample = MutableSample(); double weight = sample->weight() / total_weight; if (weight < min_assigned_sample_weight) min_assigned_sample_weight = weight; sample->set_weight(weight); } } return min_assigned_sample_weight; }
// Replicates the samples to a minimum frequency defined by // 2 * kSampleRandomSize, or for larger counts duplicates all samples. // After replication, the replicated samples are perturbed slightly, but // in a predictable and repeatable way. // Use after OrganizeByFontAndClass(). void TrainingSampleSet::ReplicateAndRandomizeSamples() { ASSERT_HOST(font_class_array_ != NULL); int font_size = font_id_map_.CompactSize(); for (int font_index = 0; font_index < font_size; ++font_index) { for (int c = 0; c < unicharset_size_; ++c) { FontClassInfo &fcinfo = (*font_class_array_)(font_index, c); int sample_count = fcinfo.samples.size(); int min_samples = 2 * MAX(kSampleRandomSize, sample_count); if (sample_count > 0 && sample_count < min_samples) { int base_count = sample_count; for (int base_index = 0; sample_count < min_samples; ++sample_count) { int src_index = fcinfo.samples[base_index++]; if (base_index >= base_count) base_index = 0; TrainingSample *sample = samples_[src_index]->RandomizedCopy( sample_count % kSampleRandomSize); int sample_index = samples_.size(); sample->set_sample_index(sample_index); samples_.push_back(sample); fcinfo.samples.push_back(sample_index); } } } } }
// Classifies the given [training] sample, writing to results. // See ShapeClassifier for a full description. int CubeTessClassifier::ClassifySample(const TrainingSample& sample, Pix* page_pix, int debug, int keep_this, GenericVector<ShapeRating>* results) { int num_results = pruner_->ClassifySample(sample, page_pix, debug, keep_this, results); if (page_pix == NULL) return num_results; ASSERT_HOST(cube_cntxt_ != NULL); const TBOX& char_box = sample.bounding_box(); CubeObject* cube_obj = new tesseract::CubeObject( cube_cntxt_, page_pix, char_box.left(), pixGetHeight(page_pix) - char_box.top(), char_box.width(), char_box.height()); CharAltList* alt_list = cube_obj->RecognizeChar(); CharSet* char_set = cube_cntxt_->CharacterSet(); if (alt_list != NULL) { for (int r = 0; r < num_results; ++r) { const Shape& shape = shape_table_.GetShape((*results)[r].shape_id); // Get the best cube probability of all unichars in the shape. double best_prob = 0.0; for (int i = 0; i < alt_list->AltCount(); ++i) { int alt_id = alt_list->Alt(i); int unichar_id = char_set->UnicharID(char_set->ClassString(alt_id)); if (shape.ContainsUnichar(unichar_id) && alt_list->AltProb(i) > best_prob) { best_prob = alt_list->AltProb(i); } } (*results)[r].rating = best_prob; } delete alt_list; // Re-sort by rating. results->sort(&ShapeRating::SortDescendingRating); } delete cube_obj; return results->size(); }
// Visual debugger classifies the given sample, displays the results and // solicits user input to display other classifications. Returns when // the user has finished with debugging the sample. // Probably doesn't need to be overridden if the subclass provides // DisplayClassifyAs. void ShapeClassifier::DebugDisplay(const TrainingSample& sample, Pix* page_pix, UNICHAR_ID unichar_id) { #ifndef GRAPHICS_DISABLED static ScrollView* terminator = NULL; if (terminator == NULL) { terminator = new ScrollView("XIT", 0, 0, 50, 50, 50, 50, true); } ScrollView* debug_win = CreateFeatureSpaceWindow("ClassifierDebug", 0, 0); // Provide a right-click menu to choose the class. SVMenuNode* popup_menu = new SVMenuNode(); popup_menu->AddChild("Choose class to debug", 0, "x", "Class to debug"); popup_menu->BuildMenu(debug_win, false); // Display the features in green. const INT_FEATURE_STRUCT* features = sample.features(); int num_features = sample.num_features(); for (int f = 0; f < num_features; ++f) { RenderIntFeature(debug_win, &features[f], ScrollView::GREEN); } debug_win->Update(); GenericVector<UnicharRating> results; // Debug classification until the user quits. const UNICHARSET& unicharset = GetUnicharset(); SVEvent* ev; SVEventType ev_type; do { PointerVector<ScrollView> windows; if (unichar_id >= 0) { tprintf("Debugging class %d = %s\n", unichar_id, unicharset.id_to_unichar(unichar_id)); UnicharClassifySample(sample, page_pix, 1, unichar_id, &results); DisplayClassifyAs(sample, page_pix, unichar_id, 1, &windows); } else { tprintf("Invalid unichar_id: %d\n", unichar_id); UnicharClassifySample(sample, page_pix, 1, -1, &results); } if (unichar_id >= 0) { tprintf("Debugged class %d = %s\n", unichar_id, unicharset.id_to_unichar(unichar_id)); } tprintf("Right-click in ClassifierDebug window to choose debug class,"); tprintf(" Left-click or close window to quit...\n"); UNICHAR_ID old_unichar_id; do { old_unichar_id = unichar_id; ev = debug_win->AwaitEvent(SVET_ANY); ev_type = ev->type; if (ev_type == SVET_POPUP) { if (unicharset.contains_unichar(ev->parameter)) { unichar_id = unicharset.unichar_to_id(ev->parameter); } else { tprintf("Char class '%s' not found in unicharset", ev->parameter); } } delete ev; } while (unichar_id == old_unichar_id && ev_type != SVET_CLICK && ev_type != SVET_DESTROY); } while (ev_type != SVET_CLICK && ev_type != SVET_DESTROY); delete debug_win; #endif // GRAPHICS_DISABLED }
// Apply the supplied feature_space/feature_map transform to all samples // accessed by this iterator. void SampleIterator::MapSampleFeatures(const IntFeatureMap& feature_map) { for (Begin(); !AtEnd(); Next()) { TrainingSample* sample = MutableSample(); sample->MapFeatures(feature_map); } }
// Returns a string debug representation of the given sample: // font, unichar_str, bounding box, page. STRING TrainingSampleSet::SampleToString(const TrainingSample &sample) const { STRING boxfile_str; MakeBoxFileStr(unicharset_.id_to_unichar(sample.class_id()), sample.bounding_box(), sample.page_num(), &boxfile_str); return STRING(fontinfo_table_.get(sample.font_id()).name) + " " + boxfile_str; }