TEST(ZipNode, TwoDisbalancedStringArrays) { // first DIA is heavily balanced to the first workers, second DIA is // balanced to the last workers. std::function<void(Context&)> start_func = [](Context& ctx) { // generate random strings with 10..20 characters auto input_gen = Generate( ctx, [](size_t index) -> std::string { std::default_random_engine rng( 123456 + static_cast<unsigned>(index)); std::uniform_int_distribution<size_t> length(10, 20); rng(); // skip one number return common::RandomString( length(rng), rng, "abcdefghijklmnopqrstuvwxyz") + std::to_string(index); }, test_size); DIA<std::string> input = input_gen.Cache(); std::vector<std::string> vinput = input.AllGather(); ASSERT_EQ(test_size, vinput.size()); // Filter out strings that start with a-e auto input1 = input.Filter( [](const std::string& str) { return str[0] <= 'e'; }); // Filter out strings that start with w-z auto input2 = input.Filter( [](const std::string& str) { return str[0] >= 'w'; }); // zip auto zip_result = input1.Zip( input2, [](const std::string& a, const std::string& b) { return a + b; }); // check result std::vector<std::string> res = zip_result.AllGather(); // recalculate result locally std::vector<std::string> check; { std::vector<std::string> v1, v2; for (size_t index = 0; index < vinput.size(); ++index) { const std::string& s1 = vinput[index]; if (s1[0] <= 'e') v1.push_back(s1); if (s1[0] >= 'w') v2.push_back(s1); } ASSERT_EQ(v1, input1.AllGather()); ASSERT_EQ(v2, input2.AllGather()); for (size_t i = 0; i < std::min(v1.size(), v2.size()); ++i) { check.push_back(v1[i] + v2[i]); // sLOG1 << check.back(); } } for (size_t i = 0; i != res.size(); ++i) { sLOG0 << res[i] << " " << check[i] << (res[i] == check[i]); } ASSERT_EQ(check.size(), res.size()); ASSERT_EQ(check, res); }; api::RunLocalTests(start_func); }
ValueType Select(const DIA<ValueType, InStack>& data, size_t rank, const Compare& compare = Compare()) { api::Context& ctx = data.context(); const size_t size = data.Size(); assert(0 <= rank && rank < size); if (size <= base_case_size) { // base case, gather all data at worker with rank 0 ValueType result = ValueType(); auto elements = data.Gather(); if (ctx.my_rank() == 0) { assert(rank < elements.size()); std::nth_element(elements.begin(), elements.begin() + rank, elements.end(), compare); result = elements[rank]; LOG << "base case: " << size << " elements remaining, result is " << result; } result = ctx.net.Broadcast(result); return result; } ValueType left_pivot, right_pivot; std::tie(left_pivot, right_pivot) = PickPivots(data, size, rank, compare); size_t left_size, middle_size, right_size; using PartSizes = std::pair<size_t, size_t>; std::tie(left_size, middle_size) = data.Map( [&](const ValueType& elem) -> PartSizes { if (compare(elem, left_pivot)) return PartSizes { 1, 0 }; else if (!compare(right_pivot, elem)) return PartSizes { 0, 1 }; else return PartSizes { 0, 0 }; }) .Sum( [](const PartSizes& a, const PartSizes& b) -> PartSizes { return PartSizes { a.first + b.first, a.second + b.second }; }, PartSizes { 0, 0 }); right_size = size - left_size - middle_size; LOGM << "left_size = " << left_size << ", middle_size = " << middle_size << ", right_size = " << right_size << ", rank = " << rank; if (rank == left_size) { // all the elements strictly smaller than the left pivot are on the left // side -> left_size-th element is the left pivot LOGM << "result is left pivot: " << left_pivot; return left_pivot; } else if (rank == left_size + middle_size - 1) { // only the elements strictly greater than the right pivot are on the // right side, so the result is the right pivot in this case LOGM << "result is right pivot: " << right_pivot; return right_pivot; } else if (rank < left_size) { // recurse on the left partition LOGM << "Recursing left, " << left_size << " elements remaining (rank = " << rank << ")\n"; auto left = data.Filter( [&](const ValueType& elem) -> bool { return compare(elem, left_pivot); }).Collapse(); assert(left.Size() == left_size); return Select(left, rank, compare); } else if (left_size + middle_size <= rank) { // recurse on the right partition LOGM << "Recursing right, " << right_size << " elements remaining (rank = " << rank - left_size - middle_size << ")\n"; auto right = data.Filter( [&](const ValueType& elem) -> bool { return compare(right_pivot, elem); }).Collapse(); assert(right.Size() == right_size); return Select(right, rank - left_size - middle_size, compare); } else { // recurse on the middle partition LOGM << "Recursing middle, " << middle_size << " elements remaining (rank = " << rank - left_size << ")\n"; auto middle = data.Filter( [&](const ValueType& elem) -> bool { return !compare(elem, left_pivot) && !compare(right_pivot, elem); }).Collapse(); assert(middle.Size() == middle_size); return Select(middle, rank - left_size, compare); } }