Exemplo n.º 1
0
    void Test(api::Context& ctx) {

        common::StatsTimerStopped t;

        // only work with first thread on this host.
        if (ctx.local_worker_id() == 0)
        {
            mem::Manager mem_manager(nullptr, "Dispatcher");

            group_ = &ctx.net.group();
            std::unique_ptr<net::Dispatcher> dispatcher =
                group_->ConstructDispatcher();
            dispatcher_ = dispatcher.get();

            t.Start();

            for (size_t outer = 0; outer < outer_repeats_; ++outer)
            {
                rnd_ = std::default_random_engine(123456);

                active_ = 0;
                remaining_requests_ = num_requests_;

                while (active_ < limit_active_ && remaining_requests_ > 0)
                {
                    if (MaybeStartRequest()) {
                        ++active_;
                    }
                }

                dispatcher_->Loop();
            }

            t.Stop();

            // must clean up dispatcher prior to using group for other things.
        }

        size_t time = t.Microseconds();
        // calculate maximum time.
        time = ctx.net.AllReduce(time, common::maximum<size_t>());

        if (ctx.my_rank() == 0) {
            std::cout
                << "RESULT"
                << " operation=" << "rblocks"
                << " hosts=" << group_->num_hosts()
                << " requests=" << num_requests_
                << " block_size=" << block_size_
                << " limit_active=" << limit_active_
                << " time[us]=" << time
                << " time_per_op[us]="
                << static_cast<double>(time) / num_requests_
                << " total_bytes=" << block_size_ * num_requests_
                << " total_bandwidth[MiB/s]="
                << CalcMiBs(block_size_ * num_requests_, time)
                << std::endl;
        }
    }
Exemplo n.º 2
0
void Bandwidth::Test(api::Context& ctx) {

    // only work with first thread on this host.
    if (ctx.local_worker_id() != 0) return;

    net::Group& group = ctx.net.group();

    bandwidth_ = AggMatrix(group.num_hosts());

    // data block to send or receive
    block_count_ = data_size_ / block_size_;
    data_block_.resize(block_size_ / sizeof(size_t), 42u);

    for (size_t outer_repeat = 0;
         outer_repeat < outer_repeats_; ++outer_repeat) {

        common::StatsTimerStopped timer;

        timer.Start();
        for (size_t inner_repeat = 0;
             inner_repeat < inner_repeats_; inner_repeat++) {
            // perform 1-factor ping pongs (without barriers)
            for (size_t round = 0; round < group.OneFactorSize(); ++round) {

                size_t peer = group.OneFactorPeer(round);

                sLOG0 << "round" << round
                      << "me" << ctx.host_rank() << "peer_id" << peer;

                if (ctx.host_rank() < peer) {
                    Sender(ctx, peer, inner_repeat);
                    Receiver(ctx, peer);
                }
                else if (ctx.host_rank() > peer) {
                    Receiver(ctx, peer);
                    Sender(ctx, peer, inner_repeat);
                }
                else {
                    // not participating in this round
                    counter_ += 2 * block_count_;
                }
            }
        }
        timer.Stop();

        size_t time = timer.Microseconds();
        // calculate maximum time.
        group.AllReduce(time, common::maximum<size_t>());

        if (ctx.my_rank() == 0) {
            std::cout
                << "RESULT"
                << " benchmark=" << benchmark
                << " hosts=" << ctx.num_hosts()
                << " outer_repeat=" << outer_repeat
                << " inner_repeats=" << inner_repeats_
                << " time[us]=" << time
                << " time_per_ping_pong[us]="
                << static_cast<double>(time) / static_cast<double>(counter_)
                << std::endl;
        }
    }

    // reduce (add) matrix to root.
    group.Reduce(bandwidth_);

    // print matrix
    if (ctx.my_rank() == 0)
        PrintMatrix(bandwidth_);
}