void MaximumCompositeLikelihood::UpdateTrainingLabeling( const std::vector<labeled_instance_type>& training_update) { assert(fg_orig_index.size() == comp_training_data.size()); // For all decomposed components for (unsigned int cn = 0; cn < fg_orig_index.size(); ++cn) { // Original factor graph index unsigned int n = fg_orig_index[cn]; assert(n < training_update.size()); FactorGraph* fg = training_update[n].first; const FactorGraphObservation* obs = training_update[n].second; size_t var_count = fg->Cardinalities().size(); // Update each component of the current decomposition for (unsigned int ci = 0; ci < fg_cc_count[cn]; ++ci) { std::vector<unsigned int> cond_var_set; cond_var_set.reserve(var_count); // Add all variables not in this component to the conditioning set for (size_t vi = 0; vi < var_count; ++vi) { if (fg_cc_var_label[cn][vi] != ci) cond_var_set.push_back(static_cast<unsigned int>(vi)); } UpdateTrainingComponentCond(fg, obs, cond_var_set, cn); } } // Update fully observed components mle.UpdateTrainingLabeling(comp_training_data); }
void MaximumCompositeLikelihood::SetupTrainingData( const std::vector<labeled_instance_type>& training_data, const std::vector<InferenceMethod*> inference_methods) { assert(comp_training_data.size() == 0); assert(comp_inference_methods.size() == 0); assert(inference_methods.size() == training_data.size()); // Number of times each component will be covered unsigned int cover_count = 1; assert(decomp >= -1); if (decomp == DecomposePseudolikelihood) { cover_count = 1; } else if (decomp > 0) { cover_count = decomp; } // Produce composite factor graphs boost::timer decomp_timer; int training_data_size = static_cast<int>(training_data.size()); fg_cc_var_label.resize(cover_count * training_data_size); fg_cc_count.resize(cover_count * training_data_size); fg_orig_index.resize(cover_count * training_data_size); std::fill(fg_cc_count.begin(), fg_cc_count.end(), 0); unsigned int cn = 0; for (int n = 0; n < training_data_size; ++n) { FactorGraph* fg = training_data[n].first; size_t var_count = fg->Cardinalities().size(); // Get observation const FactorGraphObservation* obs = training_data[n].second; // Obtain one or more decomposition(s) for (unsigned int cover_iter = 0; cover_iter < cover_count; ++cover_iter) { VAcyclicDecomposition vac(fg); std::vector<bool> factor_is_removed; if (decomp == DecomposePseudolikelihood) { factor_is_removed.resize(fg->Factors().size()); std::fill(factor_is_removed.begin(), factor_is_removed.end(), true); } else { std::vector<double> factor_weight(fg->Factors().size(), 0.0); if (decomp == DecomposeUniform) { // Use constant weights std::fill(factor_weight.begin(), factor_weight.end(), 1.0); } else { // Use uniform random weights boost::uniform_real<double> uniform_dist(0.0, 1.0); boost::variate_generator<boost::mt19937&, boost::uniform_real<double> > rgen(RandomSource::GlobalRandomSampler(), uniform_dist); for (unsigned int fi = 0; fi < factor_weight.size(); ++fi) factor_weight[fi] = rgen(); } vac.ComputeDecompositionSP(factor_weight, factor_is_removed); } // Shatter factor graph into trees fg_cc_count[cn] += FactorGraphStructurizer::ConnectedComponents( fg, factor_is_removed, fg_cc_var_label[cn]); #if 0 std::cout << "MCL, instance " << n << " decomposed into " << cc_count << " components" << std::endl; #endif // Add each component as separate factor graph for (unsigned int ci = 0; ci < fg_cc_count[cn]; ++ci) { std::vector<unsigned int> cond_var_set; cond_var_set.reserve(var_count); // Add all variables not in this component to the conditioning set for (size_t vi = 0; vi < var_count; ++vi) { if (fg_cc_var_label[cn][vi] != ci) cond_var_set.push_back(static_cast<unsigned int>(vi)); } AddTrainingComponentCond(fg, obs, inference_methods[n], cond_var_set); } fg_orig_index[cn] = n; cn += 1; } } std::cout << "MCL, decomposed " << training_data.size() << " instances " << "into " << comp_training_data.size() << " instances " << (decomp == DecomposeUniform ? "(uniform)" : "(randomized)") << " in " << decomp_timer.elapsed() << "s." << std::endl; // Initialize MLE training data from created components SetupMLETrainingData(); }