Esempio n. 1
0
void trainCard(const Image<PixRGB<byte> > &img, const std::string &cardName)
{
  rutz::shared_ptr<VisualObject>
    vo(new VisualObject(cardName, "NULL", img,
          Point2D<int>(-1,-1),
          std::vector<double>(),
          std::vector< rutz::shared_ptr<Keypoint> >(),
          USECOLOR));

  itsObjectDB.addObject(vo, false);

  itsObjectDB.saveTo("cards.vdb");
}
Esempio n. 2
0
int main(const int argc, const char **argv)
{

  MYLOGVERB = LOG_INFO;
  mgr = new ModelManager("Test ObjRec");

  if (mgr->parseCommandLine(
        (const int)argc, (const char**)argv, "<vdb file> <server ip>", 2, 2) == false)
    return 1;

  mgr->start();

  // catch signals and redirect them to terminate for clean exit:
  signal(SIGHUP, terminateProc); signal(SIGINT, terminateProc);
  signal(SIGQUIT, terminateProc); signal(SIGTERM, terminateProc);
  signal(SIGALRM, terminateProc);

  //get command line options
  const char *vdbFile = mgr->getExtraArg(0).c_str();
  const char *server_ip = mgr->getExtraArg(1).c_str();
  bool train = false;

  LINFO("Loading db from %s\n", vdbFile);
  //vdb.loadFrom(std::string(vdbFile));

  xwin  = new XWinManaged(Dims(256,256),
      -1, -1, "ILab Robot Head Demo");


   labelServer =
    nv2_label_server_create(9930,
        server_ip,
        9931);

  nv2_label_server_set_verbosity(labelServer,1); //allow warnings


  int send_interval = 1;

  while(!terminate)
  {

    Point2D clickLoc = xwin->getLastMouseClick();
    if (clickLoc.isValid())
      train = !train;

    struct nv2_image_patch p;
    const enum nv2_image_patch_result res =
      nv2_label_server_get_current_patch(labelServer, &p);

    std::string objName;
    if (res == NV2_IMAGE_PATCH_END)
    {
      fprintf(stdout, "ok, quitting\n");
      break;
    }
    else if (res == NV2_IMAGE_PATCH_NONE)
    {
      usleep(10000);
      continue;
    }
    else if (res == NV2_IMAGE_PATCH_VALID &&
        p.type == NV2_PIXEL_TYPE_RGB24)
    {
      Image<PixRGB<byte> > img(p.width, p.height, NO_INIT);
      memcpy(img.getArrayPtr(), p.data, p.width*p.height*3);

      Image<PixRGB<byte> > inputImg = rescale(img, 256, 256);

      std::string objName = matchObject(inputImg);

      Image<PixRGB<byte> > disp(320, 240, ZEROS);

      xwin->drawImage(inputImg);

      if (objName == "nomatch")
      {
        if (train)
        {
          printf("Is this %s\n", objName.c_str());
          std::string tmp;
          std::getline(std::cin, tmp);
          if (tmp == "exit") break;
          if (tmp == "no")
          {
            printf("Can you tell me what this is?\n");
            std::getline(std::cin, objName);

            rutz::shared_ptr<VisualObject>
              vo(new VisualObject(objName.c_str(), "NULL", inputImg,
                    Point2D(-1,-1),
                    std::vector<double>(),
                    std::vector< rutz::shared_ptr<Keypoint> >(),
                    USECOLOR));
            vdb.addObject(vo);
            vdb.saveTo(vdbFile);
          }
        }
      } else {
        printf("Object is %s\n", objName.c_str());
        struct nv2_patch_label l;
        l.protocol_version = NV2_LABEL_PROTOCOL_VERSION;
        l.patch_id = p.id;
        snprintf(l.source, sizeof(l.source), "%s",
            "ObjRec");
        snprintf(l.name, sizeof(l.name), "%s", // (%ux%u #%u)",
            objName.c_str());
        //(unsigned int) p.width,
        //(unsigned int) p.height,
        //(unsigned int) p.id);
        snprintf(l.extra_info, sizeof(l.extra_info),
            "auxiliary information");

        if (l.patch_id % send_interval == 0)
        {
          nv2_label_server_send_label(labelServer, &l);

          fprintf(stdout, "sent label '%s (%s)'\n",
              l.name, l.extra_info);
        }
        else
        {
          fprintf(stdout, "DROPPED label '%s (%s)'\n",
              l.name, l.extra_info);
        }
      }

      nv2_image_patch_destroy(&p);
    }

  }
  nv2_label_server_destroy(labelServer);

}
Esempio n. 3
0
/*
XWinManaged xwin(Dims(WIDTH,HEIGHT*2), 1, 1, "Test SIFT");


rutz::shared_ptr<VisualObject> objTop, objBottom;

void showObjs(rutz::shared_ptr<VisualObject> obj1, rutz::shared_ptr<VisualObject> obj2){
        //return ;

        Image<PixRGB<byte> > keyIma = rescale(obj1->getKeypointImage(),
                        WIDTH, HEIGHT);
        objTop = obj1;

        if (obj2.is_valid()){
                keyIma = concatY(keyIma, rescale(obj2->getKeypointImage(),
                                        WIDTH, HEIGHT));
                objBottom = obj2;
        }

        xwin.drawImage(keyIma);
}

void showKeypoint(rutz::shared_ptr<VisualObject> obj, int keypi,
                Keypoint::CHANNEL channel = Keypoint::ORI){

        char winTitle[255];
        switch(channel){
                case Keypoint::ORI:
                        sprintf(winTitle, "Keypoint view (Channel ORI)");
                        break;
                case Keypoint::COL:
                        sprintf(winTitle, "Keypoint view (Channel COL)");
         break;
                default:
                        sprintf(winTitle, "Keypoint view (Channel   )");
                        break;
        }


        rutz::shared_ptr<Keypoint> keyp = obj->getKeypoint(keypi);
        float x = keyp->getX();
        float y = keyp->getY();
        float s = keyp->getS();
        float o = keyp->getO();
        float m = keyp->getM();

        uint FVlength = keyp->getFVlength(channel);
        if (FVlength<=0) return; //dont show the Keypoint if we dont have a FV

        XWinManaged *xwinKey = new XWinManaged(Dims(WIDTH*2,HEIGHT), -1, -1, winTitle);


        //draw the circle around the keypoint
        const float sigma = 1.6F * powf(2.0F, s / float(6 - 3));
        const float sig = 1.5F * sigma;
        const int rad = int(3.0F * sig);

        Image<PixRGB<byte> > img = obj->getImage();
        Point2D<int> loc(int(x + 0.5F), int(y + 0.5F));
        drawCircle(img, loc, rad, PixRGB<byte>(255, 0, 0));
        drawDisk(img, loc, 2, PixRGB<byte>(255,0,0));

        s=s*5.0F; //mag for scale
        if (s > 0.0f) drawLine(img, loc,
                        Point2D<int>(int(x + s * cosf(o)  + 0.5F),
                                int(y + s * sinf(o) + 0.5F)),
                        PixRGB<byte>(255, 0, 0));

        char info[255];
        sprintf(info, "(%0.2f,%0.2f) s=%0.2f o=%0.2f m=%0.2f", x, y, s, o, m);

        writeText(img, Point2D<int>(0, HEIGHT-20), info,
                        PixRGB<byte>(255), PixRGB<byte>(127));


        //draw the vectors from the features vectors

        Image<PixRGB<byte> > fvDisp(WIDTH, HEIGHT, NO_INIT);
        fvDisp.clear(PixRGB<byte>(255, 255, 255));
        int xBins = int((float)WIDTH/4);
        int yBins = int((float)HEIGHT/4);

        drawGrid(fvDisp, xBins, yBins, 1, 1, PixRGB<byte>(0, 0, 0));



        switch (channel){
                case Keypoint::ORI:
                        for (int xx=0; xx<4; xx++){
                                for (int yy=0; yy<4; yy++){
                                        for (int oo=0; oo<8; oo++){
                                                Point2D<int> loc(xBins/2+(xBins*xx), yBins/2+(yBins*yy));
                                                byte mag = keyp->getFVelement(xx*32+yy*8+oo, channel);
                                                mag = mag/4;
                                                drawDisk(fvDisp, loc, 2, PixRGB<byte>(255, 0, 0));
                                                drawLine(fvDisp, loc,
                                                                Point2D<int>(int(loc.i + mag*cosf(oo*M_PI/4)),
                                                                        int(loc.j + mag*sinf(oo*M_PI/4))),
                                                                PixRGB<byte>(255, 0, 0));
                                        }
                                }
                        }
                        break;

                case Keypoint::COL:
                        for (int xx=0; xx<4; xx++){
                                for (int yy=0; yy<4; yy++){
                                        for (int cc=0; cc<3; cc++){
                                                Point2D<int> loc(xBins/2+(xBins*xx), yBins/2+(yBins*yy));
                                                byte mag = keyp->getFVelement(xx*12+yy*3+cc, channel);
                                                mag = mag/4;
                                                drawDisk(fvDisp, loc, 2, PixRGB<byte>(255, 0, 0));
                                                drawLine(fvDisp, loc,
                                                                Point2D<int>(int(loc.i + mag*cosf(-1*cc*M_PI/2)),
                                                                        int(loc.j + mag*sinf(-1*cc*M_PI/2))),
                                                                PixRGB<byte>(255, 0, 0));
                                        }
                                }
                        }
                        break;
                default:
                        break;
        }



        Image<PixRGB<byte> > disp = img;
        disp = concatX(disp, fvDisp);


        xwinKey->drawImage(disp);

        while(!xwinKey->pressedCloseButton()){
                usleep(100);
        }
        delete xwinKey;

}



void analizeImage(){
   int key = -1;

        while(key != 24){ // q to quit window
                key = xwin.getLastKeyPress();
                Point2D<int>  point = xwin.getLastMouseClick();
                if (point.i > -1 && point.j > -1){

                        //get the right object
                        rutz::shared_ptr<VisualObject> obj;
                        if (point.j < HEIGHT){
                                obj = objTop;
                        } else {
                                obj = objBottom;
                                point.j = point.j - HEIGHT;
                        }
                        LINFO("ClickInfo: key = %i, p=%i,%i", key, point.i, point.j);

                        //find the keypoint
                        for(uint i=0; i<obj->numKeypoints(); i++){
                                rutz::shared_ptr<Keypoint> keyp = obj->getKeypoint(i);
                                float x = keyp->getX();
                                float y = keyp->getY();

                                if ( (point.i < (int)x + 5 && point.i > (int)x - 5) &&
                                          (point.j < (int)y + 5 && point.j > (int)y - 5)){
                                        showKeypoint(obj, i, Keypoint::ORI);
                                        showKeypoint(obj, i, Keypoint::COL);
                                }

                        }

                }
        }

}
*/
int main(const int argc, const char **argv)
{

  MYLOGVERB = LOG_INFO;
  ModelManager manager("Test SIFT");



  nub::ref<InputFrameSeries> ifs(new InputFrameSeries(manager));
  manager.addSubComponent(ifs);

  nub::ref<OutputFrameSeries> ofs(new OutputFrameSeries(manager));
  manager.addSubComponent(ofs);



  if (manager.parseCommandLine(
        (const int)argc, (const char**)argv, "<database file> <trainingLabel>", 2, 2) == false)
    return 0;

  manager.start();

  Timer masterclock;                // master clock for simulations
  Timer timer;

  const char *vdbFile = manager.getExtraArg(0).c_str();
  const char *trainingLabel = manager.getExtraArg(1).c_str();

  int numMatches = 0; //the number of correct matches
  int totalObjects = 0; //the number of objects presented to the network
  int uObjId = 0; //a unique obj id for sift

  bool train = false;
  //load the database file
 // if (!train)
  vdb.loadFrom(std::string(vdbFile));

  while(1)
  {
    Image< PixRGB<byte> > inputImg;
    const FrameState is = ifs->updateNext();
    if (is == FRAME_COMPLETE)
      break;

    //grab the images
    GenericFrame input = ifs->readFrame();
    if (!input.initialized())
      break;
    inputImg = input.asRgb();
    totalObjects++;

    ofs->writeRGB(inputImg, "Input", FrameInfo("Input", SRC_POS));


    if (train)
    {
      //add the object to the database
      char objName[255]; sprintf(objName, "%s_%i", trainingLabel, uObjId);
      uObjId++;
      rutz::shared_ptr<VisualObject>
        vo(new VisualObject(objName, "NULL", inputImg,
              Point2D<int>(-1,-1),
              std::vector<float>(),
              std::vector< rutz::shared_ptr<Keypoint> >(),
              USECOLOR));

      vdb.addObject(vo);
    } else {

      //get the object classification
      std::string objName;
      std::string tmpName = matchObject(inputImg);
      int i = tmpName.find("_");
      objName.assign(tmpName, 0, i);
      LINFO("Object name %s", objName.c_str());
      printf("%i %s\n", ifs->frame(), objName.c_str());

      if (objName == trainingLabel)
        numMatches++;

      //printf("objid %i:class %i:rate=%0.2f\n",
      //    objData.description.c_str(), objData.id, cls,
      //    (float)numMatches/(float)totalObjects);
    }
  }

  if (train)
  {
    printf("Trained on %i objects\n", totalObjects);
    printf("Object in db %i\n" , vdb.numObjects());
    vdb.saveTo(std::string(vdbFile));
  } else {
    printf("Classification Rate: %i/%i %0.2f\n",
        numMatches, totalObjects,
        (float)numMatches/(float)totalObjects);
  }


}