VariableSpace& VariableProxy::operator=(const VariableSpace& other) {
            if (this == &other) return *this;

            nd4j_printf("VariableProxy = not implemented\n","");

            return *this;
        }  
Esempio n. 2
0
        Variable<T>* Context<T>::variable(std::pair<int,int>& p) {
            if (!_variableSpace->hasVariable(p)) {
                nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", this->_nodeId, p.first, p.second);
                throw "Bad variable";
            }

            return _variableSpace->getVariable(p);
        }
Esempio n. 3
0
        Variable<T>* Context<T>::getVariable(int idx) {
            if (idx >= this->_inputs.size()) {
                nd4j_printf("Node %i; Variable [%i] requested, but only %i inputs available\n", this->_nodeId, idx, this->_inputs.size());
                throw "Bad index";
            }

            auto p = this->_inputs[idx];

            auto v = variable(p);

            if (Environment::getInstance()->isDebugAndVerbose() && v != nullptr &&  v->getNDArray() != nullptr) {
                std::string shape_ = ShapeUtils<T>::shapeAsString(*(v->getNDArray()));
                nd4j_printf("Debug info for node_%i input[%i]; shape: %s; mean value: [%f]\n", this->_nodeId, idx, shape_.c_str(), (float) v->getNDArray()->meanNumber());
            }

            return v;
        }
        Graph* GraphHolder::pullGraph(Nd4jLong graphId) {
            if (!this->hasGraph(graphId)) {
                nd4j_printf("GraphHolder doesn't have graph stored for [%lld]\n", graphId);
                throw std::runtime_error("Bad argument");
            }

            auto graph = _graphF[graphId];

            return graph;
        }
        nd4j::graph::Variable *VariableProxy::getVariable(std::string *symbol) {
            if (_current->hasVariable(symbol))
                return _current->getVariable(symbol);
            
            if (_backed->hasVariable(symbol))
                return _backed->getVariable(symbol);

            nd4j_printf("Unable to get Variable to proxy: [%s]\n", symbol->c_str());
            throw std::runtime_error("Bad arguments");
        }
        nd4j::graph::Variable *VariableProxy::getVariable(std::pair<int,int>& pair) {
            if (_current->hasVariable(pair))
                return _current->getVariable(pair);
            
            if (_backed->hasVariable(pair))
                return _backed->getVariable(pair);

            nd4j_printf("Unable to get Variable to proxy: [%i:%i]\n", pair.first, pair.second);
            throw std::runtime_error("Bad arguments");
        }
        nd4j::graph::Variable *VariableProxy::getVariable(int id, int idx) {
            if (_current->hasVariable(id, idx))
                return _current->getVariable(id, idx);
            
            if (_backed->hasVariable(id, idx))
                return _backed->getVariable(id, idx);

            nd4j_printf("Unable to get Variable to proxy: [%i:%i]\n", id, idx);
            throw std::runtime_error("Bad arguments");
        }
static int _matrixDiag(const NDArray* input, NDArray* output) {

    auto listOut  = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1});
    auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1});

    if (listOut->size() != listDiag->size()) {
        nd4j_printf("matrix_diag: Input matrix has wrong shape.", "");
        return ND4J_STATUS_VALIDATION;
    }
    int lastDimension = input->sizeAt(-1);
    // TODO: tune this properlys
#pragma omp parallel for if(listOut->size() > Environment::getInstance()->elementwiseThreshold()) schedule(static)
    // condition is hold: listOut->size() == listDiag->size()
    for(int i = 0; i < listOut->size(); ++i)       
        for (int e = 0; e < lastDimension; e++)
            listOut->at(i)->p(e, e, listDiag->at(i)->e<T>(e));
    
    delete listOut;
    delete listDiag;

    return Status::OK();
}
Esempio n. 9
0
        nd4j::graph::Node::Node(const nd4j::graph::FlatNode *node) {
            _hasExternalInputs = false;
            _hasExternalOutputs = false;
            _hasInternalInputs = false;
            _hasInternalOutputs = false;
            _extraParams = nullptr;
            _dim = nullptr;
            _dataType = nd4j::DataType::FLOAT32; // float as default
            if (node->scope_id() != 0)
                this->_scope_id = node->scope_id();

            if (node->scope_name() != nullptr && node->scope_name()->size() > 0)
                this->_scope_name = node->scope_name()->str();

            if (node->scalar() != nullptr) {
                auto scalar = nd4j::graph::FlatUtils::fromFlatArray(node->scalar());
                _scalar = *scalar;
                delete scalar;
            }

            if (node != nullptr) {
                this->_id = node->id();
                //this->_dataType = DataTypeUtils::fromFlatDataType(node->dataType());
                this->_opNum = node->opNum();
                this->_opType = node->opType();

                if (node->name() != nullptr && node->name()->c_str() != nullptr) {
                    this->_name = node->name()->str();
                }

                if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) {
                    for (int e = 0; e < (int) node->inputPaired()->size(); e++) {
                        auto pair = node->inputPaired()->Get(e);
                        pickInput(pair->first(), pair->second());
                    }
                } else if (node->input() != nullptr && node->input()->size() > 0) {
                    for (int e = 0; e < (int) node->input()->size(); e++)
                        pickInput(node->input()->Get(e));
                } else {
                    if (this->opType() != OpType_LOGIC) {
                        if (this->_name.size() > 0) {
                            nd4j_debug("Node [%i:<%s>] has no inputs defined\n", this->_id, this->_name.c_str());
                        } else {
                            nd4j_debug("Node [%i:<noname>] has no inputs defined\n", this->_id);
                        }
                    }
                }

                /*
                if (node->output() != nullptr)
                    for (int e = 0; e < (int) node->output()->size(); e++) {
                        auto oid = node->output()->Get(e);
                        if (oid != this->_id && oid != 0) {
                            nd4j_verbose("Picking output: %i\n", node->output()->Get(e));
                            pickOutput(oid);
                        }
                    }
                */


                if (node->extraParams() != nullptr && node->extraParams()->size() > 0) {
                    _extraParams = new double[node->extraParams()->size()];
                    for (int e = 0; e < (int) node->extraParams()->size(); e++) {
                        _extraParams[e] = static_cast<double>(node->extraParams()->Get(e));
                    }
                }

                if (node->dimensions() != nullptr && node->dimensions()->size() > 0) {
                    _dim = new int[node->dimensions()->size()];
                    for (int e = 0; e < (int) node->dimensions()->size(); e++) {
                        _dimensions.emplace_back(node->dimensions()->Get(e));
                        _dim[e] = node->dimensions()->Get(e);
                    }
                }

                if (this->opType() == OpType_LOGIC && this->opNum() == 100L) {
                    if (node->extraInteger()->size() < 1) {
                        nd4j_printf("Node_%i is type of Enter, but has no FrameID defined\n", this->id());
                        throw std::runtime_error("Enter node must have FrameID specified");
                    }

                    this->setFrameId(node->extraInteger()->Get(0));
                }


                // these ops allow in-place execution by design
                if (_opType == OpType_BROADCAST ||
                    _opType == OpType_BROADCAST_BOOL ||
                        _opType == OpType_INDEX_REDUCE ||
                        _opType == OpType_SUMMARYSTATS ||
                        _opType == OpType_REDUCE_BOOL ||
                        _opType == OpType_REDUCE_SAME ||
                        _opType == OpType_REDUCE_FLOAT ||
                        _opType == OpType_REDUCE_3 ||
                        _opType == OpType_TRANSFORM_STRICT ||
                        _opType == OpType_TRANSFORM_SAME ||
                        _opType == OpType_TRANSFORM_FLOAT ||
                        _opType == OpType_TRANSFORM_BOOL ||
                        _opType == OpType_RANDOM ||
                        _opType == OpType_PAIRWISE ||
                        _opType == OpType_PAIRWISE_BOOL ||
                        _opType == OpType_SCALAR_BOOL ||
                        _opType == OpType_SCALAR) {

                    if (_output.size() <= 1) {
                        _isInplace = true;
                    }

                    if (node->input() != nullptr && node->input()->size() > 0) {
                        this->_isDeductable = true;

                        auto block = new ContextPrototype(nullptr, this->id(), false);


                        for (auto v: _dimensions)
                            block->getAxis()->emplace_back(v);

                        if (node->extraParams() != nullptr && node->extraParams()->size() > 0)
                            for (int e = 0; e < (int) node->extraParams()->size(); e++) {
                                block->getTArguments()->emplace_back(static_cast<double>(node->extraParams()->Get(e)));
                            }

                        if (node->extraBools() != nullptr && node->extraBools()->size() > 0)
                            for (int e = 0; e < (int) node->extraBools()->size(); e++) {
                                block->getBArguments()->push_back(node->extraBools()->Get(e));
                            }

                        if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0)
                            for (int e = 0; e < (int) node->extraInteger()->size(); e++) {
                                block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
                            }

                        this->setContextPrototype(block);
                        this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
                        block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
                    } else if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) {
                        this->_isDeductable = true;

                        auto block = new ContextPrototype(nullptr, this->id(), false);

                        for (int e = 0; e < this->input()->size(); e++) {
                            block->inputs()->emplace_back(this->input()->at(e));
                        }

                        // there's no other IArgs in legacy options, actually
                        for (auto v: _dimensions)
                            block->getAxis()->emplace_back(v);

                        if (node->extraParams() != nullptr && node->extraParams()->size() > 0)
                            for (int e = 0; e < (int) node->extraParams()->size(); e++) {
                                block->getTArguments()->emplace_back(static_cast<double>(node->extraParams()->Get(e)));
                            }

                        if (node->extraBools() != nullptr && node->extraBools()->size() > 0)
                            for (int e = 0; e < (int) node->extraBools()->size(); e++) {
                                block->getBArguments()->push_back(node->extraBools()->Get(e));
                            }

                        if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0)
                            for (int e = 0; e < (int) node->extraInteger()->size(); e++) {
                                block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
                            }

                        this->setContextPrototype(block);

                        this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
                        block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
                    }
                } else if (this->_opType == OpType_CUSTOM) {
                        auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(this->opNum());
                        if (op == nullptr) {
                            nd4j_verbose("Can't find operation: %lld\n", this->opNum());
                            throw std::runtime_error("Can't find requested operation");
                        }

                        auto block = new ContextPrototype(nullptr, this->id());

                        for (int e = 0; e < this->input()->size(); e++) {
                            block->inputs()->emplace_back(this->input()->at(e));
                        }

                        if (node->extraInteger() != nullptr)
                            for (uint32_t e = 0; e < node->extraInteger()->size(); e++) {
                                auto v = node->extraInteger()->Get(e);
                                // FIXME: remove this static_cast, iArgs should be Nd4jLong
                                block->getIArguments()->emplace_back(static_cast<int>(v));
                            }

                        if (node->extraParams() != nullptr)
                            for (uint32_t e = 0; e < node->extraParams()->size(); e++)
                                block->getTArguments()->emplace_back(static_cast<double>(node->extraParams()->Get(e)));

                        if (node->extraBools() != nullptr && node->extraBools()->size() > 0)
                            for (int e = 0; e < (int) node->extraBools()->size(); e++) {
                                block->getBArguments()->push_back(node->extraBools()->Get(e));
                            }

                        for (auto v: _dimensions)
                            block->getAxis()->emplace_back(v);

                        this->setContextPrototype(block);
                        this->setCustomOp(op);
                        block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
                }
            } else {
                // empty dynamic node, tests probably
            }
        }