示例#1
0
    bool updateModule()
    {
        ++updateCount;
        
        if (updateCount > numPred)
        {
            cout << "Specified number of predictions reached. Shutting down the module." << endl;
            return false;
        }

        // DEBUG

        if(verbose) cout << "updateModule #" << updateCount << endl;


        // Recursive update support and storage variables declaration and initialization
        gMat2D<T> Xnew(1,d);
        gMat2D<T> ynew(1,t);
        gVec<T> Xnew_v(d);
        gVec<T> ynew_v(t);
        gMat2D<T> *resptr = 0;
        
        // Wait for input feature vector
        if(verbose) cout << "Expecting input vector" << endl;
        
        Bottle *bin = inVec.read();    // blocking call
        
        if (bin != 0)
        {
            if(verbose) cout << "Got it!" << endl << bin->toString() << endl;

            //Store the received sample in gMat2D format for it to be compatible with gurls++

            for (int i = 0 ; i < bin->size() ; ++i)
            {
                if ( i < d )
                {
                    Xnew(0,i) = bin->get(i).asDouble();
                }
                else if ( (i>=d) && (i<d+t) )
                {

                    ynew(0, i - d ) = bin->get(i).asDouble();
                }
            }
    
            if(verbose) cout << "Xnew: " << endl << Xnew << endl;
            if(verbose) cout << "ynew: " << endl << ynew.rows() << " x " << ynew.cols() << endl;
            if(verbose) cout<< ynew << endl;

            //-----------------------------------
            //          Prediction
            //-----------------------------------
            
            // Test on the incoming sample
            resptr = estimator.eval(Xnew);
            
            Bottle& bpred = pred.prepare(); // Get a place to store things.
            bpred.clear();  // clear is important - b might be a reused object

            for (int i = 0 ; i < t ; ++i)
            {
                bpred.addDouble((*resptr)(0 , i));
            }
            
            if(verbose) printf("Sending prediction!!! %s\n", bpred.toString().c_str());
            pred.write();
            if(verbose) printf("Prediction written to port\n");

            //----------------------------------
            // performance

            Bottle& bperf = perf.prepare(); // Get a place to store things.
            bperf.clear();  // clear is important - b might be a reused object
    
            if (perfType == "nMSE")     // WARNING: The estimated variance could be unreliable...
            {
                // Compute nMSE and store
                //NOTE: In GURLS, "/" operator works like matlab's "\".
                error += varCols / ( ynew - *resptr )*( ynew - *resptr ) ;   
                gMat2D<T> tmp = error  / (updateCount);   // WARNING: Check
                for (int i = 0 ; i < t ; ++i)
                {
                    bperf.addDouble(tmp(0 , i));
                }
            }
            else if (perfType == "RMSE")
            {
                gMat2D<T> tmp(1,t);
                tmp = ( ynew - *resptr )*( ynew - *resptr );
                
                //error = ( error * (updateCount-1) + sqrt(( ynew - *resptr )*( ynew - *resptr )) ) / updateCount;
                
                error = error * (updateCount-1);
                for (int i = 0 ; i < ynew.cols() ; ++i)
                    error(0,i) += sqrt(tmp(0,i));
                error = error / updateCount;
                
/*                for (int i = 0 ; i < t ; ++i)
                {
                    bperf.addDouble(sqrt(MSE(0 , i)));
                }      */

                // WARNING: Temporary avg RMSE computation
                
                bperf.addDouble( (error(0 , 0) + error(0 , 1) + error(0 , 2))/ 3.0);    // Average MSE on forces
                bperf.addDouble( (error(0 , 3) + error(0 , 4) + error(0 , 5))/ 3.0);    // Average MSE on torques
            }
            else if (perfType == "MSE")
            {
                //Compute MSE and store
                
                error = ( error * (updateCount-1) + ( ynew - *resptr )*( ynew - *resptr ) ) / updateCount;
                for (int i = 0 ; i < t ; ++i)
                {
                    bperf.addDouble(error(0 , i));
                }
                
            }
            
            // Error storage matrix management
            // Update error storage matrix
            if (updateCount <= savedPerfNum)
            {
                gVec<T> errRow = error[0];
                storedError.setRow( errRow, updateCount-1);
            }
            
            // Save to CSV file
            if (updateCount == savedPerfNum)    
            {
                
                std::ostringstream ss;
                ss << experimentCount;
                
                //string tmp(std::to_string(experimentCount));
                storedError.saveCSV("storedError" + ss.rdbuf()->str() + ".csv");
                cout << "Error measurement matrix saved." << endl;
            }
            
            // Write computed error to output port
            if(verbose) printf("Sending %s measurement: %s\n", perfType.c_str(), bperf.toString().c_str());
            perf.write();
            
            //-----------------------------------
            //             Update
            //-----------------------------------
                        
            // Update estimator with a new input pair
            //if(verbose) std::cout << "Update # " << i+1 << std::endl;
            if(verbose) cout << "Now performing RRLS update" << endl;            
            if(verbose) cout << "Xnew" << Xnew << endl;            
            if(verbose) cout << "ynew" << ynew << endl;            
            estimator.update(Xnew, ynew);
            if(verbose) cout << "Update completed" << endl;            
        }

        if ( numPred >=0 && (updateCount == numPred) )
        {
            cout << "Specified number of predictions reached. Shutting down the module." << endl;
            return false;
        }
        return true;
    }