Esempio n. 1
0
TEST(ZipNode, TwoDisbalancedStringArrays) {

    // first DIA is heavily balanced to the first workers, second DIA is
    // balanced to the last workers.
    std::function<void(Context&)> start_func =
        [](Context& ctx) {

            // generate random strings with 10..20 characters
            auto input_gen = Generate(
                ctx,
                [](size_t index) -> std::string {
                    std::default_random_engine rng(
                        123456 + static_cast<unsigned>(index));
                    std::uniform_int_distribution<size_t> length(10, 20);
                    rng(); // skip one number

                    return common::RandomString(
                        length(rng), rng, "abcdefghijklmnopqrstuvwxyz")
                    + std::to_string(index);
                },
                test_size);

            DIA<std::string> input = input_gen.Cache();

            std::vector<std::string> vinput = input.AllGather();
            ASSERT_EQ(test_size, vinput.size());

            // Filter out strings that start with a-e
            auto input1 = input.Filter(
                [](const std::string& str) { return str[0] <= 'e'; });

            // Filter out strings that start with w-z
            auto input2 = input.Filter(
                [](const std::string& str) { return str[0] >= 'w'; });

            // zip
            auto zip_result = input1.Zip(
                input2, [](const std::string& a, const std::string& b) {
                    return a + b;
                });

            // check result
            std::vector<std::string> res = zip_result.AllGather();

            // recalculate result locally
            std::vector<std::string> check;
            {
                std::vector<std::string> v1, v2;

                for (size_t index = 0; index < vinput.size(); ++index) {
                    const std::string& s1 = vinput[index];
                    if (s1[0] <= 'e') v1.push_back(s1);
                    if (s1[0] >= 'w') v2.push_back(s1);
                }

                ASSERT_EQ(v1, input1.AllGather());
                ASSERT_EQ(v2, input2.AllGather());

                for (size_t i = 0; i < std::min(v1.size(), v2.size()); ++i) {
                    check.push_back(v1[i] + v2[i]);
                    // sLOG1 << check.back();
                }
            }

            for (size_t i = 0; i != res.size(); ++i) {
                sLOG0 << res[i] << " " << check[i] << (res[i] == check[i]);
            }

            ASSERT_EQ(check.size(), res.size());
            ASSERT_EQ(check, res);
        };

    api::RunLocalTests(start_func);
}
Esempio n. 2
0
ValueType Select(const DIA<ValueType, InStack>& data, size_t rank,
                 const Compare& compare = Compare()) {
    api::Context& ctx = data.context();
    const size_t size = data.Size();

    assert(0 <= rank && rank < size);

    if (size <= base_case_size) {
        // base case, gather all data at worker with rank 0
        ValueType result = ValueType();
        auto elements = data.Gather();

        if (ctx.my_rank() == 0) {
            assert(rank < elements.size());
            std::nth_element(elements.begin(), elements.begin() + rank,
                             elements.end(), compare);

            result = elements[rank];

            LOG << "base case: " << size << " elements remaining, result is "
                << result;
        }

        result = ctx.net.Broadcast(result);
        return result;
    }

    ValueType left_pivot, right_pivot;
    std::tie(left_pivot, right_pivot) = PickPivots(data, size, rank, compare);

    size_t left_size, middle_size, right_size;

    using PartSizes = std::pair<size_t, size_t>;
    std::tie(left_size, middle_size) =
        data.Map(
            [&](const ValueType& elem) -> PartSizes {
                if (compare(elem, left_pivot))
                    return PartSizes { 1, 0 };
                else if (!compare(right_pivot, elem))
                    return PartSizes { 0, 1 };
                else
                    return PartSizes { 0, 0 };
            })
        .Sum(
            [](const PartSizes& a, const PartSizes& b) -> PartSizes {
                return PartSizes { a.first + b.first, a.second + b.second };
            },
            PartSizes { 0, 0 });
    right_size = size - left_size - middle_size;

    LOGM << "left_size = " << left_size << ", middle_size = " << middle_size
         << ", right_size = " << right_size << ", rank = " << rank;

    if (rank == left_size) {
        // all the elements strictly smaller than the left pivot are on the left
        // side -> left_size-th element is the left pivot
        LOGM << "result is left pivot: " << left_pivot;
        return left_pivot;
    }
    else if (rank == left_size + middle_size - 1) {
        // only the elements strictly greater than the right pivot are on the
        // right side, so the result is the right pivot in this case
        LOGM << "result is right pivot: " << right_pivot;
        return right_pivot;
    }
    else if (rank < left_size) {
        // recurse on the left partition
        LOGM << "Recursing left, " << left_size
             << " elements remaining (rank = " << rank << ")\n";

        auto left = data.Filter(
            [&](const ValueType& elem) -> bool {
                return compare(elem, left_pivot);
            }).Collapse();
        assert(left.Size() == left_size);

        return Select(left, rank, compare);
    }
    else if (left_size + middle_size <= rank) {
        // recurse on the right partition
        LOGM << "Recursing right, " << right_size
             << " elements remaining (rank = " << rank - left_size - middle_size
             << ")\n";

        auto right = data.Filter(
            [&](const ValueType& elem) -> bool {
                return compare(right_pivot, elem);
            }).Collapse();
        assert(right.Size() == right_size);

        return Select(right, rank - left_size - middle_size, compare);
    }
    else {
        // recurse on the middle partition
        LOGM << "Recursing middle, " << middle_size
             << " elements remaining (rank = " << rank - left_size << ")\n";

        auto middle = data.Filter(
            [&](const ValueType& elem) -> bool {
                return !compare(elem, left_pivot) &&
                !compare(right_pivot, elem);
            }).Collapse();
        assert(middle.Size() == middle_size);

        return Select(middle, rank - left_size, compare);
    }
}