void eliminateNopTranspose(std::shared_ptr<Graph>& graph) { for(auto it = graph->nodes().begin(), end = graph->nodes().end(); it != end; ++it) { auto n = *it; if (n->kind() == kTranspose) { if (isNopTranspose(n->is(kperm))) { n->replaceAllUsesWith(n->input()->node()); it.destroyCurrent(); continue; } } } }
// The intent for this optimization pass is to catch all of the small, easy to // catch peephole optimizations you might be interested in doing. // // Right now, it does: // - Redundant 'expand' elimination // // TODO: Decide what kind of fixed point strategy we will have void PeepholeOptimize(std::shared_ptr<Graph>& graph) { for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) { auto* n = *it; // eliminate redundant expand if (n->kind() == kexpand) { if (n->is(ksize) == n->input()->type()->expect<TensorType>()->sizes()) { n->output()->replaceAllUsesWith(n->input()); it.destroyCurrent(); continue; } } } }
// Why this is here: // // Pytorch has a "packed" representation of sequences, as well as a // "padded" representation. ONNX has only one representation, // corresponding to pytorch's "padded". Therefore, we need to remove // any use of packed sequences before exporting. // // What this does: // // This code uses the observation that // RNN(PackPadded(x)) == PackPadded(RNN(x)) // and converts the first form to the second whenever possible, // "pushing" the packing operation past the RNN operation. Then, // the removeNopPacking pass removes the packing operations // entirely by pairing them with their inverse PadPacked. If the // input graph does not pair the operations, export will fail. void pushPackingPastRnn(std::shared_ptr<Graph>& graph) { for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) { auto* n = *it; if (n->kind() != kPackPadded) { continue; } if (n->outputs()[0]->uses().size() != 1) { // For now, only handle the case where there is one consumer. continue; } Node * rnn = n->outputs()[0]->uses()[0].user; if (!isRNN(rnn)) { continue; } // remove PackPadded from in front of the RNN n->outputs()[0]->replaceAllUsesWith(n->inputs()[0]); // note there can be multiple uses of the length blob. If we are // translating a multi-level RNN it will be an input to each level. n->outputs()[1]->replaceFirstUseWith(n->inputs()[1]); // and insert new PackPadded after the RNN Node * newPackPadded = graph->create(kPackPadded, 2); newPackPadded->insertAfter(rnn); // make things consume from the new PackPadded rnn->outputs()[0]->replaceAllUsesWith(newPackPadded->outputs()[0]); n->outputs()[1]->replaceAllUsesWith(newPackPadded->outputs()[1]); // setup the new PackPadded's inputs newPackPadded->addInput(rnn->outputs()[0]); newPackPadded->addInput(n->inputs()[1]); it.destroyCurrent(); } }
void removeNopPacking(std::shared_ptr<Graph>& graph) { for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) { auto* n = *it; if (n->kind() != kPadPacked) { continue; } Node* input = n->inputs()[0]->node(); if (input->kind() != kPackPadded) { continue; } if (input->outputs()[0] != n->inputs()[0]) { continue; } if (input->outputs()[1] != n->inputs()[1]) { continue; } n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]); n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]); n->removeAllInputs(); it.destroyCurrent(); } }