Пример #1
0
 BinaryNode(const char *out_type_str, const char *name_str,
            const char *op_str,
            Node_ptr lhs, Node_ptr rhs, int op)
     : Node(out_type_str, name_str, std::max(lhs->getHeight(), rhs->getHeight()) + 1),
       m_op_str(op_str),
       m_lhs(lhs),
       m_rhs(rhs),
       m_op(op)
 {
 }
Пример #2
0
 UnaryNode(const char *out_type_str, const char *name_str,
           const char *op_str,
           Node_ptr child, int op)
     : Node(out_type_str, name_str, child->getHeight() + 1, {{child}}),
       m_op_str(op_str),
       m_op(op)
 {
 }
Пример #3
0
    Array<T> createNodeArray(const dim4 &dims, Node_ptr node)
    {
        Array<T> out =  Array<T>(dims, node);

        if (evalFlag()) {

            if (node->getHeight() >= (int)getMaxJitSize()) {
                out.eval();
            } else {
                size_t alloc_bytes, alloc_buffers;
                size_t lock_bytes, lock_buffers;

                deviceMemoryInfo(&alloc_bytes, &alloc_buffers,
                                 &lock_bytes, &lock_buffers);

                // Check if approaching the memory limit
                if (lock_bytes > getMaxBytes() ||
                    lock_buffers > getMaxBuffers()) {

                    unsigned length =0, buf_count = 0, bytes = 0;
                    Node *n = node.get();
                    JIT::Node_map_t nodes_map;
                    std::vector<JIT::Node *> full_nodes;
                    std::vector<JIT::Node_ids> full_ids;
                    n->getNodesMap(nodes_map, full_nodes, full_ids);

                    for(auto &jit_node : full_nodes) {
                        jit_node->getInfo(length, buf_count, bytes);
                    }

                    if (2 * bytes > lock_bytes) {
                        out.eval();
                    }
                }
            }
        }

        return out;
    }
Пример #4
0
 UnaryNode(Node_ptr child) :
     Node(child->getHeight() + 1),
     m_child(child),
     m_val(0)
 {
 }
Пример #5
0
 UnaryNode(Node_ptr child) :
     TNode<To>(0, child->getHeight() + 1, {{child}}),
     m_child(reinterpret_cast<TNode<Ti> *>(child.get()))
 {
 }