Esempio n. 1
0
template<unsigned int N> static boost::shared_ptr<hoNDArray<float_complext> > gadgetronNFFT_instance(hoNDArray<float_complext> * input_data, hoNDArray<vector_td<float,N> >* trajectory,
		vector_td<uint64_t,N> matrix_size, float W, hoNDArray<float>* dcw = nullptr){

	cuNDArray<float_complext> cuInput(*input_data);
	cuNDArray<vector_td<float,N> > cu_traj(*trajectory);
	auto op = boost::make_shared<cuNFFTOperator<float,N>>();
	op->setup(matrix_size,matrix_size*size_t(2),W);
	op->preprocess(&cu_traj);
	if (dcw){
		auto cu_dcw = boost::make_shared<cuNDArray<float>>(*dcw);
		sqrt_inplace(cu_dcw.get());
		op->set_dcw(cu_dcw);

		cuInput *= *cu_dcw;
	}
	std::vector<size_t> out_dims(&matrix_size[0],&matrix_size[N]);
	out_dims.push_back(cuInput.get_number_of_elements()/cu_traj.get_number_of_elements());
/*
	op->set_domain_dimensions(&out_dims);
	op->set_codomain_dimensions(cuInput.get_dimensions().get());
	cuCgSolver<float_complext> cg;
	cg.set_max_iterations(10);
	cg.set_tc_tolerance(1e-8);
	cg.set_encoding_operator(op);
	auto output = cg.solve(&cuInput);
*/
	cuNDArray<float_complext> output(out_dims);
	op->mult_MH(&cuInput,&output);
	return output.to_host();
}
Esempio n. 2
0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	// Check number of arguments
	if (nrhs < 3 || nrhs > 5)
		mexErrMsgTxt("Unexpected number of input arguments.");
	if (nlhs < 1 || nlhs > 2)
		mexErrMsgTxt("Unexpected number of output arguments.");

	// Check argument types are valid
	mxClassID in_class = mxGetClassID(prhs[1]);
	if (mxGetClassID(prhs[2]) != in_class)
		mexErrMsgTxt("X and Y must be arrays of the same type");
	for (int i = 0; i < nrhs; ++i) {
        if (mxIsComplex(prhs[i]))
			mexErrMsgTxt("Inputs cannot be complex.");
	}

	// Get and check array dimensions
	int num_points = static_cast<int>(mxGetNumberOfElements(prhs[1]));
	if (num_points != static_cast<int>(mxGetNumberOfElements(prhs[2])))
		mexErrMsgTxt("X and Y must have the same dimensions");
	int ndims = static_cast<int>(mxGetNumberOfDimensions(prhs[0]));
	std::vector<size_t> out_dims(ndims+2);
    out_dims[0] = 2;
	out_dims[1] = mxGetM(prhs[1]);
	out_dims[2] = mxGetN(prhs[1]);
	const mwSize *dims = mxGetDimensions(prhs[0]);
	out_dims[3] = 1;
	int nchannels = 1;
	for (int i = 2; i < ndims; ++i) {
		out_dims[i+1] = dims[i];
		nchannels *= static_cast<int>(dims[i]);
	}
	
	// Get the out of bounds value (oobv) and set the output class to the same class as the oobv.
	double oobv;
	mxClassID out_class;
	if (nrhs > 4) {
		// Get the value for oobv
		if (mxGetNumberOfElements(prhs[4]) != 1)
			mexErrMsgTxt("oobv must be a scalar.");
		oobv = mxGetScalar(prhs[4]);
		out_class = mxGetClassID(prhs[4]);
	} else {
		// Use the default value for oobv
		oobv = mxGetNaN();
		out_class = in_class;
	}
	
	// Create the output arrays
	plhs[0] = mxCreateUninitNumericArray(ndims, &out_dims[1], out_class, mxREAL);
	void *B = mxGetData(plhs[0]);
    void *G = NULL;
    if (nlhs > 1) {
        plhs[1] = mxCreateUninitNumericArray(ndims+1, &out_dims[0], out_class, mxREAL);
        G = mxGetData(plhs[1]);
    }

	// Get the interpolation method
    char buffer[10] = {'l'};
    int k = 0;
    if (nrhs > 3) {
        // Read in the method string
        if (mxGetString(prhs[3], buffer, sizeof(buffer)))
            mexErrMsgTxt("Unrecognised interpolation method");
        // Remove '*' from the start
        k += (buffer[k] == '*');
        // Remove 'bi' from the start
        k += 2 * ((buffer[k] == 'b') & (buffer[k+1] == 'i'));
    }
    
    // Get pointer to the input image
    const void *A = mxGetData(prhs[0]);
    
	// Call the first wrapper function according to the input image type
    const int width = static_cast<int>(dims[1]);
    const int height = static_cast<int>(dims[0]);
	switch (mxGetClassID(prhs[0])) {
		case mxDOUBLE_CLASS:
			wrapper_func(B, G, (const double *)A, prhs, num_points, width, height, nchannels, oobv, buffer[k], out_class, in_class);
			break;
		case mxSINGLE_CLASS:
			wrapper_func(B, G, (const float *)A, prhs, num_points, width, height, nchannels, oobv, buffer[k], out_class, in_class);
			break;
		case mxINT8_CLASS:
			wrapper_func(B, G, (const int8_t *)A, prhs, num_points, width, height, nchannels, oobv, buffer[k], out_class, in_class);
			break;
		case mxUINT8_CLASS:
			wrapper_func(B, G, (const uint8_t *)A, prhs, num_points, width, height, nchannels, oobv, buffer[k], out_class, in_class);
			break;
		case mxINT16_CLASS:
			wrapper_func(B, G, (const int16_t *)A, prhs, num_points, width, height, nchannels, oobv, buffer[k], out_class, in_class);
			break;
		case mxUINT16_CLASS:
			wrapper_func(B, G, (const uint16_t *)A, prhs, num_points, width, height, nchannels, oobv, buffer[k], out_class, in_class);
			break;
		case mxLOGICAL_CLASS:
			wrapper_func(B, G, (const mxLogical *)A, prhs, num_points, width, height, nchannels, oobv, buffer[k], out_class, in_class);
			break;
		default:
			mexErrMsgTxt("A is of an unsupported type");
			break;
	}
	return;
}