示例#1
0
文件: threshold.c 项目: hcmh/bart
static void lrthresh(unsigned int D, const long dims[D], int llrblk, float lambda, unsigned int flags, complex float* out, const complex float* in)
{
	long blkdims[MAX_LEV][D];

	int levels = llr_blkdims(blkdims, ~flags, dims, llrblk);
	UNUSED(levels);

	const struct operator_p_s* p = lrthresh_create(dims, false, ~flags, (const long (*)[])blkdims, lambda, false, false);

	operator_p_apply(p, 1., D, dims, out, D, dims, in);

	operator_p_free(p);
}
示例#2
0
文件: optreg.c 项目: welcheb/bart
void opt_reg_configure(unsigned int N, const long img_dims[N], struct opt_reg_s* ropts, const struct operator_p_s* prox_ops[NUM_REGS], const struct linop_s* trafos[NUM_REGS], unsigned int llr_blk, bool randshift, bool use_gpu)
{
	float lambda = ropts->lambda;

	if (-1. == lambda)
		lambda = 0.;

	// if no penalities specified but regularization
	// parameter is given, add a l2 penalty

	struct reg_s* regs = ropts->regs;

	if ((0 == ropts->r) && (lambda > 0.)) {

		regs[0].xform = L2IMG;
		regs[0].xflags = 0u;
		regs[0].jflags = 0u;
		regs[0].lambda = lambda;
		ropts->r = 1;
	}



	int nr_penalties = ropts->r;
	long blkdims[MAX_LEV][DIMS];
	int levels;


	for (int nr = 0; nr < nr_penalties; nr++) {

		// fix up regularization parameter
		if (-1. == regs[nr].lambda)
			regs[nr].lambda = lambda;

		switch (regs[nr].xform) {

			case L1WAV:
				debug_printf(DP_INFO, "l1-wavelet regularization: %f\n", regs[nr].lambda);

				if (0 != regs[nr].jflags)
					debug_printf(DP_WARN, "joint l1-wavelet thresholding not currently supported.\n");

				long minsize[DIMS] = { [0 ... DIMS - 1] = 1 };
				minsize[0] = MIN(img_dims[0], 16);
				minsize[1] = MIN(img_dims[1], 16);
				minsize[2] = MIN(img_dims[2], 16);


				unsigned int wflags = 0;
				for (unsigned int i = 0; i < DIMS; i++) {

					if ((1 < img_dims[i]) && MD_IS_SET(regs[nr].xflags, i)) {

						wflags = MD_SET(wflags, i);
						minsize[i] = MIN(img_dims[i], 16);
					}
				}

				trafos[nr] = linop_identity_create(DIMS, img_dims);
				prox_ops[nr] = prox_wavelet3_thresh_create(DIMS, img_dims, wflags, minsize, regs[nr].lambda, randshift);
				break;

			case TV:
				debug_printf(DP_INFO, "TV regularization: %f\n", regs[nr].lambda);

				trafos[nr] = linop_grad_create(DIMS, img_dims, regs[nr].xflags);
				prox_ops[nr] = prox_thresh_create(DIMS + 1,
						linop_codomain(trafos[nr])->dims,
						regs[nr].lambda, regs[nr].jflags | MD_BIT(DIMS), use_gpu);
				break;

			case LLR:
				debug_printf(DP_INFO, "lowrank regularization: %f\n", regs[nr].lambda);

				// add locally lowrank penalty
				levels = llr_blkdims(blkdims, regs[nr].jflags, img_dims, llr_blk);

				assert(1 == levels);
				assert(levels == img_dims[LEVEL_DIM]);

				for(int l = 0; l < levels; l++)
#if 0
					blkdims[l][MAPS_DIM] = img_dims[MAPS_DIM];
#else
				blkdims[l][MAPS_DIM] = 1;
#endif

				int remove_mean = 0;

				trafos[nr] = linop_identity_create(DIMS, img_dims);
				prox_ops[nr] = lrthresh_create(img_dims, randshift, regs[nr].xflags, (const long (*)[DIMS])blkdims, regs[nr].lambda, false, remove_mean, use_gpu);
				break;

			case MLR:
#if 0
				// FIXME: multiscale low rank changes the output image dimensions 
				// and requires the forward linear operator. This should be decoupled...
				debug_printf(DP_INFO, "multi-scale lowrank regularization: %f\n", regs[nr].lambda);

				levels = multilr_blkdims(blkdims, regs[nr].jflags, img_dims, 8, 1);

				img_dims[LEVEL_DIM] = levels;
				max_dims[LEVEL_DIM] = levels;

				for(int l = 0; l < levels; l++)
					blkdims[l][MAPS_DIM] = 1;

				trafos[nr] = linop_identity_create(DIMS, img_dims);
				prox_ops[nr] = lrthresh_create(img_dims, randshift, regs[nr].xflags, (const long (*)[DIMS])blkdims, regs[nr].lambda, false, 0, use_gpu);

				const struct linop_s* decom_op = sum_create( img_dims, use_gpu );
				const struct linop_s* tmp_op = forward_op;
				forward_op = linop_chain(decom_op, forward_op);

				linop_free(decom_op);
				linop_free(tmp_op);
#else
				debug_printf(DP_WARN, "multi-scale lowrank regularization not yet supported: %f\n", regs[nr].lambda);
#endif

				break;

			case IMAGL1:
				debug_printf(DP_INFO, "l1 regularization of imaginary part: %f\n", regs[nr].lambda);

				trafos[nr] = linop_rdiag_create(DIMS, img_dims, 0, &(complex float){ 1.i });
				prox_ops[nr] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags, use_gpu);
				break;

			case IMAGL2:
				debug_printf(DP_INFO, "l2 regularization of imaginary part: %f\n", regs[nr].lambda);

				trafos[nr] = linop_rdiag_create(DIMS, img_dims, 0, &(complex float){ 1.i });
				prox_ops[nr] = prox_leastsquares_create(DIMS, img_dims, regs[nr].lambda, NULL);
				break;

			case L1IMG:
				debug_printf(DP_INFO, "l1 regularization: %f\n", regs[nr].lambda);

				trafos[nr] = linop_identity_create(DIMS, img_dims);
				prox_ops[nr] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags, use_gpu);
				break;

			case L2IMG:
				debug_printf(DP_INFO, "l2 regularization: %f\n", regs[nr].lambda);

				trafos[nr] = linop_identity_create(DIMS, img_dims);
				prox_ops[nr] = prox_leastsquares_create(DIMS, img_dims, regs[nr].lambda, NULL);
				break;

			case FTL1:
				debug_printf(DP_INFO, "l1 regularization of Fourier transform: %f\n", regs[nr].lambda);

				trafos[nr] = linop_fft_create(DIMS, img_dims, regs[nr].xflags);
				prox_ops[nr] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags, use_gpu);
				break;
		}