Exemplo n.º 1
0
 void transaction_impl::prioritized(SODIUM_SHARED_PTR<node> target,
                                    std::function<void(transaction_impl*)> f)
 {
     entryID id = next_entry_id;
     next_entry_id = next_entry_id.succ();
     entries.insert(pair<entryID, prioritized_entry>(id, prioritized_entry(target, std::move(f))));
     prioritizedQ.insert(pair<rank_t, entryID>(rankOf(target), id));
 }
Exemplo n.º 2
0
 void transaction_impl::check_regen() {
     if (to_regen) {
         to_regen = false;
         prioritizedQ.clear();
         for (std::map<entryID, prioritized_entry>::iterator it = entries.begin(); it != entries.end(); ++it)
             prioritizedQ.insert(pair<rank_t, entryID>(rankOf(it->second.target), it->first));
     }
 }
        Nd4jStatus LegacyReduceBoolOp::validateAndExecute(Context &block) {
            auto x = INPUT_VARIABLE(0);


            int opNum = block.opNum() < 0 ? this->_opNum : block.opNum();
            nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum);

            auto axis = *block.getAxis();

            bool allAxes = false;

            if (block.width() == 1) {
                auto z = OUTPUT_VARIABLE(0);

                if (axis.size() == x->rankOf())
                    allAxes = true;

                if ((axis.empty()) ||
                    (axis.size() == 1 && axis[0] == MAX_INT) || allAxes) {
                    // scalar
                    NativeOpExcutioner::execReduceBoolScalar(opNum, x->getBuffer(), x->getShapeInfo(), block.getTArguments()->data(), z->buffer(), z->shapeInfo());
                } else {
                    // TAD
                    std::vector<int> dims(axis);

                    for (int e = 0; e < dims.size(); e++)
                        if (dims[e] < 0)
                            dims[e] += x->rankOf();

                    std::sort(dims.begin(), dims.end());

                    REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!");

                    shape::TAD tad;
                    tad.init(x->getShapeInfo(), dims.data(), dims.size());
                    tad.createTadOnlyShapeInfo();
                    tad.createOffsets();

                    NativeOpExcutioner::execReduceBool(opNum, x->getBuffer(), x->getShapeInfo(), block.getTArguments()->data(), z->getBuffer(), z->getShapeInfo(), dims.data(), (int) dims.size(), tad.tadOnlyShapeInfo, tad.tadOffsets);
                }

                STORE_RESULT(*z);
            } else {
                auto indices = INPUT_VARIABLE(1);
                if (indices->lengthOf() == x->rankOf())
                    allAxes = true;

                //indices->printIndexedBuffer("indices");

                std::vector<int> axis(indices->lengthOf());
                for (int e = 0; e < indices->lengthOf(); e++) {
                    // lol otherwise we segfault on macOS
                    int f = indices->e<int>(e);
                    axis[e] = f >= 0 ? f : f += x->rankOf();
                }

                if ((block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || allAxes) {
                    auto z = OUTPUT_VARIABLE(0);

                    auto b = x->getBuffer();
                    auto s = x->shapeInfo();
                    auto e = block.numT() > 0 ? block.getTArguments()->data() : nullptr;

                    //x->printIndexedBuffer("x");

                    // scalar
                    NativeOpExcutioner::execReduceBoolScalar(opNum, b, s, e, z->buffer(), z->shapeInfo());
                } else {
                    // TAD
                    if (indices->lengthOf() > 1)
                        std::sort(axis.begin(), axis.end());

                    REQUIRE_TRUE(axis.size() > 0, 0, "Some dimensions required for reduction!");

                    shape::TAD tad;
                    tad.init(x->getShapeInfo(), axis.data(), axis.size());
                    tad.createTadOnlyShapeInfo();
                    tad.createOffsets();

                    auto z = OUTPUT_VARIABLE(0);

                    NativeOpExcutioner::execReduceBool(opNum, x->getBuffer(), x->getShapeInfo(), block.getTArguments()->data(), z->getBuffer(), z->getShapeInfo(), axis.data(), (int) axis.size(), tad.tadOnlyShapeInfo, tad.tadOffsets);
                }
            }

            return Status::OK();
        }