Пример #1
0
static void verifyFullyConnected(NodeValue src, NodeValue weights,
                                 NodeValue bias, NodeValue dest) {
  assert(src.dims()[0] == dest.dims()[0] &&
         flattenCdr(src.dims()).second == weights.dims()[0] &&
         "Mismatch on expected source dimensions");

  assert(bias.dims()[0] == weights.dims()[1] &&
         weights.dims()[1] == dest.dims()[1] &&
         "Inconsistent bias/weights/dest sizes.");
}
Пример #2
0
static void verifyBatchNormalization(NodeValue src, NodeValue dest,
                                     NodeValue bias, NodeValue scale,
                                     NodeValue mean, NodeValue var,
                                     size_t channel) {
  checkSameType(dest, src);

  // Figure out how many channels are in the tensor.
  size_t channels = src.dims()[channel];

  auto exp = {channels};
  (void)exp;
  assert(bias.getType()->dims().equals(exp) && "Invalid bias dim");
  assert(scale.getType()->dims().equals(exp) && "Invalid scale dim");
  assert(mean.getType()->dims().equals(exp) && "Invalid mean dim");
  assert(var.getType()->dims().equals(exp) && "Invalid var dim");
}
Пример #3
0
static void verifyCrossEntropyLoss(NodeValue P, NodeValue CE,
                                   NodeValue labels) {
  assert(P.getElementType() == CE->getElementType());
  assert(P.dims()[0] == labels.dims()[0] && "Invalid shape");
}
Пример #4
0
static void verifySoftMax(NodeValue src, NodeValue dest) {
  checkSameType(src, dest);
  assert(src.dims() == dest.dims() && "Invalid shape");
}
Пример #5
0
/// Check that the shape of the first operand matches the shape of the second
/// operand.
static void checkSameShape(NodeValue A, NodeValue B) {
  assert(A.dims() == B.dims() && "Invalid shape");
}