inline
typename detail::enable_if< is_floating_point< T >, T >::type
invert( mat< T, 1 >& m )
{
    if ( std::fabs( m.elem( 0 ) ) <= std::numeric_limits< T >::epsilon() )
    {
        return 0;
    }

    T r = m.elem( 0 );
    m.elem( 0 ) = static_cast< T >( 1 ) / m.elem( 0 );

    return r;
}
inline
typename detail::enable_if< is_floating_point< T >, T >::type
invert( mat< T, 2 >& m )
{
    T det = m.elem( 0 ) * m.elem( 3 ) - m.elem( 2 ) * m.elem( 1 );

    if ( std::fabs( det ) <= std::numeric_limits< T >::epsilon() )
    {
        return 0;
    }

    std::swap( m.elem( 0 ), m.elem( 3 ) );
    m.elem( 1 ) = -m.elem( 1 );
    m.elem( 2 ) = -m.elem( 2 );

    m /= det;

    return det;
}
inline
typename detail::enable_if< is_floating_point< T >, T >::type
invert( mat< T, 3 >& m )
{
    mat< T, 3 > res;

    res.elem( 0 ) = m.elem( 4 ) * m.elem( 8 ) - m.elem( 5 ) * m.elem( 7 );
    res.elem( 1 ) = m.elem( 2 ) * m.elem( 7 ) - m.elem( 1 ) * m.elem( 8 );
    res.elem( 2 ) = m.elem( 1 ) * m.elem( 5 ) - m.elem( 2 ) * m.elem( 4 );
    res.elem( 3 ) = m.elem( 5 ) * m.elem( 6 ) - m.elem( 3 ) * m.elem( 8 );
    res.elem( 4 ) = m.elem( 0 ) * m.elem( 8 ) - m.elem( 2 ) * m.elem( 6 );
    res.elem( 5 ) = m.elem( 2 ) * m.elem( 3 ) - m.elem( 0 ) * m.elem( 5 );
    res.elem( 6 ) = m.elem( 3 ) * m.elem( 7 ) - m.elem( 4 ) * m.elem( 6 );
    res.elem( 7 ) = m.elem( 1 ) * m.elem( 6 ) - m.elem( 0 ) * m.elem( 7 );
    res.elem( 8 ) = m.elem( 0 ) * m.elem( 4 ) - m.elem( 1 ) * m.elem( 3 );

    T d = m.elem( 0 ) * res.elem( 0 )
        + m.elem( 1 ) * res.elem( 3 )
        + m.elem( 2 ) * res.elem( 6 );

    if ( std::fabs( d ) <= std::numeric_limits< T >::epsilon() )
    {
        return 0;
    }

    T invd = static_cast< T >( 1 ) / d;
    m = res;
    m *= invd;

    return d;
}
inline
typename detail::enable_if< is_floating_point< T >, T >::type
invert( mat< T, 4 >& m )
{
    mat< T, 4 > res;

    T t1[ 6 ] =
        {
            m.elem(  2 ) * m.elem(  7 ) - m.elem(  6 ) * m.elem(  3 ),
            m.elem(  2 ) * m.elem( 11 ) - m.elem( 10 ) * m.elem(  3 ),
            m.elem(  2 ) * m.elem( 15 ) - m.elem( 14 ) * m.elem(  3 ),
            m.elem(  6 ) * m.elem( 11 ) - m.elem( 10 ) * m.elem(  7 ),
            m.elem(  6 ) * m.elem( 15 ) - m.elem( 14 ) * m.elem(  7 ),
            m.elem( 10 ) * m.elem( 15 ) - m.elem( 14 ) * m.elem( 11 )
        };

    res.elem( 0 ) = m.elem(  5 ) * t1[ 5 ] - m.elem(  9 ) * t1[ 4 ] + m.elem( 13 ) * t1[ 3 ];
    res.elem( 1 ) = m.elem(  9 ) * t1[ 2 ] - m.elem( 13 ) * t1[ 1 ] - m.elem(  1 ) * t1[ 5 ];
    res.elem( 2 ) = m.elem( 13 ) * t1[ 0 ] - m.elem(  5 ) * t1[ 2 ] + m.elem(  1 ) * t1[ 4 ];
    res.elem( 3 ) = m.elem(  5 ) * t1[ 1 ] - m.elem(  1 ) * t1[ 3 ] - m.elem(  9 ) * t1[ 0 ];
    res.elem( 4 ) = m.elem(  8 ) * t1[ 4 ] - m.elem(  4 ) * t1[ 5 ] - m.elem( 12 ) * t1[ 3 ];
    res.elem( 5 ) = m.elem(  0 ) * t1[ 5 ] - m.elem(  8 ) * t1[ 2 ] + m.elem( 12 ) * t1[ 1 ];
    res.elem( 6 ) = m.elem(  4 ) * t1[ 2 ] - m.elem( 12 ) * t1[ 0 ] - m.elem(  0 ) * t1[ 4 ];
    res.elem( 7 ) = m.elem(  0 ) * t1[ 3 ] - m.elem(  4 ) * t1[ 1 ] + m.elem(  8 ) * t1[ 0 ];


    T t2[ 6 ] =
        {
            m.elem(  0 ) * m.elem(  5 ) - m.elem(  4 ) * m.elem(  1 ),
            m.elem(  0 ) * m.elem(  9 ) - m.elem(  8 ) * m.elem(  1 ),
            m.elem(  0 ) * m.elem( 13 ) - m.elem( 12 ) * m.elem(  1 ),
            m.elem(  4 ) * m.elem(  9 ) - m.elem(  8 ) * m.elem(  5 ),
            m.elem(  4 ) * m.elem( 13 ) - m.elem( 12 ) * m.elem(  5 ),
            m.elem(  8 ) * m.elem( 13 ) - m.elem( 12 ) * m.elem(  9 )
        };

    res.elem( 8 )  = m.elem(  7 ) * t2[ 5 ] - m.elem( 11 ) * t2[ 4 ] + m.elem( 15 ) * t2[ 3 ];
    res.elem( 9 )  = m.elem( 11 ) * t2[ 2 ] - m.elem( 15 ) * t2[ 1 ] - m.elem(  3 ) * t2[ 5 ];
    res.elem( 10 ) = m.elem( 15 ) * t2[ 0 ] - m.elem(  7 ) * t2[ 2 ] + m.elem(  3 ) * t2[ 4 ];
    res.elem( 11 ) = m.elem(  7 ) * t2[ 1 ] - m.elem(  3 ) * t2[ 3 ] - m.elem( 11 ) * t2[ 0 ];
    res.elem( 12 ) = m.elem( 10 ) * t2[ 4 ] - m.elem(  6 ) * t2[ 5 ] - m.elem( 14 ) * t2[ 3 ];
    res.elem( 13 ) = m.elem(  2 ) * t2[ 5 ] - m.elem( 10 ) * t2[ 2 ] + m.elem( 14 ) * t2[ 1 ];
    res.elem( 14 ) = m.elem(  6 ) * t2[ 2 ] - m.elem( 14 ) * t2[ 0 ] - m.elem(  2 ) * t2[ 4 ];
    res.elem( 15 ) = m.elem(  2 ) * t2[ 3 ] - m.elem(  6 ) * t2[ 1 ] + m.elem( 10 ) * t2[ 0 ];

    T d =
        m.elem( 0 ) * res.elem( 0 ) + m.elem( 4 ) * res.elem( 1 ) +
        m.elem( 8 ) * res.elem( 2 ) + m.elem( 12 ) * res.elem( 3 );

    if ( std::fabs( d ) <= std::numeric_limits< T >::epsilon() )
    {
        return 0;
    }

    T invd = static_cast< T >( 1 ) / d;
    m = res;
    m *= invd;

    return d;
}
Example #5
0
//[[Rcpp::export]]
Rcpp::List nnmf(const mat & A, const unsigned int k, mat W, mat H, umat Wm, umat Hm,
	const vec & alpha, const vec & beta, const unsigned int max_iter, const double rel_tol, 
	const int n_threads, const int verbose, const bool show_warning, const unsigned int inner_max_iter, 
	const double inner_rel_tol, const int method, unsigned int trace)
{
	/******************************************************************************************************
	 *              Non-negative Matrix Factorization(NNMF) using alternating scheme
	 *              ----------------------------------------------------------------
	 * Description:
	 * 	Decompose matrix A such that
	 * 		A = W H
	 * Arguments:
	 * 	A              : Matrix to be decomposed
	 * 	W, H           : Initial matrices of W and H, where ncol(W) = nrow(H) = k. # of rows/columns of W/H could be 0
	 * 	Wm, Hm         : Masks of W and H, s.t. masked entries are no-updated and fixed to initial values
	 * 	alpha          : [L2, angle, L1] regularization on W (non-masked entries)
	 * 	beta           : [L2, angle, L1] regularization on H (non-masked entries)
	 * 	max_iter       : Maximum number of iteration
	 * 	rel_tol        : Relative tolerance between two successive iterations, = |e2-e1|/avg(e1, e2)
	 * 	n_threads      : Number of threads (openMP)
	 * 	verbose        : Either 0 = no any tracking, 1 == progression bar, 2 == print iteration info
	 * 	show_warning   : If to show warning if targeted `tol` is not reached
	 * 	inner_max_iter : Maximum number of iterations passed to each inner W or H matrix updating loop
	 * 	inner_rel_tol  : Relative tolerance passed to inner W or H matrix updating loop, = |e2-e1|/avg(e1, e2)
	 * 	method         : Integer of 1, 2, 3 or 4, which encodes methods
	 * 	               : 1 = sequential coordinate-wise minimization using square loss
	 * 	               : 2 = Lee's multiplicative update with square loss, which is re-scaled gradient descent
	 * 	               : 3 = sequentially quadratic approximated minimization with KL-divergence
	 * 	               : 4 = Lee's multiplicative update with KL-divergence, which is re-scaled gradient descent
	 * 	trace          : A positive integer, error will be checked very 'trace' iterations. Computing WH can be very expansive,
	 * 	               : so one may not want to check error A-WH every single iteration
	 * Return:
	 * 	A list (Rcpp::List) of 
	 * 		W, H          : resulting W and H matrices
	 * 		mse_error     : a vector of mean square error (divided by number of non-missings)
	 * 		mkl_error     : a vector (length = number of iterations) of mean KL-distance
	 * 		target_error  : a vector of loss (0.5*mse or mkl), plus constraints
	 * 		average_epoch : a vector of average epochs (one complete swap over W and H)
	 * Author:
	 * 	Eric Xihui Lin <*****@*****.**>
	 * Version:
	 * 	2015-12-11
	 ******************************************************************************************************/

	unsigned int n = A.n_rows;
	unsigned int m = A.n_cols;
	//int k = H.n_rows; // decomposition rank k
	unsigned int N_non_missing = n*m;

	if (trace < 1) trace = 1;
	unsigned int err_len = (unsigned int)std::ceil(double(max_iter)/double(trace)) + 1;
	vec mse_err(err_len), mkl_err(err_len), terr(err_len), ave_epoch(err_len);

	// check progression
	bool show_progress = false;
	if (verbose == 1) show_progress = true;
	Progress prgrss(max_iter, show_progress);

	double rel_err = rel_tol + 1;
	double terr_last = 1e99;
	uvec non_missing;
	bool any_missing = !A.is_finite();
	if (any_missing) 
	{
		non_missing = find_finite(A);
		N_non_missing = non_missing.n_elem;
		mkl_err.fill(mean((A.elem(non_missing)+TINY_NUM) % log(A.elem(non_missing)+TINY_NUM) - A.elem(non_missing)));
	}
	else
		mkl_err.fill(mean(mean((A+TINY_NUM) % log(A+TINY_NUM) - A))); // fixed part in KL-dist, mean(A log(A) - A)

	if (Wm.empty())
		Wm.resize(0, n);
	else
		inplace_trans(Wm);
	if (Hm.empty())
		Hm.resize(0, m);

	if (W.empty())
	{
		W.randu(k, n);
		W *= 0.01;
		if (!Wm.empty())
			W.elem(find(Wm > 0)).fill(0.0);
	}
	else
		inplace_trans(W);

	if (H.empty())
	{
		H.randu(k, m);
		H *= 0.01;
		if (!Hm.empty())
			H.elem(find(Hm > 0)).fill(0.0);
	}

	if (verbose == 2)
	{
		Rprintf("\n%10s | %10s | %10s | %10s | %10s\n", "Iteration", "MSE", "MKL", "Target", "Rel. Err.");
		Rprintf("--------------------------------------------------------------\n");
	}

	int total_raw_iter = 0;
	unsigned int i = 0;
	unsigned int i_e = 0; // index for error checking
	for(; i < max_iter && std::abs(rel_err) > rel_tol; i++) 
	{
		Rcpp::checkUserInterrupt();
		prgrss.increment();

		if (any_missing)
		{
			// update W
			total_raw_iter += update_with_missing(W, H, A.t(), Wm, alpha, inner_max_iter, inner_rel_tol, n_threads, method);
			// update H
			total_raw_iter += update_with_missing(H, W, A, Hm, beta, inner_max_iter, inner_rel_tol, n_threads, method);

			if (i % trace == 0)
			{
				const mat & Ahat = W.t()*H;
				mse_err(i_e) = mean(square((A - Ahat).eval().elem(non_missing)));
				mkl_err(i_e) += mean((-(A+TINY_NUM) % log(Ahat+TINY_NUM) + Ahat).eval().elem(non_missing));
			}
		}
		else
		{
			// update W
			total_raw_iter += update(W, H, A.t(), Wm, alpha, inner_max_iter, inner_rel_tol, n_threads, method);
			// update H
			total_raw_iter += update(H, W, A, Hm, beta, inner_max_iter, inner_rel_tol, n_threads, method);

			if (i % trace == 0)
			{
				const mat & Ahat = W.t()*H;
				mse_err(i_e) = mean(mean(square((A - Ahat))));
				mkl_err(i_e) += mean(mean(-(A+TINY_NUM) % log(Ahat+TINY_NUM) + Ahat));
			}
		}

		if (i % trace == 0)
		{
			ave_epoch(i_e) = double(total_raw_iter)/(n+m);
			if (method < 3) // mse based
				terr(i_e) = 0.5*mse_err(i_e);
			else // KL based
				terr(i_e) = mkl_err(i_e);

			add_penalty(i_e, terr, W, H, N_non_missing, alpha, beta);

			rel_err = 2*(terr_last - terr(i_e)) / (terr_last + terr(i_e) + TINY_NUM );
			terr_last = terr(i_e);
			if (verbose == 2)
				Rprintf("%10d | %10.4f | %10.4f | %10.4f | %10.g\n", i+1, mse_err(i_e), mkl_err(i_e), terr(i_e), rel_err);

			total_raw_iter = 0; // reset to 0
			++i_e;
		}
	}

	// compute error of the last iteration
	if ((i-1) % trace != 0)
	{
		if (any_missing)
		{
			const mat & Ahat = W.t()*H;
			mse_err(i_e) = mean(square((A - Ahat).eval().elem(non_missing)));
			mkl_err(i_e) += mean((-(A+TINY_NUM) % log(Ahat+TINY_NUM) + Ahat).eval().elem(non_missing));
		}
		else
		{
			const mat & Ahat = W.t()*H;
			mse_err(i_e) = mean(mean(square((A - Ahat))));
			mkl_err(i_e) += mean(mean(-(A+TINY_NUM) % log(Ahat+TINY_NUM) + Ahat));
		}

		ave_epoch(i_e) = double(total_raw_iter)/(n+m);
		if (method < 3) // mse based
			terr(i_e) = 0.5*mse_err(i_e);
		else // KL based
			terr(i_e) = mkl_err(i_e);
		add_penalty(i_e, terr, W, H, N_non_missing, alpha, beta);

		rel_err = 2*(terr_last - terr(i_e)) / (terr_last + terr(i_e) + TINY_NUM );
		terr_last = terr(i_e);
		if (verbose == 2)
			Rprintf("%10d | %10.4f | %10.4f | %10.4f | %10.g\n", i+1, mse_err(i_e), mkl_err(i_e), terr(i_e), rel_err);

		++i_e;
	}

	if (verbose == 2)
	{
		Rprintf("--------------------------------------------------------------\n");
		Rprintf("%10s | %10s | %10s | %10s | %10s\n\n", "Iteration", "MSE", "MKL", "Target", "Rel. Err.");
	}

	if (i_e < err_len)
	{
		mse_err.resize(i_e);
		mkl_err.resize(i_e);
		terr.resize(i_e);
		ave_epoch.resize(i_e);
	}

	if (show_warning && rel_err > rel_tol)
		Rcpp::warning("Target tolerance not reached. Try a larger max.iter.");

	return Rcpp::List::create(
		Rcpp::Named("W") = W.t(),
		Rcpp::Named("H") = H,
		Rcpp::Named("mse_error") = mse_err,
		Rcpp::Named("mkl_error") = mkl_err,
		Rcpp::Named("target_error") = terr,
		Rcpp::Named("average_epoch") = ave_epoch,
		Rcpp::Named("n_iteration") = i
		);
}