コード例 #1
0
ファイル: itebd_itime.hpp プロジェクト: wqren/tensor
 static inline double
 avg_change(size_t step, RTensor &cum, double delta)
 {
   size_t L = cum.size();
   size_t i = step % L;
   cum.at(i) = delta;
   if (step >= L) {
     return std::accumulate(cum.begin(), cum.end(), 0.0);
   } else {
     return 1.0;
   }
 }
コード例 #2
0
ファイル: block_svd.hpp プロジェクト: cadarso/tensor
  RTensor
  do_block_svd(const Tensor &A, Tensor *pU, Tensor *pVT, bool economic)
  {
    index rows = A.rows();
    index cols = A.columns();
    if (rows != cols && !economic)
      return svd(A, pU, pVT, economic);
    index minrc = std::min(rows, cols);

    index nblocks;
    Indices *block_rows, *block_cols;
    if (!find_blocks<Tensor>(A, &nblocks, &block_rows, &block_cols)) {
      return svd(A, pU, pVT, economic);
    }

    if ((nblocks == 1) &&
	(block_rows[0].size() >= rows/2) &&
	(block_cols[0].size() >= cols/2)) {
      RTensor s = svd(A, pU, pVT, economic);
      delete[] block_rows;
      delete[] block_cols;
      return s;
    }

    RTensor s(minrc);
    s.fill_with_zeros();
    if (pU) {
      *pU = Tensor::zeros(rows, economic? minrc : rows);
    }
    if (pVT) {
      *pVT = Tensor::zeros(economic? minrc : cols, cols);
    }

    RTensor stemp;
    Tensor Utemp, Vtemp;
    Tensor *pUtemp = pU? &Utemp : 0;
    Tensor *pVtemp = pVT? &Vtemp : 0;
    for (index b = 0, sndx = 0; b < nblocks; b++) {
      Tensor m = A(range(block_rows[b]), range(block_cols[b]));
      index n = m.size();
      if (m.size() > 1) {
	stemp = svd(m, pUtemp, pVtemp, economic);
        index slast = sndx + stemp.size() - 1;
	s.at(range(sndx, slast)) = stemp;
	if (pU) {
          (*pU).at(range(block_rows[b]), range(sndx, slast)) = Utemp;
	}
	if (pVT) {
          (*pVT).at(range(sndx, slast), range(block_cols[b])) = Vtemp;
	}
        sndx = slast + 1;
      } else {
	index row = block_rows[b][0];
	index col = block_cols[b][0];
	double aux = abs(m[0]);
	s.at(sndx) = aux;
	if (pU) {
	  (*pU).at(row,sndx) = 1.0;
	}
	if (pVT) {
	  (*pVT).at(sndx,col) = m[0]/aux;
	}
	++sndx;
      }
    }
    delete[] block_rows;
    delete[] block_cols;

    Indices ndx = sort_indices(s, true);
    s = s(range(ndx));
    if (pU)
      *pU = (*pU)(range(), range(ndx));
    if (pVT)
      *pVT = (*pVT)(range(ndx), range());
    return s;
  }