TiledArray::detail::DistEval<typename Op::result_type, Policy> make_contract_eval( const TiledArray::detail::DistEval<LeftTile, Policy>& left, const TiledArray::detail::DistEval<RightTile, Policy>& right, TiledArray::World& world, const typename TiledArray::detail::DistEval<typename Op::result_type, Policy>::shape_type& shape, const std::shared_ptr<typename TiledArray::detail::DistEval<typename Op::result_type, Policy>::pmap_interface>& pmap, const Permutation& perm, const Op& op) { TA_ASSERT(left.range().rank() == op.left_rank()); TA_ASSERT(right.range().rank() == op.right_rank()); TA_ASSERT((perm.dim() == op.result_rank()) || !perm); // Define the impl type typedef TiledArray::detail::Summa< TiledArray::detail::DistEval<LeftTile, Policy>, TiledArray::detail::DistEval<RightTile, Policy>, Op, Policy> impl_type; // Precompute iteration range data const unsigned int num_contract_ranks = op.num_contract_ranks(); const unsigned int left_end = op.left_rank(); const unsigned int left_middle = left_end - num_contract_ranks; const unsigned int right_end = op.right_rank(); // Construct a vector TiledRange1 objects from the left- and right-hand // arguments that will be used to construct the result TiledRange. Also, // compute the fused outer dimension sizes, number of tiles and elements, // for the contraction. typename impl_type::trange_type::Ranges ranges(op.result_rank()); std::size_t M = 1ul, m = 1ul, N = 1ul, n = 1ul; std::size_t pi = 0ul; for(unsigned int i = 0ul; i < left_middle; ++i) { ranges[(perm ? perm[pi++] : pi++)] = left.trange().data()[i]; M *= left.range().extent_data()[i]; m *= left.trange().elements().extent_data()[i]; } for(std::size_t i = num_contract_ranks; i < right_end; ++i) { ranges[(perm ? perm[pi++] : pi++)] = right.trange().data()[i]; N *= right.range().extent_data()[i]; n *= right.trange().elements().extent_data()[i]; } // Compute the number of tiles in the inner dimension. std::size_t K = 1ul; for(std::size_t i = left_middle; i < left_end; ++i) K *= left.range().extent_data()[i]; // Construct the result range typename impl_type::trange_type trange(ranges.begin(), ranges.end()); // Construct the process grid TiledArray::detail::ProcGrid proc_grid(world, M, N, m, n); return TiledArray::detail::DistEval<typename Op::result_type, Policy>( std::shared_ptr<impl_type>( new impl_type(left, right, world, trange, shape, pmap, perm, op, K, proc_grid))); }
static TiledArray::detail::DistEval<typename Op::result_type, Policy> make_binary_eval( const TiledArray::detail::DistEval<LeftTile, Policy>& left, const TiledArray::detail::DistEval<RightTile, Policy>& right, TiledArray::World& world, const typename TiledArray::detail::DistEval<typename Op::result_type, Policy>::shape_type& shape, const std::shared_ptr<typename TiledArray::detail::DistEval<typename Op::result_type, Policy>::pmap_interface>& pmap, const Permutation& perm, const Op& op) { typedef TiledArray::detail::BinaryEvalImpl< TiledArray::detail::DistEval<LeftTile, Policy>, TiledArray::detail::DistEval<RightTile, Policy>, Op, Policy> impl_type; return TiledArray::detail::DistEval<typename Op::result_type, Policy>( std::shared_ptr<impl_type>(new impl_type(left, right, world, (perm ? perm * left.trange() : left.trange()), shape, pmap, perm, op))); }