struct nlop_s* nlop_chain(const struct nlop_s* a, const struct nlop_s* b) { assert(1 == nlop_get_nr_in_args(a)); assert(1 == nlop_get_nr_out_args(a)); assert(1 == nlop_get_nr_in_args(b)); assert(1 == nlop_get_nr_out_args(b)); const struct linop_s* la = linop_from_nlop(a); const struct linop_s* lb = linop_from_nlop(b); if ((NULL != la) && (NULL != lb)) return nlop_from_linop(linop_chain(la, lb)); PTR_ALLOC(struct nlop_s, n); const struct linop_s* (*der)[1][1] = TYPE_ALLOC(const struct linop_s*[1][1]); n->derivative = &(*der)[0][0]; if (NULL == la) la = a->derivative[0]; if (NULL == lb) lb = b->derivative[0]; n->op = operator_chain(a->op, b->op); n->derivative[0] = linop_chain(la, lb); return PTR_PASS(n); }
/** * Create chain of linear operators. * C = B A * C^H = A^H B^H * C^H C = A^H B^H B A */ struct linop_s* linop_chain(const struct linop_s* a, const struct linop_s* b) { struct linop_s* c = xmalloc(sizeof(struct linop_s)); c->forward = operator_chain(a->forward, b->forward); c->adjoint = operator_chain(b->adjoint, a->adjoint); if (NULL == b->normal) { c->normal = operator_chain(c->forward, c->adjoint); } else { const struct operator_s* top = operator_chain(b->normal, a->adjoint); c->normal = operator_chain(a->forward, top); operator_free(top); } c->norm_inv = NULL; return c; }