Halide::Func convolution_layer(Halide::Func input, Halide::Func weights, Halide::Func bias, int filter_size, int input_layers, int pool_size) { // Convolution Halide::Func convolution; Halide::Var x, y, z, w; Halide::RDom r(0, filter_size, 0, filter_size, 0, input_layers); convolution(x, y, z, w) = 0.0f; convolution(x, y, z, w) += weights(r.x, r.y, r.z, z) * input(x + r.x, y + r.y, r.z, w); // Max pool Halide::Func subsample; Halide::RDom s(0, pool_size, 0, pool_size); subsample(x, y, z, w) = 0.0f; subsample(x, y, z, w) = Halide::max(convolution(pool_size * x + s.x, pool_size * y + s.y, z, w), subsample(x, y, z, w)); // Non-linear bias Halide::Func biased; biased(x, y, z, w) = tanh(subsample(x, y, z, w) + bias(z, 0)); Halide::Var x_inner, x_outer, y_inner, y_outer; biased.parallel(w); biased.tile(x, y, x_outer, y_outer, x_inner, y_inner, VECTORS, 2); biased.vectorize(x_inner); biased.unroll(y_inner); return biased; }
Halide::Func fully_connected_layer(Halide::Func input, Halide::Func weights, Halide::Func bias, int size) { Halide::Func product; Halide::Var x, y, z; Halide::RDom r(0, size); // Only y = 0 should be used product(x, y, z) = 0.0f; product(x, y, z) += weights(r.x, x) * input(r.x, y, z); product(x, y, z) = tanh(product(x, y, z) + bias(x, 0)); product.vectorize(x, VECTORS); return product; }
void NamedWindow::showImage2D(Halide::Image<uint8_t> im) { static Halide::Func convert("convertToMat2D"); static Halide::ImageParam ip(Halide::UInt(8), 2); static Halide::Var x, y; if (!convert.defined()) { convert(x, y) = ip(x, y); convert.vectorize(x, 4).parallel(y, 4); } ip.set(im); cv::Mat mat(im.height(), im.width(), CV_8UC1, cv::Scalar(0)); convert.realize(Halide::Buffer(Halide::UInt(8), im.width(), im.height(), 0, 0, mat.data)); cv::imshow(name, mat); }
void NamedWindow::showImage3D(Halide::Image<float> im) { static Halide::Func convert("convertToMat3D"); static Halide::ImageParam ip(Halide::Float(32), 3); static Halide::Var x, y, c; if (!convert.defined()) { convert(c, x, y) = Halide::cast<uint8_t>(ip(x, y, 2 - c) * 255); convert.vectorize(x, 4).parallel(y, 4); } ip.set(im); cv::Mat mat(im.height(), im.width(), CV_8UC3, cv::Scalar(0)); convert.realize(Halide::Buffer(Halide::UInt(8), im.channels(), im.width(), im.height(), 0, mat.data)); cv::imshow(name, mat); }