Waifu2x::eWaifu2xError Waifu2x::ProcessNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size, cv::Mat &im) { Waifu2x::eWaifu2xError ret; CudaDeviceSet devset(mProcess, mGPUNo); const auto OutputMemorySize = net->GetOutputMemorySize(crop_w, crop_h, OuterPadding, batch_size); if (OutputMemorySize > mOutputBlockSize) { if (mIsCuda) { CUDA_HOST_SAFE_FREE(mOutputBlock); CUDA_CHECK_WAIFU2X(cudaHostAlloc(&mOutputBlock, OutputMemorySize, cudaHostAllocDefault)); } else { SAFE_DELETE_WAIFU2X(mOutputBlock); mOutputBlock = new float[OutputMemorySize]; } mOutputBlockSize = OutputMemorySize; } ret = net->ReconstructImage(use_tta, crop_w, crop_h, OuterPadding, batch_size, mOutputBlock, im, im); if (ret != Waifu2x::eWaifu2xError_OK) return ret; return Waifu2x::eWaifu2xError_OK; }
Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &Mode, const int NoiseLevel, const std::string &ModelDir, const std::string &Process, const int CropSize, const int BatchSize) { Waifu2x::eWaifu2xError ret; if (is_inited) return eWaifu2xError_OK; try { mode = Mode; noise_level = NoiseLevel; model_dir = ModelDir; process = Process; crop_size = CropSize; batch_size = BatchSize; inner_padding = layer_num; outer_padding = 1; output_size = crop_size - offset * 2; input_block_size = crop_size + (inner_padding + outer_padding) * 2; original_width_height = 128 + layer_num * 2; output_block_size = crop_size + (inner_padding + outer_padding - layer_num) * 2; std::call_once(waifu2x_once_flag, [argc, argv]() { assert(argc >= 1); int tmpargc = 1; char* tmpargvv[] = { argv[0] }; char** tmpargv = tmpargvv; // glog等の初期化 caffe::GlobalInit(&tmpargc, &tmpargv); }); const auto cuDNNCheckStartTime = std::chrono::system_clock::now(); if (process == "gpu") process = "cudnn"; const auto cuDNNCheckEndTime = std::chrono::system_clock::now(); boost::filesystem::path mode_dir_path(model_dir); if (!mode_dir_path.is_absolute()) // model_dirが相対パスなら絶対パスに直す { // まずはカレントディレクトリ下にあるか探す mode_dir_path = boost::filesystem::absolute(model_dir); if (!boost::filesystem::exists(mode_dir_path) && argc >= 1) // 無かったらargv[0]から実行ファイルのあるフォルダを推定し、そのフォルダ下にあるか探す { boost::filesystem::path a0(argv[0]); if (a0.is_absolute()) mode_dir_path = a0.branch_path() / model_dir; } } if (!boost::filesystem::exists(mode_dir_path)) return eWaifu2xError_FailedOpenModelFile; if (process == "cpu") { caffe::Caffe::set_mode(caffe::Caffe::CPU); isCuda = false; } else { caffe::Caffe::set_mode(caffe::Caffe::GPU); isCuda = true; } if (mode == "noise" || mode == "noise_scale" || mode == "auto_scale") { const std::string model_path = (mode_dir_path / "srcnn.prototxt").string(); const std::string param_path = (mode_dir_path / ("noise" + std::to_string(noise_level) + "_model.json")).string(); ret = ConstractNet(net_noise, model_path, param_path, process); if (ret != eWaifu2xError_OK) return ret; } if (mode == "scale" || mode == "noise_scale" || mode == "auto_scale") { const std::string model_path = (mode_dir_path / "srcnn.prototxt").string(); const std::string param_path = (mode_dir_path / "scale2.0x_model.json").string(); ret = ConstractNet(net_scale, model_path, param_path, process); if (ret != eWaifu2xError_OK) return ret; } const int input_block_plane_size = input_block_size * input_block_size * input_plane; const int output_block_plane_size = output_block_size * output_block_size * input_plane; if (isCuda) { CUDA_CHECK_WAIFU2X(cudaHostAlloc(&input_block, sizeof(float) * input_block_plane_size * batch_size, cudaHostAllocWriteCombined)); CUDA_CHECK_WAIFU2X(cudaHostAlloc(&dummy_data, sizeof(float) * input_block_plane_size * batch_size, cudaHostAllocWriteCombined)); CUDA_CHECK_WAIFU2X(cudaHostAlloc(&output_block, sizeof(float) * output_block_plane_size * batch_size, cudaHostAllocDefault)); } else { input_block = new float[input_block_plane_size * batch_size]; dummy_data = new float[input_block_plane_size * batch_size]; output_block = new float[output_block_plane_size * batch_size]; } for (size_t i = 0; i < input_block_plane_size * batch_size; i++) dummy_data[i] = 0.0f; is_inited = true; } catch (...) { return eWaifu2xError_InvalidParameter; } return eWaifu2xError_OK; }