Example #1
0
  KOKKOS_INLINE_FUNCTION
  int
  IChol<Uplo::Upper,AlgoIChol::Unblocked>
  ::invoke(typename CrsExecViewType::policy_type &policy,
           const typename CrsExecViewType::policy_type::member_type &member,
           CrsExecViewType &A) {
    typedef typename CrsExecViewType::value_type        value_type;
    typedef typename CrsExecViewType::ordinal_type      ordinal_type;
    typedef typename CrsExecViewType::row_view_type     row_view_type;
    typedef typename CrsExecViewType::team_factory_type team_factory_type;

    CrsExecViewType ATL, ATR,      A00,  a01,     A02,
      /**/          ABL, ABR,      a10t, alpha11, a12t,
      /**/                         A20,  a21,     A22;

    Part_2x2(A,   ATL, ATR,
             /**/ ABL, ABR,
             0, 0, Partition::TopLeft);

    value_type zero = 0.0;
    row_view_type alpha, r12t;

    while (ATL.NumRows() < A.NumRows()) {
      Part_2x2_to_3x3(ATL, ATR, /**/  A00,  a01,     A02,
                      /*******/ /**/  a10t, alpha11, a12t,
                      ABL, ABR, /**/  A20,  a21,     A22,
                      1, 1, Partition::BottomRight);
      // -----------------------------------------------------

      // extract diagonal from alpha11
      alpha.setView(alpha11, 0);
      value_type &alpha_val = (alpha.Col(0) ? zero : alpha.Value(0));

      if (member.team_rank() == 0) {
        // if encounter null diag, return -(row + 1)
        if (abs(alpha_val) == 0.0)
          return -(ATL.NumRows() + 1);

        // sqrt on diag
        alpha_val = sqrt(real(alpha_val));
      }

      // sparse inverse scale
      scaleCrsMatrix<ParallelForType>(member, 1.0/real(alpha_val), a12t);

      // hermitian rank update
      her_r<ParallelForType>(member, -1.0, a12t, A22);

      // -----------------------------------------------------
      Merge_3x3_to_2x2(A00,  a01,     A02,  /**/ ATL, ATR,
                       a10t, alpha11, a12t, /**/ /******/
                       A20,  a21,     A22,  /**/ ABL, ABR,
                       Partition::TopLeft);
    }

    return 0;
  }
  KOKKOS_INLINE_FUNCTION
  int
  IChol<Uplo::Upper,AlgoIChol::UnblockedOpt2>
  ::invoke(typename CrsExecViewType::policy_type &policy,
           const typename CrsExecViewType::policy_type::member_type &member,
           CrsExecViewType &A) {

    typedef typename CrsExecViewType::value_type        value_type;
    typedef typename CrsExecViewType::ordinal_type      ordinal_type;
    typedef typename CrsExecViewType::row_view_type     row_view_type;
    typedef typename CrsExecViewType::team_factory_type team_factory_type;

    // row_view_type r1t, r2t;

    for (ordinal_type k=0;k<A.NumRows();++k) {
      //r1t.setView(A, k);
      row_view_type &r1t = A.RowView(k);

      // extract diagonal from alpha11
      value_type &alpha = r1t.Value(0);

      if (member.team_rank() == 0) {
        // if encounter null diag or wrong index, return -(row + 1)
        if (abs(alpha) == 0.0 || r1t.Col(0) != k)
          return -(k + 1);

        // error handling should be more carefully designed

        // sqrt on diag
        alpha = sqrt(real(alpha));
      }
      member.team_barrier();

      const ordinal_type nnz_r1t = r1t.NumNonZeros();

      if (nnz_r1t) {
        // inverse scale
        ParallelForType(team_factory_type::createThreadLoopRegion(member, 1, nnz_r1t),
                        [&](const ordinal_type j) {
                          r1t.Value(j) /= alpha;
                        });

        member.team_barrier();

        // hermitian rank update
        for (ordinal_type i=1;i<nnz_r1t;++i) {
          const ordinal_type row_at_i = r1t.Col(i);
          const value_type   val_at_i = conj(r1t.Value(i));

          //r2t.setView(A, row_at_i);
          row_view_type &r2t = A.RowView(row_at_i);

          ordinal_type idx_team[MAX_TEAM_SIZE] = {};
          ParallelForType(team_factory_type::createThreadLoopRegion(member, i, nnz_r1t),
                          [&](const ordinal_type j) {
                            ordinal_type &idx = idx_team[member.team_rank()];
                            if (idx > -2) {
                              const ordinal_type col_at_j = r1t.Col(j);
                              idx = r2t.Index(col_at_j, idx);
                              if (idx >= 0) {
                                const value_type   val_at_j = r1t.Value(j);
                                r2t.Value(idx) -= val_at_i*val_at_j;
                              }
                            }
                          });
        }
      }
    }
    return 0;
  }
  KOKKOS_INLINE_FUNCTION
  int
  Chol<Uplo::Upper,AlgoChol::UnblockedOpt,Variant::One>
  ::invoke(typename CrsExecViewType::policy_type &policy,
           const typename CrsExecViewType::policy_type::member_type &member,
           CrsExecViewType &A) {

    typedef typename CrsExecViewType::value_type        value_type;
    typedef typename CrsExecViewType::ordinal_type      ordinal_type;
    typedef typename CrsExecViewType::row_view_type     row_view_type;

    // row_view_type r1t, r2t;

    for (ordinal_type k=0;k<A.NumRows();++k) {
      //r1t.setView(A, k);
      row_view_type &r1t = A.RowView(k);

      // extract diagonal from alpha11
      value_type &alpha = r1t.Value(0);

      if (member.team_rank() == 0) {
        // if encounter null diag or wrong index, return -(row + 1)
        if (abs(alpha) == 0.0 || r1t.Col(0) != k)
          return -(k + 1);

        // error handling should be more carefully designed

        // sqrt on diag
        alpha = sqrt(real(alpha));
      }
      member.team_barrier();

      const ordinal_type nnz_r1t = r1t.NumNonZeros();

      if (nnz_r1t) {
        // inverse scale
        Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 1, nnz_r1t),
                             [&](const ordinal_type j) {
                               r1t.Value(j) /= alpha;
                             });

        member.team_barrier();

        // hermitian rank update
        Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 1, nnz_r1t),
                             [&](const ordinal_type i) {
                               const ordinal_type row_at_i = r1t.Col(i);
                               const value_type   val_at_i = conj(r1t.Value(i));
                               
                               //r2t.setView(A, row_at_i);
                               row_view_type &r2t = A.RowView(row_at_i);
                               ordinal_type idx = 0;
                               
                               for (ordinal_type j=i;j<nnz_r1t && (idx > -2);++j) {
                                 const ordinal_type col_at_j = r1t.Col(j);
                                 idx = r2t.Index(col_at_j, idx);
                                 
                                 if (idx >= 0) {
                                   const value_type val_at_j = r1t.Value(j);
                                   r2t.Value(idx) -= val_at_i*val_at_j;
                                 }
                               }
                             });
      }
    }
    return 0;
  }