TEST(KDTree, shouldSearchInTree) { KDTree tree; vector<RTShape*> shapes; RTSphere sphere0(Vector(2.5, 2.5, 2.5), 2.4); RTSphere sphere1(Vector(2.5, 2.5, 7.5), 2.4); RTSphere sphere2(Vector(2.5, 7.5, 7.5), 2.4); RTSphere sphere3(Vector(2.5, 7.5, 2.5), 2.4); RTSphere sphere4(Vector(7.5, 2.5, 2.5), 2.4); RTSphere sphere5(Vector(7.5, 2.5, 7.5), 2.4); RTSphere sphere6(Vector(7.5, 7.5, 7.5), 2.4); RTSphere sphere7(Vector(7.5, 7.5, 2.5), 2.4); shapes.push_back(&sphere0); shapes.push_back(&sphere1); shapes.push_back(&sphere2); shapes.push_back(&sphere3); shapes.push_back(&sphere4); shapes.push_back(&sphere5); shapes.push_back(&sphere6); shapes.push_back(&sphere7); BoundingBox box(Vector(0,0,0), Vector(10,10,10)); tree.setBoundingBox(box); tree.setTerminationCondition(1); tree.build(shapes, 0); Ray ray(Vector(7.5, 7.5, -2 ), Vector(0,0,1)); IntersectionPtr intersection = tree.intersect(ray); CHECK( intersection != nullptr ); CHECK( intersection->getShape() == &sphere7 ); }
TEST(KDTree, itSplitsXthenYthenZ) { KDTree tree; vector<RTShape*> shapes; RTSphere sphere0(Vector(5,5,0), 1); RTSphere sphere1(Vector(5,-5,0), 1); RTSphere sphere2(Vector(-5,5,0), 1); RTSphere sphere3(Vector(-5,-5,0), 1); shapes.push_back(&sphere0); shapes.push_back(&sphere1); shapes.push_back(&sphere2); shapes.push_back(&sphere3); BoundingBox box(Vector(-6,-6,-6), Vector(12,12,12)); tree.setBoundingBox(box); tree.setTerminationCondition(1); tree.build(shapes, 0); CHECK_EQUAL( 1, tree.getLeft()->getLeft()->size() ); CHECK_EQUAL( 1, tree.getLeft()->getRight()->size() ); CHECK_EQUAL( 1, tree.getRight()->getLeft()->size() ); CHECK_EQUAL( 1, tree.getRight()->getRight()->size() ); }
TEST(KDTree, shouldSupportIntersectionSearchForRegularNodes) { KDTree tree; vector<RTShape*> shapes; RTSphere sphere0(Vector(5,5,0), 1); RTSphere sphere1(Vector(5,-5,0), 1); RTSphere sphere2(Vector(-5,5,0), 1); RTSphere sphere3(Vector(-5,-5,0), 1); shapes.push_back(&sphere0); shapes.push_back(&sphere1); shapes.push_back(&sphere2); shapes.push_back(&sphere3); BoundingBox box(Vector(-6,-6,-6), Vector(12,12,12)); tree.setBoundingBox(box); tree.setTerminationCondition(1); tree.build(shapes, 0); Ray ray(Vector(-10, 5, 0 ), Vector(1,0,0)); IntersectionPtr intersection = tree.intersect(ray); CHECK( intersection != nullptr ); CHECK( intersection->getShape() == &sphere2 ); }
TEST(KDTree, shouldSupportTriangles) { KDTree tree; vector<RTShape*> shapes; RTTriangle triangle0(Vector(0.1,0.1,0.1), Vector(1,1,1), Vector(0.1,2,0.1)); RTTriangle triangle1(Vector(-0.1,-0.1,-0.1), Vector(-1,-1,-1), Vector(-0.1,-2,-0.1)); shapes.push_back(&triangle0); shapes.push_back(&triangle1); BoundingBox box(Vector(-1,-2,-1), Vector(2,4,2)); tree.setBoundingBox(box); tree.build(shapes, 0); CHECK_EQUAL( 1, tree.getLeft()->size() ); CHECK_EQUAL( 1, tree.getRight()->size() ); }
TEST(KDTree, mixingTrianglesAndSphereIsNotAProblem) { KDTree tree; vector<RTShape*> shapes; RTTriangle triangle(Vector(0.1,0.1,0.1), Vector(1,1,1), Vector(0.1,2,0.1)); RTSphere sphere(Vector(-1,-1,-1), 1); shapes.push_back(&triangle); shapes.push_back(&sphere); BoundingBox box(Vector(-2,-2,-2), Vector(2,4,2)); tree.setBoundingBox(box); tree.build(shapes, 0); CHECK_EQUAL( 1, tree.getLeft()->size() ); CHECK_EQUAL( 1, tree.getRight()->size() ); }
TEST(KDTree, shapesCanBeInTwoVoxelsIfOnEdge) { KDTree tree; vector<RTShape*> shapes; RTSphere sphere0(Vector(-5,0,0), 1); RTSphere sphere1(Vector(5,0,0), 1); RTSphere sphere2(Vector(0,0,0), 1); shapes.push_back(&sphere0); shapes.push_back(&sphere1); shapes.push_back(&sphere2); BoundingBox box(Vector(-6, -1, -1), Vector(12, 2, 2)); tree.setBoundingBox(box); tree.setTerminationCondition(2); tree.build(shapes, 0); CHECK_EQUAL( 2, tree.getLeft()->size() ); CHECK_EQUAL( 2, tree.getRight()->size() ); }
TEST(KDTree, shouldSupportIntersectionSearchForChildNode) { KDTree tree; vector<RTShape*> shapes; RTTriangle t1(Vector(0,0,0), Vector(0,1,0), Vector(0,0,1)); RTTriangle t2(Vector(-1,0,0), Vector(-1,1,0), Vector(-1,0,1)); shapes.push_back(&t1); shapes.push_back(&t2); BoundingBox box(Vector(0,0,0), Vector(1,0,1)); tree.setBoundingBox(box); tree.setTerminationCondition(2); tree.build(shapes, 0); Ray ray(Vector(-5,0.1,0.1), Vector(1,0,0)); IntersectionPtr intersection = tree.intersect(ray); CHECK( intersection->getShape() == &t2 ); }
TEST(KDTree, shouldExtractTrianglesFromPolySet) { KDTree tree; vector<RTShape*> shapes; RTPolySet ps; RTTriangle t1(Vector(0,0,0), Vector(1,0,0), Vector(1,0,1)); ps.addTriangle(t1); RTTriangle t2(Vector(5,5,5), Vector(6,6,5), Vector(6,6,6)); ps.addTriangle(t2); shapes.push_back(&ps); BoundingBox box(Vector(0,0,0), Vector(6,6,6)); tree.setBoundingBox(box); tree.setTerminationCondition(2); tree.build(shapes, 0); // This should'a been one with two. CHECK_EQUAL( 1, tree.getLeft()->size() ); CHECK_EQUAL( 1, tree.getRight()->size() ); }
//=======================================================================// void stitch::Vec3::relaxEquidistantVectorsII(std::vector<stitch::Vec3> &vectors, uint32_t numIterations) { const uint32_t numVectors = vectors.size(); for (uint32_t iterationsDone=0; iterationsDone<numIterations; ++iterationsDone) { float minResult=0.0f; float maxResult=0.0f; float sumResult=0.0f; KDTree kdTree; for (uint32_t acteeVectorNum=0; acteeVectorNum<numVectors; ++acteeVectorNum) { kdTree.addItem(new BoundingVolume(vectors[acteeVectorNum], 0.0f, acteeVectorNum)); } std::vector<stitch::Vec3> splitAxisVec; splitAxisVec.push_back(Vec3(1.0f, 0.0f, 0.0f)); splitAxisVec.push_back(Vec3(0.0f, 1.0f, 0.0f)); splitAxisVec.push_back(Vec3(0.0f, 0.0f, 1.0f)); kdTree.build(KDTREE_DEFAULT_CHUNK_SIZE, 0, 1000, splitAxisVec); for (uint32_t acteeVectorNum=0; acteeVectorNum<numVectors; ++acteeVectorNum) { stitch::Vec3 relaxDelta;//Initialised to zero in the constructor. float minActDistance=10.0;//Some large initial distance. KNearestItems kNearestItems(vectors[acteeVectorNum], 1.0f, numVectors/100); kdTree.getNearestK(&kNearestItems); const uint32_t numNearestItems=kNearestItems.numItems_; for (uint32_t i=0; i<numNearestItems; ++i) { uint32_t actingVectorNum=kNearestItems.heapArray_[i].second->userIndex_; if (actingVectorNum!=acteeVectorNum) { const float actDistance=(kNearestItems.heapArray_[i].second->centre_-vectors[acteeVectorNum]).lengthSq() / kNearestItems.searchRadiusSq_; const float actWeight=1.0f - actDistance; stitch::Vec3 actNormal=(vectors[acteeVectorNum] - vectors[actingVectorNum]); actNormal.normalise(); relaxDelta+=actNormal * (actWeight); if (actDistance < minActDistance) minActDistance=actDistance; } } vectors[acteeVectorNum] += relaxDelta*sqrtf(minActDistance*kNearestItems.searchRadiusSq_)*0.1;//+Vec3::randNorm()*0.005*minActDistance; vectors[acteeVectorNum].normalise(); if (acteeVectorNum>0) { sumResult+=minActDistance; if (minActDistance<minResult) minResult=minActDistance; if (minActDistance>maxResult) maxResult=minActDistance; } else { sumResult=minResult=maxResult=minActDistance; } } //std::cout << "relaxActingDistance = [" << minResult << "|" << (sumResult/numVectors) << "|" << maxResult << "]\n"; //std::cout.flush(); } }
int main(int argc, char** argv) { std::ifstream imageFile; std::ifstream labelFile; if(argc != 5) { std::cerr << "Invalid command." << std::endl; return -1; } // Open training images and their labels for reading. imageFile.open(argv[1], std::ios::binary); labelFile.open(argv[2], std::ios::binary); labelFile.seekg(8); BitInputStream* input = new BitInputStream(imageFile); BitInputStream* label = new BitInputStream(labelFile); // Get the magic number, the rows, and the columns of each training image. int magic = input->readInt(); int total = input->readInt(); std::cerr << "Loading training images" << std::endl; std::cerr << "Total image: " << total << std::endl; int rows = input->readInt(); std::cerr << "Each image contains " << rows << " rows." << std::endl; int columns = input->readInt(); std::cerr << "Each image contains " << columns << " columns." << std::endl; // Load the training data to memory. Training* training = new Training(total, rows * columns); for(int i = 0; i < total; i++) { int lbl = (int)(label->readChar()); TRPoint* point = new TRPoint(lbl, rows * columns, i); for(int j = 0; j < (rows * columns); j++) { int pixel = (int)(input->readChar()); point->addPixel(pixel); } training->addElement(point); } imageFile.close(); labelFile.close(); std::ifstream testImageFile; std::ifstream testLabelFile; // Open testing images and their actual labels for reading. testImageFile.open(argv[3], std::ios::binary); testLabelFile.open(argv[4], std::ios::binary); testLabelFile.seekg(8); input = new BitInputStream(testImageFile); label = new BitInputStream(testLabelFile); magic = input->readInt(); total = input->readInt(); std::cerr << "Loading testing images" << std::endl; std::cerr << "Total image: " << total << std::endl; rows = input->readInt(); std::cerr << "Each image contains " << rows << " rows." << std::endl; columns = input->readInt(); std::cerr << "Each image contains " << columns << " columns." << std::endl; // Construct the K-D for nearest neighbor search. std::cerr << "Please enter the number of elements in each leaf for your K-D tree: "; std::string numImages; getline(std::cin, numImages); std::cerr << "Ok, at least " << atoi(numImages.c_str()) << " images." << std::endl; std::cerr << "Constructing K-D tree for training set." << std::endl; KDTree* tree = new KDTree(); tree->root = tree->build(tree->root, training, training->size(), atoi(numImages.c_str())); // Load the testing data to memory. Training* testing = new Training(total, rows * columns); for(int i = 0; i < total; i++) { int lbl = (int)(label->readChar()); TRPoint* point = new TRPoint(lbl, rows * columns); for(int j = 0; j < (rows * columns); j++) { int pixel = (int)(input->readChar()); point->addPixel(pixel); } testing->addElement(point); } testImageFile.close(); testLabelFile.close(); // Loading the actual true nearest neighbor of each testing images to memory. std::ifstream actualLabels; actualLabels.open("actual"); std::string in; int* trueLabels = new int[total]; int index = 0; for(int i = 0; getline(actualLabels, in); i++) trueLabels[i] = atoi(in.c_str()); int errors = 0; int notTrueNN = 0; // Perform the nearest neighbor search. for(int o = 0; o < testing->size(); o++) { std::cerr << "classifying image " << o+1 << std::endl; int currentDist = INT_MAX; int currentLabel = -1; int imageNumber = -1; // Find the leaf that contains the elements closest to the test point. TRPoint* testPoint = testing->getElement(o); Training* set = tree->find(testing->getElement(o))->set; // Compute the shortest distance from training images, update if needed. for(int i = 0; i < set->size(); i++) { TRPoint* trainPoint = set->getElement(i); int distance = 0; for(int j = 0; j < trainPoint->size; j++) { int difference = trainPoint->feature[j] - testPoint->feature[j]; if(difference < 0) difference = (-1) * difference; distance = distance + difference; } if(distance < currentDist) { currentDist = distance; currentLabel = trainPoint->label; imageNumber = trainPoint->index; } } std::cerr << "classified label: " << currentLabel << std::endl; std::cerr << "actual label: " << testPoint->label << std::endl; // There's an error if this image is not classified correctly. if(currentLabel != testPoint->label) { std::cerr << "image " << o+1 << " has an error" << std::endl; //std::cout << "image " << o+1 << " has an error" << std::endl; std::cerr << "classified with label " << currentLabel << std::endl; //std::cout << "classified with label " << currentLabel << std::endl; //printImage(training->getElement(imageNumber), columns, rows); std::cerr << "actual label is " << testPoint->label << std::endl; //std::cout << "actual label is " << testPoint->label << std::endl; //printImage(testPoint, columns, rows); errors++; } // Not a true nearest neighbor. if(imageNumber != trueLabels[o]) notTrueNN++; } // Print out the result here. std::cerr << "errors: " << errors << std::endl; std::cerr << "not true nearest neighbors: " << notTrueNN << std::endl; return 0; }