int GaussianProcessNormal::precomputePrediction() { size_t n = mGPXX.size(); size_t p = mMean->nFeatures(); mKF = trans(mFeatM); inplace_solve(mL,mKF,ublas::lower_tag()); //TODO: make one line matrixd DD(p,p); DD = prod(trans(mKF),mKF); utils::addToDiagonal(DD,mInvVarW); utils::cholesky_decompose(DD,mD); vectord vn = mGPY; inplace_solve(mL,vn,ublas::lower_tag()); mWMap = prod(mFeatM,vn) + utils::ublas_elementwise_prod(mInvVarW,mW0); utils::cholesky_solve(mD,mWMap,ublas::lower()); mVf = mGPY - prod(trans(mFeatM),mWMap); inplace_solve(mL,mVf,ublas::lower_tag()); if (boost::math::isnan(mWMap(0))) { FILE_LOG(logERROR) << "Error in precomputed prediction. NaN found."; return -1; } return 0; }
void StudentTProcessNIG::precomputePrediction() { size_t n = mData.getNSamples(); size_t p = mMean.nFeatures(); mKF = trans(mMean.mFeatM); inplace_solve(mL,mKF,ublas::lower_tag()); //TODO: make one line matrixd DD(p,p); DD = prod(trans(mKF),mKF); utils::add_to_diagonal(DD,mInvVarW); utils::cholesky_decompose(DD,mD); vectord vn = mData.mY; inplace_solve(mL,vn,ublas::lower_tag()); mWMap = prod(mMean.mFeatM,vn) + utils::ublas_elementwise_prod(mInvVarW,mW0); utils::cholesky_solve(mD,mWMap,ublas::lower()); mVf = mData.mY - prod(trans(mMean.mFeatM),mWMap); inplace_solve(mL,mVf,ublas::lower_tag()); vectord v0 = mData.mY - prod(trans(mMean.mFeatM),mW0); //TODO: check for "cheaper" version //matrixd KK = prod(mL,trans(mL)); matrixd KK = computeCorrMatrix(); matrixd WW = zmatrixd(p,p); //TODO: diagonal matrix utils::add_to_diagonal(WW,mInvVarW); const matrixd FW = prod(trans(mMean.mFeatM),WW); KK += prod(FW,mMean.mFeatM); matrixd BB(n,n); utils::cholesky_decompose(KK,BB); inplace_solve(BB,v0,ublas::lower_tag()); mSigma = (mBeta/mAlpha + inner_prod(v0,v0))/(n+2*mAlpha); int dof = static_cast<int>(n+2*mAlpha); if ((boost::math::isnan(mWMap(0))) || (boost::math::isnan(mSigma))) { throw std::runtime_error("Error in precomputed prediction. NaN found."); } if (dof <= 0) { dof = n; FILE_LOG(logERROR) << "ERROR: Incorrect alpha. Dof invalid." << "Forcing Dof <= num of points."; } d_->setDof(dof); }