Exemplo n.º 1
0
void validate(const char *fileName, const int K, const int dim, const int method, const T maxRadius)
{
    typedef Nabo::NearestNeighbourSearch<T, CloudType> NNS;
    typedef vector<NNS*> NNSV;
    typedef typename NNS::Matrix Matrix;
    typedef typename NNS::Vector Vector;
    typedef typename NNS::IndexMatrix IndexMatrix;

    Loader<T, CloudType> loader;
    loader.loadMatrix(fileName);

    // check if file is ok
    const CloudType d = loader.getValue();
    if (d.rows() != dim)
    {
        cerr << "Provided data has " << d.rows() << " dimensions, but the requested dimensions were " << dim << endl;
        exit(2);
    }
    if (K >= d.cols())
    {
        cerr << "Requested more nearest neighbour than points in the data set" << endl;
        exit(2);
    }

    // create different methods
    NNSV nnss;
    unsigned searchTypeCount(NNS::SEARCH_TYPE_COUNT);
#ifndef HAVE_OPENCL
    searchTypeCount -= 3;
#endif // HAVE_OPENCL
    for (unsigned i = 0; i < searchTypeCount; ++i)
        nnss.push_back(NNS::create(d, d.rows(), typename NNS::SearchType(i)));
    //nnss.push_back(new KDTreeBalancedPtInLeavesStack<T>(d, false));


    // check methods together
    const int itCount(method != -1 ? method : d.cols() * 2);

    /*
    // element-by-element search
    typedef typename NNS::IndexVector IndexVector;
    for (int i = 0; i < itCount; ++i)
    {
    	const Vector q(createQuery<T>(d, *nnss[0], i, method));
    	const IndexVector indexes_bf(nnss[0]->knn(q, K, 0, NNS::SORT_RESULTS));
    	for (size_t j = 1; j < nnss.size(); ++j)
    	{
    		const IndexVector indexes_kdtree(nnss[j]->knn(q, K, 0, NNS::SORT_RESULTS));
    		if (indexes_bf.size() != K)
    		{
    			cerr << "Different number of points found between brute force and request" << endl;
    			exit(3);
    		}
    		if (indexes_bf.size() != indexes_kdtree.size())
    		{
    			cerr << "Different number of points found between brute force and NNS type "<< j  << endl;
    			exit(3);
    		}
    		for (size_t k = 0; k < size_t(K); ++k)
    		{
    			Vector pbf(d.col(indexes_bf[k]));
    			//cerr << indexes_kdtree[k] << endl;
    			Vector pkdtree(d.col(indexes_kdtree[k]));
    			if (fabsf((pbf-q).squaredNorm() - (pkdtree-q).squaredNorm()) >= numeric_limits<float>::epsilon())
    			{
    				cerr << "Method " << j << ", cloud point " << i << ", neighbour " << k << " of " << K << " is different between bf and kdtree (dist " << (pbf-pkdtree).norm() << ")\n";
    				cerr << "* query:\n";
    				cerr << q << "\n";
    				cerr << "* indexes " << indexes_bf[k] << " (bf) vs " <<  indexes_kdtree[k] << " (kdtree)\n";
    				cerr << "* coordinates:\n";
    				cerr << "bf: (dist " << (pbf-q).norm() << ")\n";
    				cerr << pbf << "\n";
    				cerr << "kdtree (dist " << (pkdtree-q).norm() << ")\n";
    				cerr << pkdtree << endl;
    				exit(4);
    			}
    		}
    	}
    }
    */
    // create big query
    // check all-in-one query
    Matrix q(createQuery<T>(d, itCount, method));
    IndexMatrix indexes_bf(K, q.cols());
    Matrix dists2_bf(K, q.cols());
    nnss[0]->knn(q, indexes_bf, dists2_bf, K, 0, NNS::SORT_RESULTS, maxRadius);
    assert(indexes_bf.cols() == q.cols());
    for (size_t j = 1; j < nnss.size(); ++j)
    {
        IndexMatrix indexes_kdtree(K, q.cols());
        Matrix dists2_kdtree(K, q.cols());
        nnss[j]->knn(q, indexes_kdtree, dists2_kdtree, K, 0, NNS::SORT_RESULTS, maxRadius);
        if (indexes_bf.rows() != K)
        {
            cerr << "Different number of points found between brute force and request" << endl;
            exit(3);
        }
        if (indexes_bf.cols() != indexes_kdtree.cols())
        {
            cerr << "Different number of points found between brute force and NNS type "<< j  << endl;
            exit(3);
        }

        for (int i = 0; i < q.cols(); ++i)
        {
            for (size_t k = 0; k < size_t(K); ++k)
            {
                if (dists2_bf(k,i) == numeric_limits<float>::infinity())
                    continue;
                const int pbfi(indexes_bf(k,i));
                const Vector pbf(d.col(pbfi));
                const int pkdt(indexes_kdtree(k,i));
                if (pkdt < 0 || pkdt >= d.cols())
                {
                    cerr << "Method " << j << ", query point " << i << ", neighbour " << k << " of " << K << " has invalid index " << pkdt << " out of range [0:" << d.cols() << "[" << endl;
                    exit(4);
                }
                const Vector pkdtree(d.col(pkdt));
                const Vector pq(q.col(i));
                const float distDiff(fabsf((pbf-pq).squaredNorm() - (pkdtree-pq).squaredNorm()));
                if (distDiff > numeric_limits<float>::epsilon())
                {
                    cerr << "Method " << j << ", query point " << i << ", neighbour " << k << " of " << K << " is different between bf and kdtree (dist2 " << distDiff << ")\n";
                    cerr << "* query point:\n";
                    cerr << pq << "\n";
                    cerr << "* indexes " << pbfi << " (bf) vs " << pkdt << " (kdtree)\n";
                    cerr << "* coordinates:\n";
                    cerr << "bf: (dist " << (pbf-pq).norm() << ")\n";
                    cerr << pbf << "\n";
                    cerr << "kdtree (dist " << (pkdtree-pq).norm() << ")\n";
                    cerr << pkdtree << endl;
                    cerr << "* bf neighbours:\n";
                    for (int l = 0; l < K; ++l)
                        cerr << indexes_bf(l,i) << " (dist " << (d.col(indexes_bf(l,i))-pq).norm() << ")\n";
                    cerr << "* kdtree neighbours:\n";
                    for (int l = 0; l < K; ++l)
                        cerr << indexes_kdtree(l,i) << " (dist " << (d.col(indexes_kdtree(l,i))-pq).norm() << ")\n";
                    exit(4);
                }
            }
        }
    }

// 	cout << "\tstats kdtree: "
// 		<< kdt.getStatistics().totalVisitCount << " on "
// 		<< (long long)(itCount) * (long long)(d.cols()) << " ("
// 		<< (100. * double(kdt.getStatistics().totalVisitCount)) /  (double(itCount) * double(d.cols())) << " %"
// 		<< ")\n" << endl;

    // delete searches
    for (typename NNSV::iterator it(nnss.begin()); it != nnss.end(); ++it)
        delete (*it);
}