Ejemplo n.º 1
0
void OutputSVG(const std::string& svg_path, double svg_scale,
               const DIA<Point<2> >& point_dia,
               const KMeansModel<Point<2> >& model) {
    double width = 0, height = 0;

    using Point2D = Point<2>;

    const std::vector<Point2D>& centroids = model.centroids();
    std::vector<PointClusterId<Point2D> > list =
        model.ClassifyPairs(point_dia).Gather();

    for (const PointClusterId<Point2D>& p : list) {
        width = std::max(width, p.first.x[0]);
        height = std::max(height, p.first.x[1]);
    }

    if (point_dia.context().my_rank() != 0) return;

    std::ofstream os(svg_path);

    os << "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n";
    os << "<svg\n";
    os << "   xmlns:dc=\"http://purl.org/dc/elements/1.1/\"\n";
    os << "   xmlns:cc=\"http://creativecommons.org/ns#\"\n";
    os << "   xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\"\n";
    os << "   xmlns:svg=\"http://www.w3.org/2000/svg\"\n";
    os << "   xmlns=\"http://www.w3.org/2000/svg\"\n";
    os << "   version=\"1.1\" id=\"svg2\" width=\"" << width * svg_scale
       << "\" height=\"" << height * svg_scale << "\">\n";
    os << "  <g id=\"layer1\">\n";

    for (const PointClusterId<Point2D>& p : list) {
        os << "    <circle r=\"1\" cx=\"" << p.first.x[0] * svg_scale
           << "\" cy=\"" << p.first.x[1] * svg_scale
           << "\" style=\"stroke:none;stroke-opacity:1;fill:"
           << SVGColor(p.second) << ";fill-opacity:1\" />\n";
    }
    for (size_t i = 0; i < centroids.size(); ++i) {
        const Point2D& p = centroids[i];
        os << "    <circle r=\"4\" cx=\"" << p.x[0] * svg_scale
           << "\" cy=\"" << p.x[1] * svg_scale
           << "\" style=\"stroke:black;stroke-opacity:1;fill:"
           << SVGColor(i) << ";fill-opacity:1\" />\n";
    }
    os << " </g>\n";
    os << "</svg>\n";
}
Ejemplo n.º 2
0
std::pair<ValueType, ValueType>
PickPivots(const DIA<ValueType, InStack>& data, size_t size, size_t rank,
           const Compare& compare = Compare()) {
    api::Context& ctx = data.context();

    const size_t num_workers(ctx.num_workers());
    const double size_d = static_cast<double>(size);

    const double p = 20 * sqrt(static_cast<double>(num_workers)) / size_d;

    // materialized at worker 0
    auto sample = data.BernoulliSample(p).Gather();

    std::pair<ValueType, ValueType> pivots;
    if (ctx.my_rank() == 0) {
        LOG << "got " << sample.size() << " samples (p = " << p << ")";
        // Sort the samples
        std::sort(sample.begin(), sample.end(), compare);

        const double base_pos =
            static_cast<double>(rank * sample.size()) / size_d;
        const double offset = pow(size_d, 0.25 + delta);

        long lower_pos = static_cast<long>(floor(base_pos - offset));
        long upper_pos = static_cast<long>(floor(base_pos + offset));

        size_t lower = static_cast<size_t>(std::max(0L, lower_pos));
        size_t upper = static_cast<size_t>(
            std::min(upper_pos, static_cast<long>(sample.size() - 1)));

        assert(0 <= lower && lower < sample.size());
        assert(0 <= upper && upper < sample.size());

        LOG << "Selected pivots at positions " << lower << " and " << upper
            << ": " << sample[lower] << " and " << sample[upper];

        pivots = std::make_pair(sample[lower], sample[upper]);
    }

    pivots = ctx.net.Broadcast(pivots);

    LOGM << "pivots: " << pivots.first << " and " << pivots.second;

    return pivots;
}
Ejemplo n.º 3
0
ValueType Select(const DIA<ValueType, InStack>& data, size_t rank,
                 const Compare& compare = Compare()) {
    api::Context& ctx = data.context();
    const size_t size = data.Size();

    assert(0 <= rank && rank < size);

    if (size <= base_case_size) {
        // base case, gather all data at worker with rank 0
        ValueType result = ValueType();
        auto elements = data.Gather();

        if (ctx.my_rank() == 0) {
            assert(rank < elements.size());
            std::nth_element(elements.begin(), elements.begin() + rank,
                             elements.end(), compare);

            result = elements[rank];

            LOG << "base case: " << size << " elements remaining, result is "
                << result;
        }

        result = ctx.net.Broadcast(result);
        return result;
    }

    ValueType left_pivot, right_pivot;
    std::tie(left_pivot, right_pivot) = PickPivots(data, size, rank, compare);

    size_t left_size, middle_size, right_size;

    using PartSizes = std::pair<size_t, size_t>;
    std::tie(left_size, middle_size) =
        data.Map(
            [&](const ValueType& elem) -> PartSizes {
                if (compare(elem, left_pivot))
                    return PartSizes { 1, 0 };
                else if (!compare(right_pivot, elem))
                    return PartSizes { 0, 1 };
                else
                    return PartSizes { 0, 0 };
            })
        .Sum(
            [](const PartSizes& a, const PartSizes& b) -> PartSizes {
                return PartSizes { a.first + b.first, a.second + b.second };
            },
            PartSizes { 0, 0 });
    right_size = size - left_size - middle_size;

    LOGM << "left_size = " << left_size << ", middle_size = " << middle_size
         << ", right_size = " << right_size << ", rank = " << rank;

    if (rank == left_size) {
        // all the elements strictly smaller than the left pivot are on the left
        // side -> left_size-th element is the left pivot
        LOGM << "result is left pivot: " << left_pivot;
        return left_pivot;
    }
    else if (rank == left_size + middle_size - 1) {
        // only the elements strictly greater than the right pivot are on the
        // right side, so the result is the right pivot in this case
        LOGM << "result is right pivot: " << right_pivot;
        return right_pivot;
    }
    else if (rank < left_size) {
        // recurse on the left partition
        LOGM << "Recursing left, " << left_size
             << " elements remaining (rank = " << rank << ")\n";

        auto left = data.Filter(
            [&](const ValueType& elem) -> bool {
                return compare(elem, left_pivot);
            }).Collapse();
        assert(left.Size() == left_size);

        return Select(left, rank, compare);
    }
    else if (left_size + middle_size <= rank) {
        // recurse on the right partition
        LOGM << "Recursing right, " << right_size
             << " elements remaining (rank = " << rank - left_size - middle_size
             << ")\n";

        auto right = data.Filter(
            [&](const ValueType& elem) -> bool {
                return compare(right_pivot, elem);
            }).Collapse();
        assert(right.Size() == right_size);

        return Select(right, rank - left_size - middle_size, compare);
    }
    else {
        // recurse on the middle partition
        LOGM << "Recursing middle, " << middle_size
             << " elements remaining (rank = " << rank - left_size << ")\n";

        auto middle = data.Filter(
            [&](const ValueType& elem) -> bool {
                return !compare(elem, left_pivot) &&
                !compare(right_pivot, elem);
            }).Collapse();
        assert(middle.Size() == middle_size);

        return Select(middle, rank - left_size, compare);
    }
}