Ejemplo n.º 1
0
void operator()(InputMatrixType &A,
             int target_rank,
             UMatrixType &U,
	     SingularValuesMatrixType &SV,
             VMatrixType &V,
	     rand_svd_params_t params,
	     skylark::base::context_t& context) {


    // TODO: input matrix should provide Height() and Width()
    int input_height = A.Height();
    int input_width  = A.Width();
    int sketch_size = target_rank + params.oversampling;


    /**
     * Sanity checks, raise an exception if:
     *   i)   the target rank is too large for the given input matrix or
     *   ii)  the number of columns of the sketched matrix either:
     *        - exceeds its width or
     *        - is less than the target rank
     */
    if ((target_rank > std::min(input_height, input_width)) ||
		    (sketch_size > input_width) ||
		    (sketch_size < target_rank)) {
	    std::ostringstream msg;
	    msg << "Incompatible matrix dimensions and target rank\n";
	    SKYLARK_THROW_EXCEPTION(
			    skylark::base::skylark_exception()
			    << skylark::base::error_msg(msg.str()));
    }


    /** Apply sketch transformation on the input matrix */
    UMatrixType Q(input_height, sketch_size);

    typedef typename SketchTransform<InputMatrixType, UMatrixType>::data_type sketch_data_type;
    sketch_data_type sketch_data(input_width, sketch_size, context);
    //typedef typename SketchTransform<InputMatrixType, UMatrixType> sketch_transform_type;
    SketchTransform<InputMatrixType, UMatrixType> sketch_transform(sketch_data);
    sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());

#if 0
    if (params.transform == sketch::c::transform_type_t::JLT)
	{
	  typedef typename skylark::sketch::JLT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type;
	  sketch_data_type sketch_data(input_width, sketch_size, context);
    	  typedef typename skylark::sketch::JLT_t<InputMatrixType, UMatrixType> sketch_transform_type;
	  sketch_transform_type sketch_transform(sketch_data);
    	  sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());
	}
    else if (params.transform == sketch::c::transform_type_t::FJLT)
	{
	  typedef typename skylark::sketch::FJLT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type;
	  sketch_data_type sketch_data(input_width, sketch_size, context);
    	  typedef typename skylark::sketch::FJLT_t<InputMatrixType, UMatrixType> sketch_transform_type;
	  sketch_transform_type sketch_transform(sketch_data);
    	  sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());
	}
    else if (params.transform == sketch::c::transform_type_t::CWT)
	{
	  /*typedef typename skylark::sketch::CWT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type;
	  sketch_data_type sketch_data(input_width, sketch_size, context);
    	  typedef typename skylark::sketch::CWT_t<InputMatrixType, UMatrixType> sketch_transform_type;
	  sketch_transform_type sketch_transform(sketch_data);
    	  sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());*/
	}
    else
	{
	    std::ostringstream msg;
	    msg << "Unknown sketch transform type\n";
	    SKYLARK_THROW_EXCEPTION(
			    skylark::base::skylark_exception()
			    << skylark::base::error_msg(msg.str()));

	}
#endif

    /** The three steps of the sketched svd approach follow:
     *  - apply sketching
     *  - approximate range of A (find Q)
     *  - SVD
     */


    UMatrixType Y;

    /** Q = QR(Q) */
    skylark::base::qr::Explicit(Q);

    /** q steps of subspace iteration */
    for(int step = 0; step < params.num_iterations; step++) {
	    /** Q = QR(A^T * Q) */
	    skylark::base::Gemm(elem::ADJOINT, elem::NORMAL,
			    double(1), A, Q, Y);
	    skylark::base::qr::Explicit(Y);

	    skylark::base::Gemm(elem::NORMAL, elem::NORMAL,
			    double(1), A,Y, Q);
	    if (!params.skip_qr)
	    	skylark::base::qr::Explicit(Q);
    }


    /** SVD of projected A and then project-back left singular vectors */
    UMatrixType B;
    skylark::base::Gemm(elem::ADJOINT, elem::NORMAL,
		    double(1), Q, A, B);
    skylark::base::SVD(B, SV, V);
    skylark::base::Gemm(elem::NORMAL, elem::NORMAL,
		    double(1), Q, B, U);
}
Ejemplo n.º 2
0
int main(int argc, char* argv[]) {

    /** Initialize MPI  */
    boost::mpi::environment env(argc, argv);
    boost::mpi::communicator world;

    /** Initialize Elemental */
    El::Initialize (argc, argv);

    MPI_Comm mpi_world(world);
    El::Grid grid(mpi_world);

    /** Example parameters */
    int height      = 20;
    int width       = 10;
    int sketch_size = 5;

    /** Define input matrix A */

#ifdef LOCAL
    input_matrix_t A;
    El::Uniform(A, height, width);
#else
    dist_CIRC_CIRC_dense_matrix_t A_CIRC_CIRC(grid);
    input_matrix_t A(grid);
    El::Uniform(A_CIRC_CIRC, height, width);
    A = A_CIRC_CIRC;
#endif

    /** Initialize context */
    skylark::base::context_t context(0);

#ifdef ROWWISE

    /** Sketch transform (rowwise)*/
    int size = width;
    /** Distributed matrix computation */
    output_matrix_t sketched_A(height, sketch_size);
    sketch_transform_t sketch_transform(size, sketch_size, context);
    sketch_transform.apply(A, sketched_A, skylark::sketch::rowwise_tag());

#else

    /** Sketch transform (columnwise)*/
    int size = height;
    /** Distributed matrix computation */
    output_matrix_t sketched_A(sketch_size, width);
    sketch_transform_t sketch_transform(size, sketch_size, context);
    sketch_transform.apply(A, sketched_A, skylark::sketch::columnwise_tag());

#endif

#ifdef ROOT_OUTPUT
    if (world.rank() == 0) {
#endif
        El::Print(sketched_A, "sketched_A");
#ifdef ROOT_OUTPUT
    }
#endif
    El::Finalize();
    return 0;
}