void extractBinaryLayerParams(const caffe::LayerParameter& layer, LayerParams& layerParams) { const std::string &name = layer.name(); int li; for (li = 0; li != netBinary.layer_size(); li++) { const caffe::LayerParameter& binLayer = netBinary.layer(li); // Break if the layer name is the same and the blobs are not cleared if (binLayer.name() == name && binLayer.blobs_size() != 0) break; } if (li == netBinary.layer_size()) return; caffe::LayerParameter* binLayer = netBinary.mutable_layer(li); const int numBlobs = binLayer->blobs_size(); layerParams.blobs.resize(numBlobs); for (int bi = 0; bi < numBlobs; bi++) { blobFromProto(binLayer->blobs(bi), layerParams.blobs[bi]); } binLayer->clear_blobs(); CV_Assert(numBlobs == binLayer->blobs().ClearedCount()); for (int bi = 0; bi < numBlobs; bi++) { delete binLayer->mutable_blobs()->ReleaseCleared(); } }
Waifu2x::eWaifu2xError cNet::SetParameter(caffe::NetParameter ¶m, const std::string &process) const { param.mutable_state()->set_phase(caffe::TEST); { auto input_layer = param.mutable_layer(0); auto mid = input_layer->mutable_input_param()->mutable_shape(); if (mid->size() != 1 || mid->Mutable(0)->dim_size() != 4) return Waifu2x::eWaifu2xError_FailedParseModelFile; } for (int i = 0; i < param.layer_size(); i++) { caffe::LayerParameter *layer_param = param.mutable_layer(i); const std::string& type = layer_param->type(); if (type == "Convolution") { if (process == "cudnn") layer_param->mutable_convolution_param()->set_engine(caffe::ConvolutionParameter_Engine_CUDNN); else layer_param->mutable_convolution_param()->set_engine(caffe::ConvolutionParameter_Engine_CAFFE); } else if (type == "Deconvolution") { if (process == "cudnn") layer_param->mutable_convolution_param()->set_engine(caffe::ConvolutionParameter_Engine_CUDNN); else layer_param->mutable_convolution_param()->set_engine(caffe::ConvolutionParameter_Engine_CAFFE); } else if (type == "ReLU") { if (process == "cudnn") layer_param->mutable_relu_param()->set_engine(caffe::ReLUParameter_Engine_CUDNN); else layer_param->mutable_relu_param()->set_engine(caffe::ReLUParameter_Engine_CAFFE); } } return Waifu2x::eWaifu2xError_OK; }
void populateNet(Net dstNet) { CV_TRACE_FUNCTION(); int layersSize = net.layer_size(); layerCounter.clear(); addedBlobs.clear(); addedBlobs.reserve(layersSize + 1); //setup input layer names std::vector<String> netInputs(net.input_size()); { for (int inNum = 0; inNum < net.input_size(); inNum++) { addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum)); netInputs[inNum] = net.input(inNum); } } for (int li = 0; li < layersSize; li++) { const caffe::LayerParameter &layer = net.layer(li); String name = layer.name(); String type = layer.type(); LayerParams layerParams; extractLayerParams(layer, layerParams); extractBinaryLayerParams(layer, layerParams); int repetitions = layerCounter[name]++; if (repetitions) name += String("_") + toString(repetitions); if (type == "Input") { for (int outNum = 0; outNum < layer.top_size(); outNum++) { addOutput(layer, 0, outNum); addedBlobs.back().outNum = netInputs.size(); netInputs.push_back(addedBlobs.back().name); } continue; } else if (type == "BatchNorm") { if (!layerParams.get<bool>("use_global_stats", true)) { CV_Assert_N(layer.bottom_size() == 1, layer.top_size() == 1); LayerParams mvnParams; mvnParams.set("eps", layerParams.get<float>("eps", 1e-5)); std::string mvnName = name + "/mvn"; int repetitions = layerCounter[mvnName]++; if (repetitions) mvnName += String("_") + toString(repetitions); int mvnId = dstNet.addLayer(mvnName, "MVN", mvnParams); addInput(layer.bottom(0), mvnId, 0, dstNet); addOutput(layer, mvnId, 0); net.mutable_layer(li)->set_bottom(0, layer.top(0)); layerParams.blobs[0].setTo(0); // mean layerParams.blobs[1].setTo(1); // std } } else if ("ConvolutionDepthwise" == type) { type = "Convolution"; } int id = dstNet.addLayer(name, type, layerParams); for (int inNum = 0; inNum < layer.bottom_size(); inNum++) addInput(layer.bottom(inNum), id, inNum, dstNet); for (int outNum = 0; outNum < layer.top_size(); outNum++) addOutput(layer, id, outNum); } dstNet.setInputsNames(netInputs); addedBlobs.clear(); }