TiledArray::detail::DistEval<typename Op::result_type, Policy> make_contract_eval(
      const TiledArray::detail::DistEval<LeftTile, Policy>& left,
      const TiledArray::detail::DistEval<RightTile, Policy>& right,
      TiledArray::World& world,
      const typename TiledArray::detail::DistEval<typename Op::result_type, Policy>::shape_type& shape,
      const std::shared_ptr<typename TiledArray::detail::DistEval<typename Op::result_type, Policy>::pmap_interface>& pmap,
      const Permutation& perm,
      const Op& op)
  {
    TA_ASSERT(left.range().rank() == op.left_rank());
    TA_ASSERT(right.range().rank() == op.right_rank());
    TA_ASSERT((perm.dim() == op.result_rank()) || !perm);

    // Define the impl type
    typedef TiledArray::detail::Summa<
        TiledArray::detail::DistEval<LeftTile, Policy>,
        TiledArray::detail::DistEval<RightTile, Policy>, Op, Policy> impl_type;

    // Precompute iteration range data
    const unsigned int num_contract_ranks = op.num_contract_ranks();
    const unsigned int left_end = op.left_rank();
    const unsigned int left_middle = left_end - num_contract_ranks;
    const unsigned int right_end = op.right_rank();

    // Construct a vector TiledRange1 objects from the left- and right-hand
    // arguments that will be used to construct the result TiledRange. Also,
    // compute the fused outer dimension sizes, number of tiles and elements,
    // for the contraction.
    typename impl_type::trange_type::Ranges ranges(op.result_rank());
    std::size_t M = 1ul, m = 1ul, N = 1ul, n = 1ul;
    std::size_t pi = 0ul;
    for(unsigned int i = 0ul; i < left_middle; ++i) {
      ranges[(perm ? perm[pi++] : pi++)] = left.trange().data()[i];
      M *= left.range().extent_data()[i];
      m *= left.trange().elements().extent_data()[i];
    }
    for(std::size_t i = num_contract_ranks; i < right_end; ++i) {
      ranges[(perm ? perm[pi++] : pi++)] = right.trange().data()[i];
      N *= right.range().extent_data()[i];
      n *= right.trange().elements().extent_data()[i];
    }

    // Compute the number of tiles in the inner dimension.
    std::size_t K = 1ul;
    for(std::size_t i = left_middle; i < left_end; ++i)
      K *= left.range().extent_data()[i];

    // Construct the result range
    typename impl_type::trange_type trange(ranges.begin(), ranges.end());

    // Construct the process grid
    TiledArray::detail::ProcGrid proc_grid(world, M, N, m, n);

    return TiledArray::detail::DistEval<typename Op::result_type, Policy>(
        std::shared_ptr<impl_type>( new impl_type(left, right, world, trange,
        shape, pmap, perm, op, K, proc_grid)));
  }
Ejemplo n.º 2
0
 static TiledArray::detail::DistEval<typename Op::result_type, Policy>
 make_binary_eval(
     const TiledArray::detail::DistEval<LeftTile, Policy>& left,
     const TiledArray::detail::DistEval<RightTile, Policy>& right,
     TiledArray::World& world,
     const typename TiledArray::detail::DistEval<typename Op::result_type, Policy>::shape_type& shape,
     const std::shared_ptr<typename TiledArray::detail::DistEval<typename Op::result_type, Policy>::pmap_interface>& pmap,
     const Permutation& perm,
     const Op& op)
 {
   typedef TiledArray::detail::BinaryEvalImpl<
       TiledArray::detail::DistEval<LeftTile, Policy>,
       TiledArray::detail::DistEval<RightTile, Policy>, Op, Policy> impl_type;
   return TiledArray::detail::DistEval<typename Op::result_type, Policy>(
       std::shared_ptr<impl_type>(new impl_type(left, right, world,
       (perm ? perm * left.trange() : left.trange()), shape, pmap, perm, op)));
 }