Halide::Func convolution_layer(Halide::Func input, Halide::Func weights, Halide::Func bias, int filter_size, int input_layers, int pool_size) { // Convolution Halide::Func convolution; Halide::Var x, y, z, w; Halide::RDom r(0, filter_size, 0, filter_size, 0, input_layers); convolution(x, y, z, w) = 0.0f; convolution(x, y, z, w) += weights(r.x, r.y, r.z, z) * input(x + r.x, y + r.y, r.z, w); // Max pool Halide::Func subsample; Halide::RDom s(0, pool_size, 0, pool_size); subsample(x, y, z, w) = 0.0f; subsample(x, y, z, w) = Halide::max(convolution(pool_size * x + s.x, pool_size * y + s.y, z, w), subsample(x, y, z, w)); // Non-linear bias Halide::Func biased; biased(x, y, z, w) = tanh(subsample(x, y, z, w) + bias(z, 0)); Halide::Var x_inner, x_outer, y_inner, y_outer; biased.parallel(w); biased.tile(x, y, x_outer, y_outer, x_inner, y_inner, VECTORS, 2); biased.vectorize(x_inner); biased.unroll(y_inner); return biased; }
int resize_with_halide() { Halide::ImageParam input {Halide::type_of<uint8_t>(), 3}; //_/_/_/ load a source image and repeat its edges Halide::Func src_image {}; src_image = Halide::BoundaryConditions::repeat_edge(input); //_/_/_/ describe algorithm Halide::Param<float> src_rows {}; Halide::Param<float> src_cols {}; Halide::Param<float> dst_rows {}; Halide::Param<float> dst_cols {}; // const float sc = 500.0f/4999;//static_cast<float>(src_cols.get()) / dst_cols.get(); // const float sr = 350.0f/3499;//static_cast<float>(src_rows.get()) / dst_rows.get(); const auto sc = src_cols / dst_cols; const auto sr = src_rows / dst_rows; Halide::Var i {}; Halide::Var j {}; Halide::Var c {}; auto fj = j * sr; auto cj0 = Halide::cast<int>(fj); auto cj1 = cj0 + 1; auto dj = fj - cj0; auto fi = i * sc; auto ci0 = Halide::cast<int>(fi); auto ci1 = ci0 + 1; auto di = fi - ci0; const auto c0 = (1.0f - dj) * (1.0f - di); const auto c1 = (1.0f - dj) * di; const auto c2 = dj * (1.0f - di); const auto c3 = dj * di; const auto& src_pixel0 = src_image(ci0, cj0, c); const auto& src_pixel1 = src_image(ci1, cj0, c); const auto& src_pixel2 = src_image(ci0, cj1, c); const auto& src_pixel3 = src_image(ci1, cj1, c); Halide::Func resize {}; resize(i, j, c) = Halide::saturating_cast<uint8_t>(c0 * src_pixel0 + c1 * src_pixel1 + c2 * src_pixel2 + c3 * src_pixel3); //_/_/_/ describe scheduling Halide::Var i_inner, j_inner; auto x_vector_size = 64; resize.compute_root(); resize.tile(i, j, i_inner, j_inner, x_vector_size, 4).vectorize(i_inner, 16).parallel(j); //_/_/_/ save a static library const auto path = "/Users/kumada/Projects/cct_blog/halide/sample_4/sample_4/resize"; resize.compile_to_static_library( path, {input, src_rows, src_cols, dst_rows, dst_cols}, "resize"); return 1; }