// TODO(tb): this is disabled because the double ReduceBy creates some deadlock // in destruction of Channels. We have to fix that later. TEST(Stage, DISABLED_AdditionalChildReferences) { std::function<void(Context&)> start_func = [](Context& ctx) { auto integers = Generate( ctx, [](const size_t& index) { return static_cast<int>(index) + 1; }, 16); auto duplicate_elements = [](int in, auto emit) { emit(in); emit(in); }; auto modulo_two = [](int in) { return (in % 2); }; auto add_function = [](int in1, int in2) { return in1 + in2; }; // Create a new DIA references to Generate auto doubles = integers.FlatMap(duplicate_elements); // Create a child references to Generate // Create a new DIA reference to LOpNode DIA<int> quadruples = integers.FlatMap(duplicate_elements).Cache(); // Create a child reference to LOpNode DIA<int> octuples = quadruples.ReduceBy(modulo_two, add_function).Cache(); // Create a child reference to LOpNode DIA<int> octuples_second = quadruples.ReduceBy(modulo_two, add_function).Cache(); // Trigger execution std::vector<int> out_vec = octuples.AllGather(); // 2x DIA reference + 1x child reference ASSERT_EQ(integers.node_refcount(), 3u); ASSERT_EQ(doubles.node_refcount(), 3u); // 1x DIA reference + 2x child reference ASSERT_EQ(quadruples.node_refcount(), 3u); // 1x DIA reference + 0x child reference ASSERT_EQ(octuples.node_refcount(), 1u); // 1x DIA reference + 0x child reference ASSERT_EQ(octuples_second.node_refcount(), 1u); }; api::RunLocalTests(start_func); }
TEST(Graph, WhileLoop) { std::function<void(Context&)> start_func = [](Context& ctx) { auto integers = Generate( ctx, [](const size_t& index) -> size_t { return index; }, 16); auto flatmap_duplicate = [](size_t in, auto emit) { emit(in); emit(in); }; auto map_multiply = [](size_t in) { return 2 * in; }; DIA<size_t> squares = integers.Collapse(); size_t sum = 0; // run loop four times, inflating DIA of 16 items -> 256 while (sum < 64) { auto pairs = squares.FlatMap(flatmap_duplicate); auto multiplied = pairs.Map(map_multiply); squares = multiplied.Cache(); sum = squares.Size(); } std::vector<size_t> out_vec = squares.AllGather(); ASSERT_EQ(64u, out_vec.size()); ASSERT_EQ(64u, squares.Size()); ctx.stats_graph().BuildLayout("loop.out"); }; api::RunLocalTests(start_func); }
TEST(ZipNode, TwoDisbalancedStringArrays) { // first DIA is heavily balanced to the first workers, second DIA is // balanced to the last workers. std::function<void(Context&)> start_func = [](Context& ctx) { // generate random strings with 10..20 characters auto input_gen = Generate( ctx, [](size_t index) -> std::string { std::default_random_engine rng( 123456 + static_cast<unsigned>(index)); std::uniform_int_distribution<size_t> length(10, 20); rng(); // skip one number return common::RandomString( length(rng), rng, "abcdefghijklmnopqrstuvwxyz") + std::to_string(index); }, test_size); DIA<std::string> input = input_gen.Cache(); std::vector<std::string> vinput = input.AllGather(); ASSERT_EQ(test_size, vinput.size()); // Filter out strings that start with a-e auto input1 = input.Filter( [](const std::string& str) { return str[0] <= 'e'; }); // Filter out strings that start with w-z auto input2 = input.Filter( [](const std::string& str) { return str[0] >= 'w'; }); // zip auto zip_result = input1.Zip( input2, [](const std::string& a, const std::string& b) { return a + b; }); // check result std::vector<std::string> res = zip_result.AllGather(); // recalculate result locally std::vector<std::string> check; { std::vector<std::string> v1, v2; for (size_t index = 0; index < vinput.size(); ++index) { const std::string& s1 = vinput[index]; if (s1[0] <= 'e') v1.push_back(s1); if (s1[0] >= 'w') v2.push_back(s1); } ASSERT_EQ(v1, input1.AllGather()); ASSERT_EQ(v2, input2.AllGather()); for (size_t i = 0; i < std::min(v1.size(), v2.size()); ++i) { check.push_back(v1[i] + v2[i]); // sLOG1 << check.back(); } } for (size_t i = 0; i != res.size(); ++i) { sLOG0 << res[i] << " " << check[i] << (res[i] == check[i]); } ASSERT_EQ(check.size(), res.size()); ASSERT_EQ(check, res); }; api::RunLocalTests(start_func); }