Пример #1
0
bool passesJitHeuristics(Node *root_node) {
    if (!evalFlag()) return true;
    if (root_node->getHeight() >= (int)getMaxJitSize()) { return false; }

    size_t alloc_bytes, alloc_buffers;
    size_t lock_bytes, lock_buffers;

    deviceMemoryInfo(&alloc_bytes, &alloc_buffers, &lock_bytes, &lock_buffers);

    // Check if approaching the memory limit
    if (lock_bytes > getMaxBytes() || lock_buffers > getMaxBuffers()) {
        NodeIterator<jit::Node> it(root_node);
        NodeIterator<jit::Node> end_node;
        size_t bytes = accumulate(it, end_node, size_t(0),
                                  [=](const size_t prev, const Node &n) {
                                      // getBytes returns the size of the data
                                      // Array. Sub arrays will be represented
                                      // by their parent size.
                                      return prev + n.getBytes();
                                  });

        if (2 * bytes > lock_bytes) { return false; }
    }
    return true;
}
Пример #2
0
Array<T>
createNodeArray(const dim4 &dims, Node_ptr node)
{
    Array<T> out =  Array<T>(dims, node);

    if (evalFlag()) {
        if (node->getHeight() >= (int)getMaxJitSize()) {
            out.eval();
        } else {
            size_t alloc_bytes, alloc_buffers;
            size_t lock_bytes, lock_buffers;

            deviceMemoryInfo(&alloc_bytes, &alloc_buffers,
                             &lock_bytes, &lock_buffers);

            // Check if approaching the memory limit
            if (lock_bytes > getMaxBytes() ||
                lock_buffers > getMaxBuffers()) {

                Node *n = node.get();

                TNJ::Node_map_t nodes_map;
                vector<TNJ::Node *> full_nodes;
                n->getNodesMap(nodes_map, full_nodes);
                unsigned length =0, buf_count = 0, bytes = 0;
                for(auto &entry : nodes_map) {
                    Node *node = entry.first;
                    node->getInfo(length, buf_count, bytes);
                }

                if (2 * bytes > lock_bytes) {
                    out.eval();
                }
            }
        }
    }

    return out;
}