Ejemplo n.º 1
0
// TODO(tb): this is disabled because the double ReduceBy creates some deadlock
// in destruction of Channels. We have to fix that later.
TEST(Stage, DISABLED_AdditionalChildReferences) {

    std::function<void(Context&)> start_func =
        [](Context& ctx) {

            auto integers = Generate(
                ctx,
                [](const size_t& index) {
                    return static_cast<int>(index) + 1;
                },
                16);

            auto duplicate_elements = [](int in, auto emit) {
                                          emit(in);
                                          emit(in);
                                      };

            auto modulo_two = [](int in) {
                                  return (in % 2);
                              };

            auto add_function = [](int in1, int in2) {
                                    return in1 + in2;
                                };

            // Create a new DIA references to Generate
            auto doubles = integers.FlatMap(duplicate_elements);

            // Create a child references to Generate
            // Create a new DIA reference to LOpNode
            DIA<int> quadruples = integers.FlatMap(duplicate_elements).Cache();

            // Create a child reference to LOpNode
            DIA<int> octuples = quadruples.ReduceBy(modulo_two, add_function).Cache();
            // Create a child reference to LOpNode
            DIA<int> octuples_second = quadruples.ReduceBy(modulo_two, add_function).Cache();

            // Trigger execution
            std::vector<int> out_vec = octuples.AllGather();

            // 2x DIA reference + 1x child reference
            ASSERT_EQ(integers.node_refcount(), 3u);
            ASSERT_EQ(doubles.node_refcount(), 3u);
            // 1x DIA reference + 2x child reference
            ASSERT_EQ(quadruples.node_refcount(), 3u);
            // 1x DIA reference + 0x child reference
            ASSERT_EQ(octuples.node_refcount(), 1u);
            // 1x DIA reference + 0x child reference
            ASSERT_EQ(octuples_second.node_refcount(), 1u);
        };

    api::RunLocalTests(start_func);
}
Ejemplo n.º 2
0
TEST(Graph, WhileLoop) {

    std::function<void(Context&)> start_func =
        [](Context& ctx) {

            auto integers = Generate(
                ctx,
                [](const size_t& index) -> size_t {
                    return index;
                },
                16);

            auto flatmap_duplicate = [](size_t in, auto emit) {
                                         emit(in);
                                         emit(in);
                                     };

            auto map_multiply = [](size_t in) {
                                    return 2 * in;
                                };

            DIA<size_t> squares = integers.Collapse();
            size_t sum = 0;

            // run loop four times, inflating DIA of 16 items -> 256
            while (sum < 64) {
                auto pairs = squares.FlatMap(flatmap_duplicate);
                auto multiplied = pairs.Map(map_multiply);
                squares = multiplied.Cache();
                sum = squares.Size();
            }

            std::vector<size_t> out_vec = squares.AllGather();

            ASSERT_EQ(64u, out_vec.size());
            ASSERT_EQ(64u, squares.Size());

            ctx.stats_graph().BuildLayout("loop.out");
        };

    api::RunLocalTests(start_func);
}
Ejemplo n.º 3
0
TEST(ZipNode, TwoDisbalancedStringArrays) {

    // first DIA is heavily balanced to the first workers, second DIA is
    // balanced to the last workers.
    std::function<void(Context&)> start_func =
        [](Context& ctx) {

            // generate random strings with 10..20 characters
            auto input_gen = Generate(
                ctx,
                [](size_t index) -> std::string {
                    std::default_random_engine rng(
                        123456 + static_cast<unsigned>(index));
                    std::uniform_int_distribution<size_t> length(10, 20);
                    rng(); // skip one number

                    return common::RandomString(
                        length(rng), rng, "abcdefghijklmnopqrstuvwxyz")
                    + std::to_string(index);
                },
                test_size);

            DIA<std::string> input = input_gen.Cache();

            std::vector<std::string> vinput = input.AllGather();
            ASSERT_EQ(test_size, vinput.size());

            // Filter out strings that start with a-e
            auto input1 = input.Filter(
                [](const std::string& str) { return str[0] <= 'e'; });

            // Filter out strings that start with w-z
            auto input2 = input.Filter(
                [](const std::string& str) { return str[0] >= 'w'; });

            // zip
            auto zip_result = input1.Zip(
                input2, [](const std::string& a, const std::string& b) {
                    return a + b;
                });

            // check result
            std::vector<std::string> res = zip_result.AllGather();

            // recalculate result locally
            std::vector<std::string> check;
            {
                std::vector<std::string> v1, v2;

                for (size_t index = 0; index < vinput.size(); ++index) {
                    const std::string& s1 = vinput[index];
                    if (s1[0] <= 'e') v1.push_back(s1);
                    if (s1[0] >= 'w') v2.push_back(s1);
                }

                ASSERT_EQ(v1, input1.AllGather());
                ASSERT_EQ(v2, input2.AllGather());

                for (size_t i = 0; i < std::min(v1.size(), v2.size()); ++i) {
                    check.push_back(v1[i] + v2[i]);
                    // sLOG1 << check.back();
                }
            }

            for (size_t i = 0; i != res.size(); ++i) {
                sLOG0 << res[i] << " " << check[i] << (res[i] == check[i]);
            }

            ASSERT_EQ(check.size(), res.size());
            ASSERT_EQ(check, res);
        };

    api::RunLocalTests(start_func);
}