Ejemplo n.º 1
0
static void verifyConvolution(NodeValue src, NodeValue dest, NodeValue filter,
                              NodeValue bias, size_t kernel, size_t stride,
                              size_t pad, size_t group) {
  assert(src.getElementType() == dest.getElementType() && "Invalid Type");
  assert(src.getElementType() == filter.getElementType() && "Invalid Type");
  assert(src.getElementType() == bias.getElementType() && "Invalid Type");

  ShapeNHWC idim(src.getType()->dims());
  ShapeNHWC odim(dest.getType()->dims());

  assert(idim.w >= kernel && idim.h >= kernel &&
         "buffer too small for selected stride");
  assert(idim.c % group == 0 && "channels number must be divisible by groups");

  auto outSz = calculateConvOutputDims(idim.h, idim.w, kernel, stride, pad);
  (void)outSz;
  assert(odim.n == idim.n && odim.h == outSz.first && odim.w == outSz.second &&
         odim.c % group == 0 && "Invalid output dimensions");

  auto filterDims = {odim.c, kernel, kernel, idim.c / group};
  assert(filter.getType()->dims().equals(filterDims) && "Invalid filter dims");
  (void)filterDims;

  auto biasDims = {odim.c};
  assert(bias.getType()->dims().equals(biasDims) && "Invalid bias dims");
  (void)biasDims;
}
Ejemplo n.º 2
0
static void verifyCrossEntropyLoss(NodeValue P, NodeValue CE,
                                   NodeValue labels) {
  assert(P.getElementType() == CE->getElementType());
  assert(P.dims()[0] == labels.dims()[0] && "Invalid shape");
}
Ejemplo n.º 3
0
static void checkType(NodeValue A, ElemKind expectedType) {
  assert(A.getElementType() == expectedType && "Invalid type");
}