示例#1
0
文件: nnls.hpp 项目: beckgom/smallk
bool NnlsHals(const MatrixType<T>& A,
              DenseMatrix<T>& W, 
              DenseMatrix<T>& H,
              const T tol,
              const bool verbose,
              const unsigned int max_iter)
{
    unsigned int n = A.Width();
    unsigned int k = W.Width();

    if (static_cast<unsigned int>(W.Height()) != static_cast<unsigned int>(A.Height()))
        throw std::logic_error("NnlsHals: W and A must have identical height");
    if (static_cast<unsigned int>(H.Width()) != static_cast<unsigned int>(A.Width()))
        throw std::logic_error("NnlsHals: H and A must have identical width");
    if (H.Height() != W.Width())
        throw std::logic_error("NnlsHals: non-conformant W and H");

    DenseMatrix<T> WtW(k, k), WtA(k, n), WtWH_r(1, n), gradH(k, n);

    if (verbose)
        std::cout << "\nRunning NNLS solver..." << std::endl;
    
    // compute W'W and W'A for the normal equations
    Gemm(TRANSPOSE, NORMAL, T(1.0), W, W, T(0.0), WtW);
    Gemm(TRANSPOSE, NORMAL, T(1.0), W, A, T(0.0), WtA);

    bool success = false;

    T pg0 = T(0), pg;
    for (unsigned int i=0; i<max_iter; ++i)
    {
        // compute the new matrix H
        UpdateH_Hals(H, WtWH_r, WtW, WtA);
        
        // compute gradH = WtW*H - WtA
        Gemm(NORMAL, NORMAL, T(1.0), WtW, H, T(0.0), gradH);
        Axpy( T(-1.0), WtA, gradH);

        // compute progress metric
        if (0 == i)
        {
            pg0 = ProjectedGradientNorm(gradH, H);
            if (verbose)
                ReportProgress(i+1, T(1.0));
            continue;
        }
        else
        {
            pg = ProjectedGradientNorm(gradH, H);
        }

        if (verbose)
            ReportProgress(i+1, pg/pg0);
        
        // check progress vs. desired tolerance
        if (pg < tol * pg0)
        {
            success = true;
            NormalizeAndScale<T>(W, H);
            break;
        }
    }

    if (!success)
        std::cerr << "NNLS solver reached iteration limit." << std::endl;
    
    return success;
}