Example #1
0
mat nnls_solver(const mat & H, mat mu, const umat & mask, int max_iter, double rel_tol, int n_threads)
{
	/****************************************************************************************************
	 * Description: sequential Coordinate-wise algorithm for non-negative least square regression problem
	 * 		A x = b, s.t. x[!m] >= 0, x[m] == 0
	 * Arguments:
	 * 	H         : A^T * A
	 * 	mu        : -A^T * b
	 * 	mask      : a mask matrix (m) of same dim of x
	 * 	max_iter  : maximum number of iterations
	 * 	rel_tol   : stop criterion, minimum change on x between two successive iteration
	 * 	n_threads : number of threads
	 * Return:
	 * 	x : solution to argmin_{x, x>=0} ||Ax - b||_F^2
	 * Reference:
	 * 	http://cmp.felk.cvut.cz/ftp/articles/franc/Franc-TR-2005-06.pdf
	 * Author:
	 * 	Eric Xihui Lin <*****@*****.**>
	 * Version:
	 * 	2015-11-16
	 ****************************************************************************************************/

	mat x(H.n_cols, mu.n_cols, fill::zeros);
	if (n_threads < 0) n_threads = 0;
	bool is_masked = !mask.empty();


	#pragma omp parallel for num_threads(n_threads) schedule(dynamic)
	for (int j = 0; j < mu.n_cols; j++)
	{
		if (is_masked && arma::all(mask.col(j))) 
			continue;
		vec x0(H.n_cols);
		// x0.fill(-9999);
		double tmp;
		int i = 0;
		double err1, err2 = 9999;
		do {
			// break if all entries of col_j are masked
			x0 = x.col(j);
			err1 = err2;
			err2 = 0;
			for (int k = 0; k < H.n_cols; k++)
			{
				if (is_masked && mask(k,j) > 0) continue;
				tmp = x(k,j) - mu(k,j) / H(k,k);
				if (tmp < 0) tmp = 0;
				if (tmp != x(k,j))
				{
					mu.col(j) += (tmp - x(k,j)) * H.col(k);
				}
				x(k,j) = tmp;
				tmp = std::abs(x(k,j) - x0(k));
				if (tmp > err2) err2 = tmp;
			}
		} while(++i < max_iter && std::abs(err1 - err2) / (err1 + 1e-9) > rel_tol);
	}
	return x;
}
Example #2
0
mat concentrate_step(mat x, umat x_nonmiss, vec pu, vec pmd, uvec x_miss_group_match,
    umat miss_group_unique, uvec miss_group_counts,
    umat miss_group_obs_col, umat miss_group_mis_col, uvec miss_group_p, int miss_group_n,
    int n, int n_half, int p, vec theta0, mat G, int d, int EM_maxits,
    double* xh_mem, unsigned  int* misgrpuqh_mem, unsigned  int* msgrpcth_mem,
    unsigned int* msgrpoch_mem, unsigned int* msgrpmch_mem, unsigned int* msgrpph_mem)
{
    // information required for subsampling
    mat x_half(xh_mem, n_half, p, false, true);
    umat miss_grp_uq_half( misgrpuqh_mem, miss_group_n, p, false, true); miss_grp_uq_half.zeros();
    uvec miss_grp_cts_half( msgrpcth_mem, miss_group_n, false, true); miss_grp_cts_half.zeros();
    umat miss_grp_oc_half( msgrpoch_mem, miss_group_n, p, false, true); miss_grp_oc_half.zeros();
    umat miss_grp_mc_half( msgrpmch_mem, miss_group_n, p, false, true); miss_grp_mc_half.zeros();
    uvec miss_grp_p_half( msgrpph_mem, n_half, false, true);miss_grp_p_half.zeros();
    int miss_grp_n_half;

    // compute adjusted pmd
    vec pmd_adj(n);
  	for(int i = 0; i < n; i++) pmd_adj(i) = R::pchisq(pmd(i), (double) pu(i), 1, 0);
	  double pmd_adj_med = median(pmd_adj);
	  uvec x_half_ind = find( pmd_adj <= pmd_adj_med, n_half);

    // find new missing pattern based on the halved samples
    int miss_new_add = -1;
    for(int i = 0; i < n_half; i++){
        int index = (int) x_half_ind(i);
        x_half.row(i) = x.row( index );

        // check whether we found a new missing pattern
        urowvec check_new_miss(p);
        if( i == 0 ){ check_new_miss.zeros();
        } else{  check_new_miss = (x_nonmiss.row( index ) == miss_grp_uq_half.row( miss_new_add )); }

        if( sum(check_new_miss)== (unsigned int) p){
            // if found same missing pattern
            miss_grp_cts_half( miss_new_add )++;
        }else{
            // if found a new missing pattern
            miss_grp_uq_half.row( miss_new_add + 1) = x_nonmiss.row( index );
            miss_grp_cts_half( miss_new_add + 1)++;
            miss_grp_oc_half.row( miss_new_add + 1) = miss_group_obs_col.row( x_miss_group_match( index)-1 );
            miss_grp_mc_half.row( miss_new_add + 1) = miss_group_mis_col.row( x_miss_group_match( index)-1 );
            miss_grp_p_half( miss_new_add + 1) = miss_group_p( x_miss_group_match( index)-1);
            miss_new_add++;
        }

    }
    miss_grp_n_half = miss_new_add + 1;

    mat res = CovEM(x_half, n_half, p, theta0, G, d, miss_grp_uq_half, miss_grp_cts_half,
        miss_grp_oc_half, miss_grp_mc_half, miss_grp_p_half, miss_grp_n_half,
        0.0001, EM_maxits);
    res.shed_rows(0,1);

    return res;
}
Example #3
0
//' Rank elements within column of a matrix
//'
//' This function returns the rank of each element within each column of a
//' matrix. The highest element receives the highest rank.
//'
//' @param sortedIdx is the input matrix
//' @return a rank matrix
// [[Rcpp::export]]
umat rankIndex(const umat& sortedIdx) {
    int N = sortedIdx.n_rows;
    int M = sortedIdx.n_cols;
    umat rankedIdx(N,M);
    for(int iX=0; iX<N; iX++) {
        for(int jX=0; jX<M; jX++) {
            rankedIdx.at(sortedIdx.at(iX,jX), jX) = iX;
        }
    }
    return rankedIdx;
}
Example #4
0
void subsampling(double* subsample_mem, unsigned int* subsamp_nonmis_mem,
    mat x, umat x_nonmiss, int nSubsampleSize, int p, uvec subsample_id)
{
    mat subsamp( subsample_mem, nSubsampleSize, p, false, true);
    umat subsamp_nonmiss( subsamp_nonmis_mem, nSubsampleSize, p, false, true);
    uvec subsample_id_shuff = shuffle(subsample_id);
    for(int i = 0; i < nSubsampleSize; i++)
    {
        subsamp.row(i) = x.row( subsample_id_shuff(i) );
        subsamp_nonmiss.row(i) = x_nonmiss.row( subsample_id_shuff(i) );
    }
}
rowvec star(const rowvec& vphi, const umat& vxicat){
	rowvec vphis = vphi;
	umat::const_iterator iter = vxicat.begin();
	int len = vxicat.n_elem;
	int L = vphi.n_elem/2/len;
	for(int i=0; i< len; i++){
		for(int l=(*iter);l<L;l++){
			vphis[l*len+i] = 0;
			vphis[(l+L)*len+i] = 0;
		}
		iter++;
	}
	return vphis;
}
void phi_design(mat& U, const uvec& sti, const int L, umat& vxicat, cube& wZ, const uvec& WTIME){
	int P=U.n_cols, wT = WTIME.n_elem, t;
	int Psq=P*P;
	wZ.zeros();
	for(int wt=0; wt<wT; wt++){
		t = WTIME[wt];
		for(int l=1;l<=L;l++){
			if(sti[t-l]){
				for(int p=0;p<P;p++){
					uvec cond= (vxicat.col(p) >= l);
					wZ.slice(wt).col(p).subvec((l-1)*Psq+ p*P,(l-1)*Psq+(p+1)*P-1) = (cond)%trans(U.row(t-l));
				}
				wZ.slice(wt).rows((L+l-1)*Psq,(L+l)*Psq-1).zeros();
			}else{
				wZ.slice(wt).rows((l-1)*Psq,l*Psq-1).zeros();
				for(int p=0;p<P;p++){
					uvec cond= (vxicat.col(p) >= l);
					wZ.slice(wt).col(p).subvec((L+l-1)*Psq+p*P,(L+l-1)*Psq+(p+1)*P-1)=cond%trans(U.row(t-l));
				}
			}
		}
	}

}
Example #7
0
void assignWinners(mat bids, rowvec prices, umat & assignments) {
	uword
		winnerIdx = 0,
		nItems = prices.size();
	double winningBid;

	for(int item = 0; item < nItems; item++) {
		vec winner = getMaxItemBid(item, bids);
		winnerIdx = winner(0);
		winningBid = winner(1);
		if(winningBid < 0.0)
			continue;
		prices(item) += winningBid;
		assignments.col(item).fill(0);
		assignments(winnerIdx, item) = 1;
	}
}
Example #8
0
// function used internally, which computes lasso fits for subsets containing a
// small number of observations (typically only 3) and returns the indices of
// the respective h observations with the smallest absolute residuals
umat sparseSubsets(const mat& x, const vec& y, const double& lambda,
    const uword& h, const umat& subsets, const bool& normalize,
    const bool& useIntercept, const double& eps, const bool& useGram) {
	const uword nsamp = subsets.n_cols;
	umat indices(h, nsamp);
	for(uword k = 0; k < nsamp; k++) {
		// compute lasso fit
		double intercept, crit;
		vec coefficients, residuals;
    fastLasso(x, y, lambda, true, subsets.unsafe_col(k), normalize,
        useIntercept, eps, useGram, false, intercept, coefficients,
        residuals, crit);
		// find h observations with smallest absolute residuals
		indices.col(k) = findSmallest(abs(residuals), h);
	}
	return indices;
}
Example #9
0
//' C++ wrapper for Gale-Shapley Algorithm
//'
//' This function provides an R wrapper for the C++ backend. Users should not
//' call this function directly and instead use
//' \code{\link{galeShapley.marriageMarket}} or
//' \code{\link{galeShapley.collegeAdmissions}}.
//'
//' @param proposerPref is a matrix with the preference order of the proposing
//'   side of the market. If there are \code{n} proposers and \code{m} reviewers
//'   in the market, then this matrix will be of dimension \code{m} by \code{n}.
//'   The \code{i,j}th element refers to \code{j}'s \code{i}th most favorite
//'   partner. Preference orders must be complete and specified using C++
//'   indexing (starting at 0).
//' @param reviewerUtils is a matrix with cardinal utilities of the courted side
//'   of the market. If there are \code{n} proposers and \code{m} reviewers, then
//'   this matrix will be of dimension \code{n} by \code{m}. The \code{i,j}th
//'   element refers to the payoff that individual \code{j} receives from being
//'   matched to individual \code{i}.
//'  @return  A list with elements that specify who is matched to whom. Suppose
//'    there are \code{n} proposers and \code{m} reviewers. The list contains
//'    the following items:
//'   \itemize{
//'    \item{\code{proposals} is a vector of length \code{n} whose \code{i}th
//'    element contains the number of the reviewer that proposer \code{i} is
//'    matched to using C++ indexing. Proposers that remain unmatched will be
//'    listed as being matched to \code{m}.}
//'    \item{\code{engagements} is a vector of length \code{m} whose \code{j}th
//'    element contains the number of the proposer that reviewer \code{j} is
//'    matched to using C++ indexing. Reviwers that remain unmatched will be
//'    listed as being matched to \code{n}.}
//'   }
// [[Rcpp::export]]
List cpp_wrapper_galeshapley(const umat& proposerPref, const mat& reviewerUtils) {

    // number of proposers (men)
    int M = proposerPref.n_cols;
    
    // number of reviewers (women)
    int N = proposerPref.n_rows;
    
    // initialize engagements, proposals
    vec engagements(N), proposals(M);
    
    // create an integer queue of bachelors 
    // the idea of using queues for this problem is borrowed from
    // http://rosettacode.org/wiki/Stable_marriage_problem#C.2B.2B
    queue<int> bachelors;
    
    // set all proposals to N (aka no proposals)
    proposals.fill(N);
    
    // set all engagements to M (aka no engagements)
    engagements.fill(M);
    
    // every proposer starts out as a bachelor
    for(int iX=M-1; iX >= 0; iX--) {
        bachelors.push(iX);
    }

    // loop until there are no more proposals to be made
    while (!bachelors.empty()) {
        
        // get the index of the proposer
        int proposer = bachelors.front();
        
        // get the proposer's preferences: we use a raw pointer to the memory 
        // used by the column `proposer` for performance reasons (this is to avoid
        // making a copy of the proposers vector of preferences)
        const uword * proposerPrefcol = proposerPref.colptr(proposer);
        
        // find the best available match for proposer
        for(int jX=0; jX<N; jX++) {
        
            // get the index of the reviewer that the proposer is interested in
            // by dereferencing the pointer; increment the pointer after use (not its value)
            const uword wX = *proposerPrefcol++;
        
            // check if wX is available (`M` means unmatched)
            if(engagements(wX)==M) {
        
                // if available, then form a match
                engagements(wX) = proposer;
                proposals(proposer) = wX;
        
                // go to the next proposer
                break;
            }
          
            // wX is already matched, let's see if wX can be poached
            if(reviewerUtils(proposer, wX) > reviewerUtils(engagements(wX), wX)) {
          
                // wX's previous partner becomes unmatched (`N` means unmatched)
                proposals(engagements(wX)) = N;
                bachelors.push(engagements(wX));
          
                // proposer and wX form a match
                engagements(wX) = proposer;
                proposals(proposer) = wX;
          
                // go to the next proposer
                break;
            }
        }
        
        // remove proposer from bachelor queue: proposer will remain unmatched
        bachelors.pop();
    }

    return List::create(
      _["proposals"]   = proposals,
      _["engagements"] = engagements);
}
Example #10
0
//[[Rcpp::export]]
Rcpp::List nnmf(const mat & A, const unsigned int k, mat W, mat H, umat Wm, umat Hm,
	const vec & alpha, const vec & beta, const unsigned int max_iter, const double rel_tol, 
	const int n_threads, const int verbose, const bool show_warning, const unsigned int inner_max_iter, 
	const double inner_rel_tol, const int method, unsigned int trace)
{
	/******************************************************************************************************
	 *              Non-negative Matrix Factorization(NNMF) using alternating scheme
	 *              ----------------------------------------------------------------
	 * Description:
	 * 	Decompose matrix A such that
	 * 		A = W H
	 * Arguments:
	 * 	A              : Matrix to be decomposed
	 * 	W, H           : Initial matrices of W and H, where ncol(W) = nrow(H) = k. # of rows/columns of W/H could be 0
	 * 	Wm, Hm         : Masks of W and H, s.t. masked entries are no-updated and fixed to initial values
	 * 	alpha          : [L2, angle, L1] regularization on W (non-masked entries)
	 * 	beta           : [L2, angle, L1] regularization on H (non-masked entries)
	 * 	max_iter       : Maximum number of iteration
	 * 	rel_tol        : Relative tolerance between two successive iterations, = |e2-e1|/avg(e1, e2)
	 * 	n_threads      : Number of threads (openMP)
	 * 	verbose        : Either 0 = no any tracking, 1 == progression bar, 2 == print iteration info
	 * 	show_warning   : If to show warning if targeted `tol` is not reached
	 * 	inner_max_iter : Maximum number of iterations passed to each inner W or H matrix updating loop
	 * 	inner_rel_tol  : Relative tolerance passed to inner W or H matrix updating loop, = |e2-e1|/avg(e1, e2)
	 * 	method         : Integer of 1, 2, 3 or 4, which encodes methods
	 * 	               : 1 = sequential coordinate-wise minimization using square loss
	 * 	               : 2 = Lee's multiplicative update with square loss, which is re-scaled gradient descent
	 * 	               : 3 = sequentially quadratic approximated minimization with KL-divergence
	 * 	               : 4 = Lee's multiplicative update with KL-divergence, which is re-scaled gradient descent
	 * 	trace          : A positive integer, error will be checked very 'trace' iterations. Computing WH can be very expansive,
	 * 	               : so one may not want to check error A-WH every single iteration
	 * Return:
	 * 	A list (Rcpp::List) of 
	 * 		W, H          : resulting W and H matrices
	 * 		mse_error     : a vector of mean square error (divided by number of non-missings)
	 * 		mkl_error     : a vector (length = number of iterations) of mean KL-distance
	 * 		target_error  : a vector of loss (0.5*mse or mkl), plus constraints
	 * 		average_epoch : a vector of average epochs (one complete swap over W and H)
	 * Author:
	 * 	Eric Xihui Lin <*****@*****.**>
	 * Version:
	 * 	2015-12-11
	 ******************************************************************************************************/

	unsigned int n = A.n_rows;
	unsigned int m = A.n_cols;
	//int k = H.n_rows; // decomposition rank k
	unsigned int N_non_missing = n*m;

	if (trace < 1) trace = 1;
	unsigned int err_len = (unsigned int)std::ceil(double(max_iter)/double(trace)) + 1;
	vec mse_err(err_len), mkl_err(err_len), terr(err_len), ave_epoch(err_len);

	// check progression
	bool show_progress = false;
	if (verbose == 1) show_progress = true;
	Progress prgrss(max_iter, show_progress);

	double rel_err = rel_tol + 1;
	double terr_last = 1e99;
	uvec non_missing;
	bool any_missing = !A.is_finite();
	if (any_missing) 
	{
		non_missing = find_finite(A);
		N_non_missing = non_missing.n_elem;
		mkl_err.fill(mean((A.elem(non_missing)+TINY_NUM) % log(A.elem(non_missing)+TINY_NUM) - A.elem(non_missing)));
	}
	else
		mkl_err.fill(mean(mean((A+TINY_NUM) % log(A+TINY_NUM) - A))); // fixed part in KL-dist, mean(A log(A) - A)

	if (Wm.empty())
		Wm.resize(0, n);
	else
		inplace_trans(Wm);
	if (Hm.empty())
		Hm.resize(0, m);

	if (W.empty())
	{
		W.randu(k, n);
		W *= 0.01;
		if (!Wm.empty())
			W.elem(find(Wm > 0)).fill(0.0);
	}
	else
		inplace_trans(W);

	if (H.empty())
	{
		H.randu(k, m);
		H *= 0.01;
		if (!Hm.empty())
			H.elem(find(Hm > 0)).fill(0.0);
	}

	if (verbose == 2)
	{
		Rprintf("\n%10s | %10s | %10s | %10s | %10s\n", "Iteration", "MSE", "MKL", "Target", "Rel. Err.");
		Rprintf("--------------------------------------------------------------\n");
	}

	int total_raw_iter = 0;
	unsigned int i = 0;
	unsigned int i_e = 0; // index for error checking
	for(; i < max_iter && std::abs(rel_err) > rel_tol; i++) 
	{
		Rcpp::checkUserInterrupt();
		prgrss.increment();

		if (any_missing)
		{
			// update W
			total_raw_iter += update_with_missing(W, H, A.t(), Wm, alpha, inner_max_iter, inner_rel_tol, n_threads, method);
			// update H
			total_raw_iter += update_with_missing(H, W, A, Hm, beta, inner_max_iter, inner_rel_tol, n_threads, method);

			if (i % trace == 0)
			{
				const mat & Ahat = W.t()*H;
				mse_err(i_e) = mean(square((A - Ahat).eval().elem(non_missing)));
				mkl_err(i_e) += mean((-(A+TINY_NUM) % log(Ahat+TINY_NUM) + Ahat).eval().elem(non_missing));
			}
		}
		else
		{
			// update W
			total_raw_iter += update(W, H, A.t(), Wm, alpha, inner_max_iter, inner_rel_tol, n_threads, method);
			// update H
			total_raw_iter += update(H, W, A, Hm, beta, inner_max_iter, inner_rel_tol, n_threads, method);

			if (i % trace == 0)
			{
				const mat & Ahat = W.t()*H;
				mse_err(i_e) = mean(mean(square((A - Ahat))));
				mkl_err(i_e) += mean(mean(-(A+TINY_NUM) % log(Ahat+TINY_NUM) + Ahat));
			}
		}

		if (i % trace == 0)
		{
			ave_epoch(i_e) = double(total_raw_iter)/(n+m);
			if (method < 3) // mse based
				terr(i_e) = 0.5*mse_err(i_e);
			else // KL based
				terr(i_e) = mkl_err(i_e);

			add_penalty(i_e, terr, W, H, N_non_missing, alpha, beta);

			rel_err = 2*(terr_last - terr(i_e)) / (terr_last + terr(i_e) + TINY_NUM );
			terr_last = terr(i_e);
			if (verbose == 2)
				Rprintf("%10d | %10.4f | %10.4f | %10.4f | %10.g\n", i+1, mse_err(i_e), mkl_err(i_e), terr(i_e), rel_err);

			total_raw_iter = 0; // reset to 0
			++i_e;
		}
	}

	// compute error of the last iteration
	if ((i-1) % trace != 0)
	{
		if (any_missing)
		{
			const mat & Ahat = W.t()*H;
			mse_err(i_e) = mean(square((A - Ahat).eval().elem(non_missing)));
			mkl_err(i_e) += mean((-(A+TINY_NUM) % log(Ahat+TINY_NUM) + Ahat).eval().elem(non_missing));
		}
		else
		{
			const mat & Ahat = W.t()*H;
			mse_err(i_e) = mean(mean(square((A - Ahat))));
			mkl_err(i_e) += mean(mean(-(A+TINY_NUM) % log(Ahat+TINY_NUM) + Ahat));
		}

		ave_epoch(i_e) = double(total_raw_iter)/(n+m);
		if (method < 3) // mse based
			terr(i_e) = 0.5*mse_err(i_e);
		else // KL based
			terr(i_e) = mkl_err(i_e);
		add_penalty(i_e, terr, W, H, N_non_missing, alpha, beta);

		rel_err = 2*(terr_last - terr(i_e)) / (terr_last + terr(i_e) + TINY_NUM );
		terr_last = terr(i_e);
		if (verbose == 2)
			Rprintf("%10d | %10.4f | %10.4f | %10.4f | %10.g\n", i+1, mse_err(i_e), mkl_err(i_e), terr(i_e), rel_err);

		++i_e;
	}

	if (verbose == 2)
	{
		Rprintf("--------------------------------------------------------------\n");
		Rprintf("%10s | %10s | %10s | %10s | %10s\n\n", "Iteration", "MSE", "MKL", "Target", "Rel. Err.");
	}

	if (i_e < err_len)
	{
		mse_err.resize(i_e);
		mkl_err.resize(i_e);
		terr.resize(i_e);
		ave_epoch.resize(i_e);
	}

	if (show_warning && rel_err > rel_tol)
		Rcpp::warning("Target tolerance not reached. Try a larger max.iter.");

	return Rcpp::List::create(
		Rcpp::Named("W") = W.t(),
		Rcpp::Named("H") = H,
		Rcpp::Named("mse_error") = mse_err,
		Rcpp::Named("mkl_error") = mkl_err,
		Rcpp::Named("target_error") = terr,
		Rcpp::Named("average_epoch") = ave_epoch,
		Rcpp::Named("n_iteration") = i
		);
}
Example #11
0
mat nnls_solver_with_missing(const mat & A, const mat & W, const mat & W1, const mat & H2, const umat & mask, 
	const double & eta, const double & beta, int max_iter, double rel_tol, int n_threads)
{
	// A = [W, W1, W2] [H, H1, H2]^T.
	// Where A may have missing values
	// Note that here in the input W = [W, W2]
	// compute x = [H, H1]^T given W, W2
	// A0 = W2*H2 is empty when H2 is empty (no partial info in H)
	// Return: x = [H, H1]

	int n = A.n_rows, m = A.n_cols;
	int k = W.n_cols - H2.n_cols;
	int kW = W1.n_cols;
	int nH = k+kW;

	mat x(nH, m, fill::zeros);

	if (n_threads < 0) n_threads = 0;
	bool is_masked = !mask.empty();

	#pragma omp parallel for num_threads(n_threads) schedule(dynamic)
	for (int j = 0; j < m; j++)
	{
		// break if all entries of col_j are masked
		if (is_masked && arma::all(mask.col(j))) 
			continue;
		
		uvec non_missing = find_finite(A.col(j));
		mat WtW(nH, nH); // WtW
		update_WtW(WtW, W.rows(non_missing), W1.rows(non_missing), H2);
		if (beta > 0) WtW += beta;
		if (eta > 0) WtW.diag() += eta;

		mat mu(nH, 1); // -WtA
		uvec jv(1);
		jv(0) = j;
		//non_missing.t().print("non_missing = ");
		//std::cout << "1.1" << std::endl;
		if (H2.empty())
			update_WtA(mu, W.rows(non_missing), W1.rows(non_missing), H2, A.submat(non_missing, jv));
		else
			update_WtA(mu, W.rows(non_missing), W1.rows(non_missing), H2.rows(j, j), A.submat(non_missing, jv));
		//std::cout << "1.5" << std::endl;

		vec x0(nH);
		double tmp;
		int i = 0;
		double err1, err2 = 9999;
		do {
			x0 = x.col(j);
			err1 = err2;
			err2 = 0;
			for (int l = 0; l < nH; l++)
			{
				if (is_masked && mask(l,j) > 0) continue;
				tmp = x(l,j) - mu(l,0) / WtW(l,l);
				if (tmp < 0) tmp = 0;
				if (tmp != x(l,j))
				{
					mu.col(0) += (tmp - x(l,j)) * WtW.col(l);
				}
				x(l,j) = tmp;
				tmp = std::abs(x(l,j) - x0(l));
				if (tmp > err2) err2 = tmp;
			}
		} while(++i < max_iter && std::abs(err1 - err2) / (err1 + 1e-9) > rel_tol);
	}
	return x;
}