bool passesJitHeuristics(Node *root_node) { if (!evalFlag()) return true; if (root_node->getHeight() >= (int)getMaxJitSize()) { return false; } size_t alloc_bytes, alloc_buffers; size_t lock_bytes, lock_buffers; deviceMemoryInfo(&alloc_bytes, &alloc_buffers, &lock_bytes, &lock_buffers); // Check if approaching the memory limit if (lock_bytes > getMaxBytes() || lock_buffers > getMaxBuffers()) { NodeIterator<jit::Node> it(root_node); NodeIterator<jit::Node> end_node; size_t bytes = accumulate(it, end_node, size_t(0), [=](const size_t prev, const Node &n) { // getBytes returns the size of the data // Array. Sub arrays will be represented // by their parent size. return prev + n.getBytes(); }); if (2 * bytes > lock_bytes) { return false; } } return true; }
Array<T> createNodeArray(const dim4 &dims, Node_ptr node) { Array<T> out = Array<T>(dims, node); if (evalFlag()) { if (node->getHeight() >= (int)getMaxJitSize()) { out.eval(); } else { size_t alloc_bytes, alloc_buffers; size_t lock_bytes, lock_buffers; deviceMemoryInfo(&alloc_bytes, &alloc_buffers, &lock_bytes, &lock_buffers); // Check if approaching the memory limit if (lock_bytes > getMaxBytes() || lock_buffers > getMaxBuffers()) { Node *n = node.get(); TNJ::Node_map_t nodes_map; vector<TNJ::Node *> full_nodes; n->getNodesMap(nodes_map, full_nodes); unsigned length =0, buf_count = 0, bytes = 0; for(auto &entry : nodes_map) { Node *node = entry.first; node->getInfo(length, buf_count, bytes); } if (2 * bytes > lock_bytes) { out.eval(); } } } } return out; }