static int applicable0_buf(const S *ego, rdft_kind kind, INT r, INT rs, INT m, INT ms, INT v, INT vs, const R *cr, const R *ci, const planner *plnr, INT *extra_iter) { const hc2c_desc *e = ego->desc; INT batchsz, brs; UNUSED(v); UNUSED(rs); UNUSED(ms); UNUSED(vs); return ( 1 && r == e->radix && kind == e->genus->kind /* ignore cr, ci, use buffer */ && (cr = (const R *)0, ci = cr + 1, batchsz = compute_batchsize(r), brs = 4 * batchsz, 1) && e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2, brs, 1, 1+batchsz, 2, plnr) && ((*extra_iter = 0, e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2, brs, 1, 1 + (((m-1)/2) % batchsz), 2, plnr)) || (*extra_iter = 1, e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2, brs, 1, 1 + 1 + (((m-1)/2) % batchsz), 2, plnr))) ); }
static void apply_buf(const plan *ego_, R *IO) { const P *ego = (const P *) ego_; plan_rdft *cld0 = (plan_rdft *) ego->cld0; plan_rdft *cldm = (plan_rdft *) ego->cldm; INT i, j, m = ego->m, v = ego->v, r = ego->r; INT mb = ego->mb, me = ego->me, ms = ego->ms; INT batchsz = compute_batchsize(r); R *buf; size_t bufsz = r * batchsz * 2 * sizeof(R); BUF_ALLOC(R *, buf, bufsz); for (i = 0; i < v; ++i, IO += ego->vs) { R *IOp = IO; R *IOm = IO + m * ms; cld0->apply((plan *) cld0, IO, IO); for (j = mb; j + batchsz < me; j += batchsz) dobatch(ego, IOp, IOm, j, j + batchsz, buf); dobatch(ego, IOp, IOm, j, me, buf); cldm->apply((plan *) cldm, IO + ms * (m/2), IO + ms * (m/2)); } BUF_FREE(buf, bufsz); }
static void apply_buf(const plan *ego_, R *cr, R *ci) { const P *ego = (const P *) ego_; plan_rdft2 *cld0 = (plan_rdft2 *) ego->cld0; plan_rdft2 *cldm = (plan_rdft2 *) ego->cldm; INT i, j, ms = ego->ms, v = ego->v; INT batchsz = compute_batchsize(ego->r); R *buf; INT mb = 1, me = (ego->m+1) / 2; STACK_MALLOC(R *, buf, ego->r * batchsz * 2 * sizeof(R)); for (i = 0; i < v; ++i, cr += ego->vs, ci += ego->vs) { R *Rp = cr; R *Ip = ci; R *Rm = cr + ego->m * ms; R *Im = ci + ego->m * ms; cld0->apply((plan *) cld0, Rp, Ip, Rp, Ip); for (j = mb; j + batchsz < me; j += batchsz) dobatch(ego, Rp, Ip, Rm, Im, j, j + batchsz, 0, buf); dobatch(ego, Rp, Ip, Rm, Im, j, me, ego->extra_iter, buf); cldm->apply((plan *) cldm, Rp + me * ms, Ip + me * ms, Rp + me * ms, Ip + me * ms); } STACK_FREE(buf); }
static int applicable0_buf(const S *ego, INT r, INT irs, INT ors, INT m, INT ms, INT v, INT ivs, INT ovs, INT mb, INT me, R *rio, R *iio, const planner *plnr) { const ct_desc *e = ego->desc; INT batchsz; UNUSED(v); UNUSED(ms); UNUSED(rio); UNUSED(iio); return ( 1 && r == e->radix && irs == ors /* in-place along R */ && ivs == ovs /* in-place along V */ /* check for alignment/vector length restrictions, both for batchsize and for the remainder */ && (batchsz = compute_batchsize(r), 1) && (e->genus->okp(e, 0, ((const R *)0) + 1, 2 * batchsz, 0, m, mb, mb + batchsz, 2, plnr)) && (e->genus->okp(e, 0, ((const R *)0) + 1, 2 * batchsz, 0, m, mb, me, 2, plnr)) ); }
static int applicable_buf(const solver *ego_, const problem *p_) { const S *ego = (const S *) ego_; const kr2c_desc *desc = ego->desc; const problem_rdft *p = (const problem_rdft *) p_; INT vl, ivs, ovs, batchsz; return ( 1 && p->sz->rnk == 1 && p->vecsz->rnk <= 1 && p->sz->dims[0].n == desc->n && p->kind[0] == desc->genus->kind /* check strides etc */ && X(tensor_tornk1)(p->vecsz, &vl, &ivs, &ovs) && (batchsz = compute_batchsize(desc->n), 1) && (0 /* can operate out-of-place */ || p->I != p->O /* can operate in-place as long as strides are the same */ || X(tensor_inplace_strides2)(p->sz, p->vecsz) /* can do it if the problem fits in the buffer, no matter what the strides are */ || vl <= batchsz ) ); }
static plan *mkcldw(const ct_solver *ego_, INT r, INT irs, INT ors, INT m, INT ms, INT v, INT ivs, INT ovs, INT mstart, INT mcount, R *rio, R *iio, planner *plnr) { const S *ego = (const S *) ego_; P *pln; const ct_desc *e = ego->desc; INT extra_iter; static const plan_adt padt = { 0, awake, print, destroy }; A(mstart >= 0 && mstart + mcount <= m); if (!applicable(ego, r, irs, ors, m, ms, v, ivs, ovs, mstart, mstart + mcount, rio, iio, plnr, &extra_iter)) return (plan *)0; if (ego->bufferedp) { pln = MKPLAN_DFTW(P, &padt, apply_buf); } else { pln = MKPLAN_DFTW(P, &padt, extra_iter ? apply_extra_iter : apply); } pln->k = ego->k; pln->rs = X(mkstride)(r, irs); pln->td = 0; pln->r = r; pln->m = m; pln->ms = ms; pln->v = v; pln->vs = ivs; pln->mb = mstart; pln->me = mstart + mcount; pln->slv = ego; pln->brs = X(mkstride)(r, 2 * compute_batchsize(r)); pln->extra_iter = extra_iter; X(ops_zero)(&pln->super.super.ops); X(ops_madd2)(v * (mcount/e->genus->vl), &e->ops, &pln->super.super.ops); if (ego->bufferedp) { /* 8 load/stores * N * V */ pln->super.super.ops.other += 8 * r * mcount * v; } pln->super.super.could_prune_now_p = (!ego->bufferedp && r >= 5 && r < 64 && m >= r); return &(pln->super.super); }
static void print(const plan *ego_, printer *p) { const P *ego = (const P *) ego_; const S *s = ego->slv; const kdft_desc *d = s->desc; if (ego->slv->bufferedp) p->print(p, "(dft-directbuf/%D-%D%v \"%s\")", compute_batchsize(d->sz), d->sz, ego->vl, d->nam); else p->print(p, "(dft-direct-%D%v \"%s\")", d->sz, ego->vl, d->nam); }
static void print(const plan *ego_, printer *p) { const P *ego = (const P *) ego_; const S *slv = ego->slv; const ct_desc *e = slv->desc; if (slv->bufferedp) p->print(p, "(dftw-directbuf/%D-%D/%D%v \"%s\")", compute_batchsize(ego->r), ego->r, X(twiddle_length)(ego->r, e->tw), ego->v, e->nam); else p->print(p, "(dftw-direct-%D/%D%v \"%s\")", ego->r, X(twiddle_length)(ego->r, e->tw), ego->v, e->nam); }
static void print(const plan *ego_, printer *p) { const P *ego = (const P *) ego_; const S *slv = ego->slv; const hc2hc_desc *e = slv->desc; INT batchsz = compute_batchsize(ego->r); if (slv->bufferedp) p->print(p, "(hc2hc-directbuf/%D-%D/%D%v \"%s\"%(%p%)%(%p%))", batchsz, ego->r, X(twiddle_length)(ego->r, e->tw), ego->v, e->nam, ego->cld0, ego->cldm); else p->print(p, "(hc2hc-direct-%D/%D%v \"%s\"%(%p%)%(%p%))", ego->r, X(twiddle_length)(ego->r, e->tw), ego->v, e->nam, ego->cld0, ego->cldm); }
static plan *mkplan(const solver *ego_, const problem *p_, planner *plnr) { const S *ego = (const S *) ego_; P *pln; const problem_dft *p; iodim *d; const kdft_desc *e = ego->desc; static const plan_adt padt = { X(dft_solve), X(null_awake), print, destroy }; UNUSED(plnr); if (ego->bufferedp) { if (!applicable_buf(ego_, p_, plnr)) return (plan *)0; pln = MKPLAN_DFT(P, &padt, apply_buf); } else { int extra_iterp = 0; if (!applicable(ego_, p_, plnr, &extra_iterp)) return (plan *)0; pln = MKPLAN_DFT(P, &padt, extra_iterp ? apply_extra_iter : apply); } p = (const problem_dft *) p_; d = p->sz->dims; pln->k = ego->k; pln->n = d[0].n; pln->is = X(mkstride)(pln->n, d[0].is); pln->os = X(mkstride)(pln->n, d[0].os); pln->bufstride = X(mkstride)(pln->n, 2 * compute_batchsize(pln->n)); X(tensor_tornk1)(p->vecsz, &pln->vl, &pln->ivs, &pln->ovs); pln->slv = ego; X(ops_zero)(&pln->super.super.ops); X(ops_madd2)(pln->vl / e->genus->vl, &e->ops, &pln->super.super.ops); if (ego->bufferedp) pln->super.super.ops.other += 4 * pln->n * pln->vl; pln->super.super.could_prune_now_p = !ego->bufferedp; return &(pln->super.super); }
static int applicable_buf(const solver *ego_, const problem *p_, const planner *plnr) { const S *ego = (const S *) ego_; const problem_dft *p = (const problem_dft *) p_; const kdft_desc *d = ego->desc; INT vl; INT ivs, ovs; INT batchsz; return ( 1 && p->sz->rnk == 1 && p->vecsz->rnk == 1 && p->sz->dims[0].n == d->sz /* check strides etc */ && X(tensor_tornk1)(p->vecsz, &vl, &ivs, &ovs) /* UGLY if IS <= IVS */ && !(NO_UGLYP(plnr) && X(iabs)(p->sz->dims[0].is) <= X(iabs)(ivs)) && (batchsz = compute_batchsize(d->sz), 1) && (d->genus->okp(d, 0, ((const R *)0) + 1, p->ro, p->io, 2 * batchsz, p->sz->dims[0].os, batchsz, 2, ovs, plnr)) && (d->genus->okp(d, 0, ((const R *)0) + 1, p->ro, p->io, 2 * batchsz, p->sz->dims[0].os, vl % batchsz, 2, ovs, plnr)) && (0 /* can operate out-of-place */ || p->ri != p->ro /* can operate in-place as long as strides are the same */ || X(tensor_inplace_strides2)(p->sz, p->vecsz) /* can do it if the problem fits in the buffer, no matter what the strides are */ || vl <= batchsz ) ); }
static void apply_buf(const plan *ego_, R *ri, R *ii, R *ro, R *io) { const P *ego = (const P *) ego_; R *buf; INT vl = ego->vl, n = ego->n, batchsz = compute_batchsize(n); INT i; size_t bufsz = n * batchsz * 2 * sizeof(R); BUF_ALLOC(R *, buf, bufsz); for (i = 0; i < vl - batchsz; i += batchsz) { dobatch(ego, ri, ii, ro, io, buf, batchsz); ri += batchsz * ego->ivs; ii += batchsz * ego->ivs; ro += batchsz * ego->ovs; io += batchsz * ego->ovs; } dobatch(ego, ri, ii, ro, io, buf, vl - i); BUF_FREE(buf, bufsz); }
static void apply_buf(const plan *ego_, R *rio, R *iio) { const P *ego = (const P *) ego_; INT i, j, v = ego->v, r = ego->r; INT batchsz = compute_batchsize(r); R *buf; INT mb = ego->mb, me = ego->me; size_t bufsz = r * batchsz * 2 * sizeof(R); BUF_ALLOC(R *, buf, bufsz); for (i = 0; i < v; ++i, rio += ego->vs, iio += ego->vs) { for (j = mb; j + batchsz < me; j += batchsz) dobatch(ego, rio, iio, j, j + batchsz, buf); dobatch(ego, rio, iio, j, me, buf); } BUF_FREE(buf, bufsz); }
static void iterate(const P *ego, R *I, R *O, void (*dobatch)(const P *ego, R *I, R *O, R *buf, INT batchsz)) { R *buf; INT vl = ego->vl; INT n = ego->n; INT i; INT batchsz = compute_batchsize(n); size_t bufsz = n * batchsz * sizeof(R); BUF_ALLOC(R *, buf, bufsz); for (i = 0; i < vl - batchsz; i += batchsz) { dobatch(ego, I, O, buf, batchsz); I += batchsz * ego->ivs; O += batchsz * ego->ovs; } dobatch(ego, I, O, buf, vl - i); BUF_FREE(buf, bufsz); }
static plan *mkcldw(const hc2hc_solver *ego_, rdft_kind kind, INT r, INT m, INT ms, INT v, INT vs, INT mstart, INT mcount, R *IO, planner *plnr) { const S *ego = (const S *) ego_; P *pln; const hc2hc_desc *e = ego->desc; plan *cld0 = 0, *cldm = 0; INT imid = (m / 2) * ms; INT rs = m * ms; static const plan_adt padt = { 0, awake, print, destroy }; if (!applicable(ego, kind, r, m, v, plnr)) return (plan *)0; cld0 = X(mkplan_d)( plnr, X(mkproblem_rdft_1_d)((CLD0P(mstart) ? X(mktensor_1d)(r, rs, rs) : X(mktensor_0d)()), X(mktensor_0d)(), TAINT(IO, vs), TAINT(IO, vs), kind)); if (!cld0) goto nada; cldm = X(mkplan_d)( plnr, X(mkproblem_rdft_1_d)((CLDMP(m, mstart, mcount) ? X(mktensor_1d)(r, rs, rs) : X(mktensor_0d)()), X(mktensor_0d)(), TAINT(IO + imid, vs), TAINT(IO + imid, vs), kind == R2HC ? R2HCII : HC2RIII)); if (!cldm) goto nada; pln = MKPLAN_HC2HC(P, &padt, ego->bufferedp ? apply_buf : apply); pln->k = ego->k; pln->td = 0; pln->r = r; pln->rs = X(mkstride)(r, rs); pln->m = m; pln->ms = ms; pln->v = v; pln->vs = vs; pln->slv = ego; pln->brs = X(mkstride)(r, 2 * compute_batchsize(r)); pln->cld0 = cld0; pln->cldm = cldm; pln->mb = mstart + CLD0P(mstart); pln->me = mstart + mcount - CLDMP(m, mstart, mcount); X(ops_zero)(&pln->super.super.ops); X(ops_madd2)(v * ((pln->me - pln->mb) / e->genus->vl), &e->ops, &pln->super.super.ops); X(ops_madd2)(v, &cld0->ops, &pln->super.super.ops); X(ops_madd2)(v, &cldm->ops, &pln->super.super.ops); if (ego->bufferedp) pln->super.super.ops.other += 4 * r * (pln->me - pln->mb) * v; pln->super.super.could_prune_now_p = (!ego->bufferedp && r >= 5 && r < 64 && m >= r); return &(pln->super.super); nada: X(plan_destroy_internal)(cld0); X(plan_destroy_internal)(cldm); return 0; }
static plan *mkcldw(const hc2c_solver *ego_, rdft_kind kind, INT r, INT rs, INT m, INT ms, INT v, INT vs, R *cr, R *ci, planner *plnr) { const S *ego = (const S *) ego_; P *pln; const hc2c_desc *e = ego->desc; plan *cld0 = 0, *cldm = 0; INT imid = (m / 2) * ms; INT extra_iter; static const plan_adt padt = { 0, awake, print, destroy }; if (!applicable(ego, kind, r, rs, m, ms, v, vs, cr, ci, plnr, &extra_iter)) return (plan *)0; cld0 = X(mkplan_d)( plnr, X(mkproblem_rdft2_d)(X(mktensor_1d)(r, rs, rs), X(mktensor_0d)(), TAINT(cr, vs), TAINT(ci, vs), TAINT(cr, vs), TAINT(ci, vs), kind)); if (!cld0) goto nada; cldm = X(mkplan_d)( plnr, X(mkproblem_rdft2_d)(((m % 2) ? X(mktensor_0d)() : X(mktensor_1d)(r, rs, rs) ), X(mktensor_0d)(), TAINT(cr + imid, vs), TAINT(ci + imid, vs), TAINT(cr + imid, vs), TAINT(ci + imid, vs), kind == R2HC ? R2HCII : HC2RIII)); if (!cldm) goto nada; if (ego->bufferedp) pln = MKPLAN_HC2C(P, &padt, apply_buf); else pln = MKPLAN_HC2C(P, &padt, extra_iter ? apply_extra_iter : apply); pln->k = ego->k; pln->td = 0; pln->r = r; pln->rs = X(mkstride)(r, rs); pln->m = m; pln->ms = ms; pln->v = v; pln->vs = vs; pln->slv = ego; pln->brs = X(mkstride)(r, 4 * compute_batchsize(r)); pln->cld0 = cld0; pln->cldm = cldm; pln->extra_iter = extra_iter; X(ops_zero)(&pln->super.super.ops); X(ops_madd2)(v * (((m - 1) / 2) / e->genus->vl), &e->ops, &pln->super.super.ops); X(ops_madd2)(v, &cld0->ops, &pln->super.super.ops); X(ops_madd2)(v, &cldm->ops, &pln->super.super.ops); if (ego->bufferedp) pln->super.super.ops.other += 4 * r * m * v; return &(pln->super.super); nada: X(plan_destroy_internal)(cld0); X(plan_destroy_internal)(cldm); return 0; }
static plan *mkplan(const solver *ego_, const problem *p_, planner *plnr) { const S *ego = (const S *) ego_; P *pln; const problem_rdft *p; iodim *d; INT rs, cs, b, n; static const plan_adt padt = { X(rdft_solve), X(null_awake), print, destroy }; UNUSED(plnr); if (ego->bufferedp) { if (!applicable_buf(ego_, p_)) return (plan *)0; } else { if (!applicable(ego_, p_)) return (plan *)0; } p = (const problem_rdft *) p_; if (R2HC_KINDP(p->kind[0])) { rs = p->sz->dims[0].is; cs = p->sz->dims[0].os; pln = MKPLAN_RDFT(P, &padt, ego->bufferedp ? apply_buf_r2hc : apply_r2hc); } else { rs = p->sz->dims[0].os; cs = p->sz->dims[0].is; pln = MKPLAN_RDFT(P, &padt, ego->bufferedp ? apply_buf_hc2r : apply_hc2r); } d = p->sz->dims; n = d[0].n; pln->k = ego->k; pln->n = n; pln->rs0 = rs; pln->rs = X(mkstride)(n, 2 * rs); pln->csr = X(mkstride)(n, cs); pln->csi = X(mkstride)(n, -cs); pln->ioffset = ioffset(p->kind[0], n, cs); b = compute_batchsize(n); pln->brs = X(mkstride)(n, 2 * b); pln->bcsr = X(mkstride)(n, b); pln->bcsi = X(mkstride)(n, -b); pln->bioffset = ioffset(p->kind[0], n, b); X(tensor_tornk1)(p->vecsz, &pln->vl, &pln->ivs, &pln->ovs); pln->slv = ego; X(ops_zero)(&pln->super.super.ops); X(ops_madd2)(pln->vl / ego->desc->genus->vl, &ego->desc->ops, &pln->super.super.ops); if (ego->bufferedp) pln->super.super.ops.other += 2 * n * pln->vl; pln->super.super.could_prune_now_p = !ego->bufferedp; return &(pln->super.super); }