static void verifyConvolution(NodeValue src, NodeValue dest, NodeValue filter, NodeValue bias, size_t kernel, size_t stride, size_t pad, size_t group) { assert(src.getElementType() == dest.getElementType() && "Invalid Type"); assert(src.getElementType() == filter.getElementType() && "Invalid Type"); assert(src.getElementType() == bias.getElementType() && "Invalid Type"); ShapeNHWC idim(src.getType()->dims()); ShapeNHWC odim(dest.getType()->dims()); assert(idim.w >= kernel && idim.h >= kernel && "buffer too small for selected stride"); assert(idim.c % group == 0 && "channels number must be divisible by groups"); auto outSz = calculateConvOutputDims(idim.h, idim.w, kernel, stride, pad); (void)outSz; assert(odim.n == idim.n && odim.h == outSz.first && odim.w == outSz.second && odim.c % group == 0 && "Invalid output dimensions"); auto filterDims = {odim.c, kernel, kernel, idim.c / group}; assert(filter.getType()->dims().equals(filterDims) && "Invalid filter dims"); (void)filterDims; auto biasDims = {odim.c}; assert(bias.getType()->dims().equals(biasDims) && "Invalid bias dims"); (void)biasDims; }
static void verifyPool(NodeValue src, NodeValue dest, size_t kernel, size_t stride, size_t pad) { ShapeNHWC idim = ShapeNHWC(src.getType()->dims()); ShapeNHWC odim = ShapeNHWC(dest.getType()->dims()); (void)odim; assert(idim.w >= kernel && idim.h >= kernel && "buffer too small for selected stride"); auto outSz = calculateConvOutputDims(idim.h, idim.w, kernel, stride, pad); ShapeNHWC exp(idim.n, outSz.first, outSz.second, idim.c); (void)exp; assert(exp == odim && "Unexpected output dimensions"); }
static void verifyBatchNormalization(NodeValue src, NodeValue dest, NodeValue bias, NodeValue scale, NodeValue mean, NodeValue var, size_t channel) { checkSameType(dest, src); // Figure out how many channels are in the tensor. size_t channels = src.dims()[channel]; auto exp = {channels}; (void)exp; assert(bias.getType()->dims().equals(exp) && "Invalid bias dim"); assert(scale.getType()->dims().equals(exp) && "Invalid scale dim"); assert(mean.getType()->dims().equals(exp) && "Invalid mean dim"); assert(var.getType()->dims().equals(exp) && "Invalid var dim"); }
void NodeUse::setOperand(NodeValue &other) { if (other && site_->getNode()) { assert(site_->getType() == other.getType() && "Setting operand to a node with a different type"); } site_->setOperand(other.getNode(), other.getResNo()); }
void NodeValue::replaceAllUsesOfWith(NodeValue v) { if (v.getNode()) { assert(getType() == v.getType() && "Replacing value with the wrong type"); } auto &users = node_->getUsers(); llvm::SmallVector<NodeUse, 4> usersVec(users.begin(), users.end()); for (auto &U : usersVec) { NodeValue *site = U.get(); assert(site->getNode() == node_ && "Invalid user"); if (site->getResNo() == getResNo()) { site->setOperand(v.getNode(), v.getResNo()); } } }
/// Check that the type of the first operand matches the type of the second /// operand. static void checkSameType(NodeValue A, NodeValue B) { assert(A.getType() == B.getType() && "Invalid type"); }