示例#1
0
void Consensus::findConsensus(const vector<Point2f> & points, const vector<int> & classes,
        const float scale, const float rotation,
        Point2f & center, vector<Point2f> & points_inlier, vector<int> & classes_inlier)
{
    FILE_LOG(logDEBUG) << "Consensus::findConsensus() call";

    //If no points are available, reteurn nan
    if (points.size() == 0)
    {
        center.x = numeric_limits<float>::quiet_NaN();
        center.y = numeric_limits<float>::quiet_NaN();

        FILE_LOG(logDEBUG) << "Consensus::findConsensus() return";

        return;
    }

    //Compute votes
    vector<Point2f> votes(points.size());
    for (size_t i = 0; i < points.size(); i++)
    {
        votes[i] = points[i] - scale * rotate(points_normalized[classes[i]], rotation);
    }

    t_index N = points.size();

    float * D = new float[N*(N-1)/2]; //This is a lot of memory, so we put it on the heap
    cluster_result Z(N-1);

    //Compute pairwise distances between votes
    int index = 0;
    for (size_t i = 0; i < points.size(); i++)
    {
        for (size_t j = i+1; j < points.size(); j++)
        {
            //TODO: This index calculation is correct, but is it a good thing?
            //int index = i * (points.size() - 1) - (i*i + i) / 2 + j - 1;
            D[index] = norm(votes[i] - votes[j]);
            index++;
        }
    }

    FILE_LOG(logDEBUG) << "Consensus::MST_linkage_core() call";
    MST_linkage_core(N,D,Z);
    FILE_LOG(logDEBUG) << "Consensus::MST_linkage_core() return";

    union_find nodes(N);

    //Sort linkage by distance ascending
    std::stable_sort(Z[0], Z[N-1]);

    //S are cluster sizes
	int* S = (int*)alloca((2*N-1)*sizeof(int));
    //TODO: Why does this loop go to 2*N-1? Shouldn't it be simply N? Everything > N gets overwritten later
    for(int i = 0; i < 2*N-1; i++)
    {
        S[i] = 1;
    }

    t_index parent = 0; //After the loop ends, parent contains the index of the last cluster
    for (node const * NN=Z[0]; NN!=Z[N-1]; ++NN)
    {
        // Get two data points whose clusters are merged in step i.
        // Find the cluster identifiers for these points.
        t_index node1 = nodes.Find(NN->node1);
        t_index node2 = nodes.Find(NN->node2);

        // Merge the nodes in the union-find data structure by making them
        // children of a new node
        // if the distance is appropriate
        if (NN->dist < thr_cutoff)
        {
            parent = nodes.Union(node1, node2);
            S[parent] = S[node1] + S[node2];
        }
    }

    //Get cluster labels
    int* T = (int*)alloca(N*sizeof(int));
    for (t_index i = 0; i < N; i++)
    {
        T[i] = nodes.Find(i);
    }

    //Find largest cluster
    int S_max = distance(S, max_element(S, S + 2*N-1));

    //Find inliers, compute center of votes
    points_inlier.reserve(S[S_max]);
    classes_inlier.reserve(S[S_max]);
    center.x = center.y = 0;

    for (size_t i = 0; i < points.size(); i++)
    {
        //If point is in consensus cluster
        if (T[i] == S_max)
        {
            points_inlier.push_back(points[i]);
            classes_inlier.push_back(classes[i]);
            center.x += votes[i].x;
            center.y += votes[i].y;
        }

    }

    center.x /= points_inlier.size();
    center.y /= points_inlier.size();

    delete[] D;

    FILE_LOG(logDEBUG) << "Consensus::findConsensus() return";
}
  SEXP fastcluster(SEXP const N_, SEXP const method_, SEXP D_, SEXP members_) {
    SEXP r = NULL; // return value

    try{
      /*
        Input checks
      */
      // Parameter N: number of data points
      PROTECT(N_);
      if (!IS_INTEGER(N_) || LENGTH(N_)!=1)
        Rf_error("'N' must be a single integer.");
      const int N = *INTEGER_POINTER(N_);
      if (N<2)
        Rf_error("N must be at least 2.");
      const std::ptrdiff_t NN = static_cast<std::ptrdiff_t>(N)*(N-1)/2;
      UNPROTECT(1); // N_

      // Parameter method: dissimilarity index update method
      PROTECT(method_);
      if (!IS_INTEGER(method_) || LENGTH(method_)!=1)
        Rf_error("'method' must be a single integer.");
      const int method = *INTEGER_POINTER(method_) - 1; // index-0 based;
      if (method<METHOD_METR_SINGLE || method>METHOD_METR_MEDIAN) {
        Rf_error("Invalid method index.");
      }
      UNPROTECT(1); // method_

      // Parameter members: number of members in each node
      auto_array_ptr<t_float> members;
      if (method==METHOD_METR_AVERAGE ||
          method==METHOD_METR_WARD ||
          method==METHOD_METR_CENTROID) {
        members.init(N);
        if (Rf_isNull(members_)) {
          for (t_index i=0; i<N; ++i) members[i] = 1;
        }
        else {
          PROTECT(members_ = AS_NUMERIC(members_));
          if (LENGTH(members_)!=N)
            Rf_error("'members' must have length N.");
          const t_float * const m = NUMERIC_POINTER(members_);
          for (t_index i=0; i<N; ++i) members[i] = m[i];
          UNPROTECT(1); // members
        }
      }

      // Parameter D_: dissimilarity matrix
      PROTECT(D_ = AS_NUMERIC(D_));
      if (LENGTH(D_)!=NN)
        Rf_error("'D' must have length (N \\choose 2).");
      const double * const D = NUMERIC_POINTER(D_);
      // Make a working copy of the dissimilarity array
      // for all methods except "single".
      auto_array_ptr<double> D__;
      if (method!=METHOD_METR_SINGLE) {
        D__.init(NN);
        for (std::ptrdiff_t i=0; i<NN; ++i)
          D__[i] = D[i];
      }
      UNPROTECT(1); // D_

      /*
        Clustering step
      */
      cluster_result Z2(N-1);
      switch (method) {
      case METHOD_METR_SINGLE:
        MST_linkage_core(N, D, Z2);
        break;
      case METHOD_METR_COMPLETE:
        NN_chain_core<METHOD_METR_COMPLETE, t_float>(N, D__, NULL, Z2);
        break;
      case METHOD_METR_AVERAGE:
        NN_chain_core<METHOD_METR_AVERAGE, t_float>(N, D__, members, Z2);
        break;
      case METHOD_METR_WEIGHTED:
        NN_chain_core<METHOD_METR_WEIGHTED, t_float>(N, D__, NULL, Z2);
        break;
      case METHOD_METR_WARD:
        NN_chain_core<METHOD_METR_WARD, t_float>(N, D__, members, Z2);
        break;
      case METHOD_METR_CENTROID:
        generic_linkage<METHOD_METR_CENTROID, t_float>(N, D__, members, Z2);
        break;
      case METHOD_METR_MEDIAN:
        generic_linkage<METHOD_METR_MEDIAN, t_float>(N, D__, NULL, Z2);
        break;
      default:
        throw std::runtime_error(std::string("Invalid method."));
      }

      D__.free();     // Free the memory now
      members.free(); // (not strictly necessary).

      SEXP m; // return field "merge"
      PROTECT(m = NEW_INTEGER(2*(N-1)));
      int * const merge = INTEGER_POINTER(m);

      SEXP dim_m; // Specify that m is an (N-1)×2 matrix
      PROTECT(dim_m = NEW_INTEGER(2));
      INTEGER(dim_m)[0] = N-1;
      INTEGER(dim_m)[1] = 2;
      SET_DIM(m, dim_m);

      SEXP h; // return field "height"
      PROTECT(h = NEW_NUMERIC(N-1));
      double * const height = NUMERIC_POINTER(h);

      SEXP o; // return fiels "order'
      PROTECT(o = NEW_INTEGER(N));
      int * const order = INTEGER_POINTER(o);

      if (method==METHOD_METR_CENTROID ||
          method==METHOD_METR_MEDIAN)
        generate_R_dendrogram<true>(merge, height, order, Z2, N);
      else
        generate_R_dendrogram<false>(merge, height, order, Z2, N);

      SEXP n; // names
      PROTECT(n = NEW_CHARACTER(3));
      SET_STRING_ELT(n, 0, COPY_TO_USER_STRING("merge"));
      SET_STRING_ELT(n, 1, COPY_TO_USER_STRING("height"));
      SET_STRING_ELT(n, 2, COPY_TO_USER_STRING("order"));

      PROTECT(r = NEW_LIST(3)); // field names in the output list
      SET_ELEMENT(r, 0, m);
      SET_ELEMENT(r, 1, h);
      SET_ELEMENT(r, 2, o);
      SET_NAMES(r, n);

      UNPROTECT(6); // m, dim_m, h, o, r, n
    } // try
    catch (const std::bad_alloc&) {
      Rf_error( "Memory overflow.");
    }
    catch(const std::exception& e){
      Rf_error( e.what() );
    }
    catch(const nan_error&){
      Rf_error("NaN dissimilarity value.");
    }
    #ifdef FE_INVALID
    catch(const fenv_error&){
      Rf_error( "NaN dissimilarity value in intermediate results.");
    }
    #endif
    catch(...){
      Rf_error( "C++ exception (unknown reason)." );
    }

    return r;
  }