// Decode the label marginals for each candidate arc. The output vector
// total_scores contains the sum of exp-scores (over the labels) for each arc;
// label_marginals contains those marginals ignoring the tree constraint.
void ConstituencyLabelerDecoder::DecodeLabelMarginals(
  Instance *instance, Parts *parts,
  const std::vector<double> &scores,
  std::vector<double> *total_scores,
  std::vector<double> *label_marginals) {
  ConstituencyLabelerInstanceNumeric *sentence =
    static_cast<ConstituencyLabelerInstanceNumeric*>(instance);
  ConstituencyLabelerParts *labeled_parts =
    static_cast<ConstituencyLabelerParts*>(parts);
  ConstituencyLabelerOptions *labeler_options =
    static_cast<ConstituencyLabelerOptions*>(pipe_->GetOptions());
  int num_nodes = sentence->GetNumConstituents();

  int offset_labeled_nodes, num_labeled_nodes;
  labeled_parts->GetOffsetNode(&offset_labeled_nodes, &num_labeled_nodes);
  total_scores->clear();
  total_scores->resize(num_nodes, 0.0);
  label_marginals->clear();
  label_marginals->resize(num_labeled_nodes, 0.0);

  for (int i = 0; i < num_nodes; ++i) {
    const std::vector<int> &index_node_parts =
      labeled_parts->FindNodeParts(i);
    // If no part for null label, initiliaze log partition to exp(0.0) to
    // account the null label which has score 0.0.
    LogValD total_score = (labeler_options->ignore_null_labels()) ?
      LogValD::One() : LogValD::Zero();
    for (int k = 0; k < index_node_parts.size(); ++k) {
      total_score += LogValD(scores[index_node_parts[k]], false);
    }
    (*total_scores)[i] = total_score.logabs();
    // If no part for null label, initiliaze sum to exp(0.0)/Z to
    // account the null label which has score 0.0.
    double sum = (labeler_options->ignore_null_labels()) ?
      (1.0 / total_score.as_float()) : 0.0;
    for (int k = 0; k < index_node_parts.size(); ++k) {
      LogValD marginal =
        LogValD(scores[index_node_parts[k]], false) / total_score;
      (*label_marginals)[index_node_parts[k] - offset_labeled_nodes] =
        marginal.as_float();
      sum += marginal.as_float();
    }
    if (!NEARLY_EQ_TOL(sum, 1.0, 1e-9)) {
      LOG(INFO) << "Label marginals don't sum to one: sum = " << sum;
    }
  }
}
// Compute marginals and evaluate log partition function for a coreference tree
// model.
void CoreferenceDecoder::DecodeBasicMarginals(
    Instance *instance, Parts *parts,
    const std::vector<double> &scores,
    std::vector<double> *predicted_output,
    double *log_partition_function,
    double *entropy) {
  CoreferenceDocumentNumeric *document =
    static_cast<CoreferenceDocumentNumeric*>(instance);
  CoreferenceParts *coreference_parts = static_cast<CoreferenceParts*>(parts);

  predicted_output->clear();
  predicted_output->resize(parts->size(), 0.0);

  *log_partition_function = 0.0;
  *entropy = 0.0;
  const std::vector<Mention*> &mentions = document->GetMentions();
  for (int j = 0; j < mentions.size(); ++j) {
    // List all possible antecedents and pick the one with highest score.
    const std::vector<int> &arcs = coreference_parts->FindArcParts(j);
    int best_antecedent = -1;
    // Find the best label for each candidate arc.
    LogValD total_score = LogValD::Zero();
    //LOG(INFO) << "num_arcs = " << arcs.size();
    for (int k = 0; k < arcs.size(); ++k) {
      int r = arcs[k];
      total_score += LogValD(scores[r], false);
      //LOG(INFO) << "scores[" << r << "] = " << scores[r];
    }
    //LOG(INFO) << "total score = " << total_score.logabs();
    *log_partition_function += total_score.logabs();
    double sum = 0.0;
    for (int k = 0; k < arcs.size(); ++k) {
      int r = arcs[k];
      LogValD marginal = LogValD(scores[r], false) / total_score;
      double marginal_value = marginal.as_float();
      (*predicted_output)[r] = marginal_value;
#if 0
      if (marginal_value > 0.0) {
        LOG(INFO) << "Marginal[" << j << ", "
                  << static_cast<CoreferencePartArc*>((*parts)[r])->parent_mention()
                  << "] = " << marginal_value;
      }
#endif
      if (scores[r] != -std::numeric_limits<double>::infinity()) {
        *entropy -= scores[r] * marginal_value;
      } else {
        CHECK_EQ(marginal_value, 0.0);
      }
      sum += marginal_value;
    }
    if (!NEARLY_EQ_TOL(sum, 1.0, 1e-9)) {
      LOG(INFO) << "Antecedent marginals don't sum to one: sum = " << sum;
    }
  }

  *entropy += *log_partition_function;

#if 0
  LOG(INFO) << "Log-partition function: " << *log_partition_function;
  LOG(INFO) << "Entropy: " << *entropy;
#endif
}