// Underlying implemention of parameter request void ParameterLoaderBase::ProcessRequest() { zmq::socket_t *socket = ZMQUtil::CreateSocket(); std::vector<Table*>& cache = Multiverso::double_buffer_->IOBuffer(); for (int i = 0; i < cache.size(); ++i) { cache[i]->Clear(); } int src_rank = Multiverso::ProcessRank(); int num_server = Multiverso::TotalServerCount(); std::vector<MsgPack*> send_list(num_server, nullptr); std::vector<int> send_ret_size(num_server, 0); int num_send_msg = 0; for (auto tuple : requests_) { integer_t table = tuple.table; integer_t row = tuple.row; integer_t col = tuple.col; if (row >= 0 && requests_.find({ table, -1, -1 }) != requests_.end() || col >= 0 && requests_.find({ table, row, -1 }) != requests_.end()) { continue; } int dst_rank, last_rank; if (row == -1) { dst_rank = 0; last_rank = num_server - 1; } else { dst_rank = (table + row) % num_server; last_rank = dst_rank; } while (dst_rank <= last_rank) { if (send_list[dst_rank] == nullptr) { send_list[dst_rank] = new MsgPack(MsgType::Get, MsgArrow::Worker2Server, src_rank, dst_rank); send_ret_size[dst_rank] = 0; } if (send_ret_size[dst_rank] + 3 * sizeof(integer_t) > kMaxMsgSize) { send_list[dst_rank]->Send(socket); ++num_send_msg; delete send_list[dst_rank]; send_list[dst_rank] = new MsgPack(MsgType::Get, MsgArrow::Worker2Server, src_rank, dst_rank); send_ret_size[dst_rank] = 0; } zmq::message_t* msg = new zmq::message_t(3 * sizeof(integer_t)); integer_t* buffer = static_cast<integer_t*>(msg->data()); buffer[0] = table; buffer[1] = row; buffer[2] = col; send_list[dst_rank]->Push(msg); send_ret_size[dst_rank] += 3 * sizeof(integer_t); ++dst_rank; } } for (int i = 0; i < num_server; ++i) { if (send_ret_size[i] > 0) { send_list[i]->Send(socket); ++num_send_msg; delete send_list[i]; } } // we expect each ReplyGet msg contains a over tag. while (num_send_msg > 0) { MsgPack reply(socket); for (int i = 1; i < reply.Size() - 1; ++i) { zmq::message_t* msg = reply.GetMsg(i); integer_t *buffer = static_cast<integer_t*>(msg->data()); integer_t table = buffer[0]; integer_t row = buffer[1]; cache[table]->GetRow(row)->BatchAdd(buffer + 2); } zmq::message_t* msg = reply.GetMsg(reply.Size() - 1); bool over = (static_cast<integer_t*>(msg->data())[0] == 1); if (over) { --num_send_msg; } } delete socket; }
void Aggregator::Send(int id, zmq::socket_t* socket) { int src_rank = Multiverso::ProcessRank(); int num_server = Multiverso::TotalServerCount(); std::vector<MsgPack*> send_list(num_server, nullptr); std::vector<int> send_ret_size(num_server, 0); for (int table_id = 0; table_id < tables_.size(); ++table_id) { Table* table = tables_[table_id]; TableIterator iter(*table); for (; iter.HasNext(); iter.Next()) { integer_t row_id = iter.RowId(); if (row_id % num_threads_ != id) { continue; } RowBase* row = iter.Row(); int dst_rank = (table_id + row_id) % num_server; if (send_list[dst_rank] == nullptr) { send_list[dst_rank] = new MsgPack(MsgType::Add, MsgArrow::Worker2Server, src_rank, dst_rank); send_ret_size[dst_rank] = 0; } // Format: table_id, row_id, number // col_1, col2, ..., col_n, val_1, val_2, ..., val_n; int msg_size = sizeof(integer_t)* 3 + row->NonzeroSize() * (table->ElementSize() + sizeof(integer_t)); if (msg_size > kMaxMsgSize) { // TODO(feiga): we currently assume the row serialized size // not ecceed kMaxMsgSize. should solve the issue later. Log::Error("Row size exceed the max size of message\n"); } if (send_ret_size[dst_rank] + msg_size > kMaxMsgSize) { send_list[dst_rank]->Send(socket); delete send_list[dst_rank]; send_list[dst_rank] = new MsgPack(MsgType::Add, MsgArrow::Worker2Server, src_rank, dst_rank); send_ret_size[dst_rank] = 0; } zmq::message_t* msg = new zmq::message_t(msg_size); integer_t* buffer = static_cast<integer_t*>(msg->data()); buffer[0] = table_id; buffer[1] = row_id; row->Serialize(buffer + 2); send_list[dst_rank]->Push(msg); send_ret_size[dst_rank] += msg_size; } } for (int i = 0; i < num_server; ++i) { if (send_ret_size[i] > 0) { send_list[i]->Send(socket); delete send_list[i]; } } }