Example #1
0
void eliminateNopTranspose(std::shared_ptr<Graph>& graph) {
  for(auto it = graph->nodes().begin(), end = graph->nodes().end(); it != end; ++it) {
    auto n = *it;
    if (n->kind() == kTranspose) {
      if (isNopTranspose(n->is(kperm))) {
        n->replaceAllUsesWith(n->input()->node());
        it.destroyCurrent();
        continue;
      }
    }
  }
}
Example #2
0
// The intent for this optimization pass is to catch all of the small, easy to
// catch peephole optimizations you might be interested in doing.
//
// Right now, it does:
//    - Redundant 'expand' elimination
//
// TODO: Decide what kind of fixed point strategy we will have
void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
  for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
    auto* n = *it;

    // eliminate redundant expand
    if (n->kind() == kexpand) {
      if (n->is(ksize) == n->input()->type()->expect<TensorType>()->sizes()) {
        n->output()->replaceAllUsesWith(n->input());
        it.destroyCurrent();
        continue;
      }
    }
  }
}
Example #3
0
// Why this is here:
//
//   Pytorch has a "packed" representation of sequences, as well as a
//   "padded" representation. ONNX has only one representation,
//   corresponding to pytorch's "padded". Therefore, we need to remove
//   any use of packed sequences before exporting.
//
// What this does:
//
//   This code uses the observation that
//     RNN(PackPadded(x)) == PackPadded(RNN(x))
//   and converts the first form to the second whenever possible,
//   "pushing" the packing operation past the RNN operation. Then,
//   the removeNopPacking pass removes the packing operations
//   entirely by pairing them with their inverse PadPacked. If the
//   input graph does not pair the operations, export will fail.
void pushPackingPastRnn(std::shared_ptr<Graph>& graph) {
  for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
    auto* n = *it;

    if (n->kind() != kPackPadded) {
      continue;
    }
    if (n->outputs()[0]->uses().size() != 1) {
      // For now, only handle the case where there is one consumer.
      continue;
    }
    Node * rnn = n->outputs()[0]->uses()[0].user;
    if (!isRNN(rnn)) {
      continue;
    }

    // remove PackPadded from in front of the RNN
    n->outputs()[0]->replaceAllUsesWith(n->inputs()[0]);

    // note there can be multiple uses of the length blob. If we are
    // translating a multi-level RNN it will be an input to each level.
    n->outputs()[1]->replaceFirstUseWith(n->inputs()[1]);

    // and insert new PackPadded after the RNN
    Node * newPackPadded = graph->create(kPackPadded, 2);
    newPackPadded->insertAfter(rnn);

    // make things consume from the new PackPadded
    rnn->outputs()[0]->replaceAllUsesWith(newPackPadded->outputs()[0]);
    n->outputs()[1]->replaceAllUsesWith(newPackPadded->outputs()[1]);

    // setup the new PackPadded's inputs
    newPackPadded->addInput(rnn->outputs()[0]);
    newPackPadded->addInput(n->inputs()[1]);

    it.destroyCurrent();
  }
}
Example #4
0
void removeNopPacking(std::shared_ptr<Graph>& graph) {
  for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
    auto* n = *it;

    if (n->kind() != kPadPacked) {
      continue;
    }
    Node* input = n->inputs()[0]->node();
    if (input->kind() != kPackPadded) {
      continue;
    }
    if (input->outputs()[0] != n->inputs()[0]) {
      continue;
    }
    if (input->outputs()[1] != n->inputs()[1]) {
      continue;
    }
    n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]);
    n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]);

    n->removeAllInputs();
    it.destroyCurrent();
  }
}