bool updateModule() { ++updateCount; if (updateCount > numPred) { cout << "Specified number of predictions reached. Shutting down the module." << endl; return false; } // DEBUG if(verbose) cout << "updateModule #" << updateCount << endl; // Recursive update support and storage variables declaration and initialization gMat2D<T> Xnew(1,d); gMat2D<T> ynew(1,t); gVec<T> Xnew_v(d); gVec<T> ynew_v(t); gMat2D<T> *resptr = 0; // Wait for input feature vector if(verbose) cout << "Expecting input vector" << endl; Bottle *bin = inVec.read(); // blocking call if (bin != 0) { if(verbose) cout << "Got it!" << endl << bin->toString() << endl; //Store the received sample in gMat2D format for it to be compatible with gurls++ for (int i = 0 ; i < bin->size() ; ++i) { if ( i < d ) { Xnew(0,i) = bin->get(i).asDouble(); } else if ( (i>=d) && (i<d+t) ) { ynew(0, i - d ) = bin->get(i).asDouble(); } } if(verbose) cout << "Xnew: " << endl << Xnew << endl; if(verbose) cout << "ynew: " << endl << ynew.rows() << " x " << ynew.cols() << endl; if(verbose) cout<< ynew << endl; //----------------------------------- // Prediction //----------------------------------- // Test on the incoming sample resptr = estimator.eval(Xnew); Bottle& bpred = pred.prepare(); // Get a place to store things. bpred.clear(); // clear is important - b might be a reused object for (int i = 0 ; i < t ; ++i) { bpred.addDouble((*resptr)(0 , i)); } if(verbose) printf("Sending prediction!!! %s\n", bpred.toString().c_str()); pred.write(); if(verbose) printf("Prediction written to port\n"); //---------------------------------- // performance Bottle& bperf = perf.prepare(); // Get a place to store things. bperf.clear(); // clear is important - b might be a reused object if (perfType == "nMSE") // WARNING: The estimated variance could be unreliable... { // Compute nMSE and store //NOTE: In GURLS, "/" operator works like matlab's "\". error += varCols / ( ynew - *resptr )*( ynew - *resptr ) ; gMat2D<T> tmp = error / (updateCount); // WARNING: Check for (int i = 0 ; i < t ; ++i) { bperf.addDouble(tmp(0 , i)); } } else if (perfType == "RMSE") { gMat2D<T> tmp(1,t); tmp = ( ynew - *resptr )*( ynew - *resptr ); //error = ( error * (updateCount-1) + sqrt(( ynew - *resptr )*( ynew - *resptr )) ) / updateCount; error = error * (updateCount-1); for (int i = 0 ; i < ynew.cols() ; ++i) error(0,i) += sqrt(tmp(0,i)); error = error / updateCount; /* for (int i = 0 ; i < t ; ++i) { bperf.addDouble(sqrt(MSE(0 , i))); } */ // WARNING: Temporary avg RMSE computation bperf.addDouble( (error(0 , 0) + error(0 , 1) + error(0 , 2))/ 3.0); // Average MSE on forces bperf.addDouble( (error(0 , 3) + error(0 , 4) + error(0 , 5))/ 3.0); // Average MSE on torques } else if (perfType == "MSE") { //Compute MSE and store error = ( error * (updateCount-1) + ( ynew - *resptr )*( ynew - *resptr ) ) / updateCount; for (int i = 0 ; i < t ; ++i) { bperf.addDouble(error(0 , i)); } } // Error storage matrix management // Update error storage matrix if (updateCount <= savedPerfNum) { gVec<T> errRow = error[0]; storedError.setRow( errRow, updateCount-1); } // Save to CSV file if (updateCount == savedPerfNum) { std::ostringstream ss; ss << experimentCount; //string tmp(std::to_string(experimentCount)); storedError.saveCSV("storedError" + ss.rdbuf()->str() + ".csv"); cout << "Error measurement matrix saved." << endl; } // Write computed error to output port if(verbose) printf("Sending %s measurement: %s\n", perfType.c_str(), bperf.toString().c_str()); perf.write(); //----------------------------------- // Update //----------------------------------- // Update estimator with a new input pair //if(verbose) std::cout << "Update # " << i+1 << std::endl; if(verbose) cout << "Now performing RRLS update" << endl; if(verbose) cout << "Xnew" << Xnew << endl; if(verbose) cout << "ynew" << ynew << endl; estimator.update(Xnew, ynew); if(verbose) cout << "Update completed" << endl; } if ( numPred >=0 && (updateCount == numPred) ) { cout << "Specified number of predictions reached. Shutting down the module." << endl; return false; } return true; }