TEST(SumNode, GenerateAndSumHaveEqualAmount2) { std::function<void(Context&)> start_func = [](Context& ctx) { // TODO(ms): Replace this with some test-specific rendered file auto input = ReadLines(ctx, "inputs/test1") .Map([](const std::string& line) { return std::stoi(line); }); auto ones = input.Map([](int in) { return in; }); auto add_function = [](int in1, int in2) { return in1 + in2; }; DIA<int> coll = ones.Collapse(); ASSERT_EQ(136, coll.Sum(add_function)); ASSERT_EQ(16u, coll.Size()); }; api::RunLocalTests(start_func); }
TEST(Graph, WhileLoop) { std::function<void(Context&)> start_func = [](Context& ctx) { auto integers = Generate( ctx, [](const size_t& index) -> size_t { return index; }, 16); auto flatmap_duplicate = [](size_t in, auto emit) { emit(in); emit(in); }; auto map_multiply = [](size_t in) { return 2 * in; }; DIA<size_t> squares = integers.Collapse(); size_t sum = 0; // run loop four times, inflating DIA of 16 items -> 256 while (sum < 64) { auto pairs = squares.FlatMap(flatmap_duplicate); auto multiplied = pairs.Map(map_multiply); squares = multiplied.Cache(); sum = squares.Size(); } std::vector<size_t> out_vec = squares.AllGather(); ASSERT_EQ(64u, out_vec.size()); ASSERT_EQ(64u, squares.Size()); ctx.stats_graph().BuildLayout("loop.out"); }; api::RunLocalTests(start_func); }
ValueType Select(const DIA<ValueType, InStack>& data, size_t rank, const Compare& compare = Compare()) { api::Context& ctx = data.context(); const size_t size = data.Size(); assert(0 <= rank && rank < size); if (size <= base_case_size) { // base case, gather all data at worker with rank 0 ValueType result = ValueType(); auto elements = data.Gather(); if (ctx.my_rank() == 0) { assert(rank < elements.size()); std::nth_element(elements.begin(), elements.begin() + rank, elements.end(), compare); result = elements[rank]; LOG << "base case: " << size << " elements remaining, result is " << result; } result = ctx.net.Broadcast(result); return result; } ValueType left_pivot, right_pivot; std::tie(left_pivot, right_pivot) = PickPivots(data, size, rank, compare); size_t left_size, middle_size, right_size; using PartSizes = std::pair<size_t, size_t>; std::tie(left_size, middle_size) = data.Map( [&](const ValueType& elem) -> PartSizes { if (compare(elem, left_pivot)) return PartSizes { 1, 0 }; else if (!compare(right_pivot, elem)) return PartSizes { 0, 1 }; else return PartSizes { 0, 0 }; }) .Sum( [](const PartSizes& a, const PartSizes& b) -> PartSizes { return PartSizes { a.first + b.first, a.second + b.second }; }, PartSizes { 0, 0 }); right_size = size - left_size - middle_size; LOGM << "left_size = " << left_size << ", middle_size = " << middle_size << ", right_size = " << right_size << ", rank = " << rank; if (rank == left_size) { // all the elements strictly smaller than the left pivot are on the left // side -> left_size-th element is the left pivot LOGM << "result is left pivot: " << left_pivot; return left_pivot; } else if (rank == left_size + middle_size - 1) { // only the elements strictly greater than the right pivot are on the // right side, so the result is the right pivot in this case LOGM << "result is right pivot: " << right_pivot; return right_pivot; } else if (rank < left_size) { // recurse on the left partition LOGM << "Recursing left, " << left_size << " elements remaining (rank = " << rank << ")\n"; auto left = data.Filter( [&](const ValueType& elem) -> bool { return compare(elem, left_pivot); }).Collapse(); assert(left.Size() == left_size); return Select(left, rank, compare); } else if (left_size + middle_size <= rank) { // recurse on the right partition LOGM << "Recursing right, " << right_size << " elements remaining (rank = " << rank - left_size - middle_size << ")\n"; auto right = data.Filter( [&](const ValueType& elem) -> bool { return compare(right_pivot, elem); }).Collapse(); assert(right.Size() == right_size); return Select(right, rank - left_size - middle_size, compare); } else { // recurse on the middle partition LOGM << "Recursing middle, " << middle_size << " elements remaining (rank = " << rank - left_size << ")\n"; auto middle = data.Filter( [&](const ValueType& elem) -> bool { return !compare(elem, left_pivot) && !compare(right_pivot, elem); }).Collapse(); assert(middle.Size() == middle_size); return Select(middle, rank - left_size, compare); } }