Func GeneratorBase::call_extern(std::initializer_list<ExternFuncArgument> function_arguments, std::string function_name){ Pipeline p = build_pipeline(); user_assert(p.outputs().size() == 1) \ << "Can only call_extern Pipelines with a single output Func\n"; Func f = p.outputs()[0]; Func f_extern; if (function_name.empty()) { function_name = generator_name(); user_assert(!function_name.empty()) << "call_extern: generator_name is empty\n"; } f_extern.define_extern(function_name, function_arguments, f.output_types(), f.dimensions()); return f_extern; }
// Merge sort contiguous chunks of size s in a 1d func. Func merge_sort(Func input, int total_size) { std::vector<Func> stages; Func result; const int parallel_work_size = 512; Func parallel_stage("parallel_stage"); // First gather the input into a 2D array of width four where each row is sorted { assert(input.dimensions() == 1); // Use a small sorting network Expr a0 = input(4*y); Expr a1 = input(4*y+1); Expr a2 = input(4*y+2); Expr a3 = input(4*y+3); Expr b0 = min(a0, a1); Expr b1 = max(a0, a1); Expr b2 = min(a2, a3); Expr b3 = max(a2, a3); a0 = min(b0, b2); a1 = max(b0, b2); a2 = min(b1, b3); a3 = max(b1, b3); b0 = a0; b1 = min(a1, a2); b2 = max(a1, a2); b3 = a3; result(x, y) = select(x == 0, b0, select(x == 1, b1, select(x == 2, b2, b3))); result.compute_at(parallel_stage, y).bound(x, 0, 4).unroll(x); stages.push_back(result); } // Now build up to the total size, merging each pair of rows for (int chunk_size = 4; chunk_size < total_size; chunk_size *= 2) { // "result" contains the sorted halves assert(result.dimensions() == 2); // Merge pairs of rows from the partial result Func merge_rows("merge_rows"); RDom r(0, chunk_size*2); // The first dimension of merge_rows is within the chunk, and the // second dimension is the chunk index. Keeps track of two // pointers we're merging from and an output value. merge_rows(x, y) = Tuple(0, 0, cast(input.value().type(), 0)); Expr candidate_a = merge_rows(r-1, y)[0]; Expr candidate_b = merge_rows(r-1, y)[1]; Expr valid_a = candidate_a < chunk_size; Expr valid_b = candidate_b < chunk_size; Expr value_a = result(clamp(candidate_a, 0, chunk_size-1), 2*y); Expr value_b = result(clamp(candidate_b, 0, chunk_size-1), 2*y+1); merge_rows(r, y) = tuple_select(valid_a && ((value_a < value_b) || !valid_b), Tuple(candidate_a + 1, candidate_b, value_a), Tuple(candidate_a, candidate_b + 1, value_b)); if (chunk_size <= parallel_work_size) { merge_rows.compute_at(parallel_stage, y); } else { merge_rows.compute_root(); } if (chunk_size == parallel_work_size) { parallel_stage(x, y) = merge_rows(x, y)[2]; parallel_stage.compute_root().parallel(y); result = parallel_stage; } else { result = lambda(x, y, merge_rows(x, y)[2]); } } // Convert back to 1D return lambda(x, result(x, 0)); }
int main(int argc, char **argv) { Var x("x"), y("y"), z("z"), w("w"); ImageParam im1(Int(32), 3); assert(im1.dimensions() == 3); // im1 is a 3d imageparam Image<int> im1_val = lambda(x, y, z, x*y*z).realize(10, 10, 10); im1.set(im1_val); Image<int> im2 = lambda(x, y, x+y).realize(10, 10); assert(im2.dimensions() == 2); assert(im2(4, 6) == 10); // im2 is a 2d image Func f; f(x, _) = im1(_) + im2(x, _) + im2(_); // Equivalent to // f(x, i, j, k) = im1(i, j, k) + im2(x, i) + im2(i, j); // f(x, i, j, k) = i*j*k + x+i + i+j; Image<int> result1 = f.realize(2, 2, 2, 2); for (int k = 0; k < 2; k++) { for (int j = 0; j < 2; j++) { for (int i = 0; i < 2; i++) { for (int x = 0; x < 2; x++) { int correct = i*j*k + x+i + i+j; if (result1(x, i, j, k) != correct) { printf("result1(%d, %d, %d, %d) = %d instead of %d\n", x, i, j, k, result1(x, i, j, k), correct); return -1; } } } } } // f is a 4d function (thanks to the first arg having 3 implicit arguments assert(f.dimensions() == 4); Func g; g(_) = f(2, 2, _) + im2(Expr(1), _); f.compute_root(); // Equivalent to // g(i, j) = f(2, 2, i, j) + im2(1, i); // g(i, j) = 2*i*j + 2+2 + 2+i + 1+i assert(g.dimensions() == 2); Image<int> result2 = g.realize(10, 10); for (int j = 0; j < 10; j++) { for (int i = 0; i < 10; i++) { int correct = 2*i*j + 2+2 + 2+i + 1+i; if (result2(i, j) != correct) { printf("result2(%d, %d) = %d instead of %d\n", i, j, result2(i, j), correct); return -1; } } } // An image which ensures any transposition of unequal coordinates changes the value Image<int> im3 = lambda(x, y, z, w, (x<<24)|(y<<16)|(z<<8)|w).realize(10, 10, 10, 10); Func transpose_last_two; transpose_last_two(_, x, y) = im3(_, y, x); // Equivalent to transpose_last_two(_0, _1, x, y) = im3(_0, _1, x, y) Image<int> transposed = transpose_last_two.realize(10, 10, 10, 10); for (int i = 0; i < 10; i++) { for (int j = 0; j < 10; j++) { for (int k = 0; k < 10; k++) { for (int l = 0; l < 10; l++) { int correct = (i<<24)|(j<<16)|(l<<8)|k; if (transposed(i, j, k, l) != correct) { printf("transposed(%d, %d, %d, %d) = %d instead of %d\n", i, j, k, l, transposed(i, j, k, l), correct); return -1; } } } } } Func hairy_transpose; hairy_transpose(_, x, y) = im3(y, _, x) + im3(y, x, _); // Equivalent to hairy_transpose(_0, _1, x, y) = im3(y, _0, _1, x) + // im3(y, x, _0, _1) Image<int> hairy_transposed = hairy_transpose.realize(10, 10, 10, 10); for (int i = 0; i < 10; i++) { for (int j = 0; j < 10; j++) { for (int k = 0; k < 10; k++) { for (int l = 0; l < 10; l++) { int correct1 = (l<<24)|(i<<16)|(j<<8)|k; int correct2 = (l<<24)|(k<<16)|(i<<8)|j; int correct = correct1 + correct2; if (hairy_transposed(i, j, k, l) != correct) { printf("hairy_transposed(%d, %d, %d, %d) = %d instead of %d\n", i, j, k, l, hairy_transposed(i, j, k, l), correct); return -1; } } } } } Func hairy_transpose2; hairy_transpose2(_, x) = im3(_, x) + im3(x, x, _); // Equivalent to hairy_transpose2(_0, _1, _2, x) = im3(_0, _1, _2, x) + // im3(x, x, _0, _1) Image<int> hairy_transposed2 = hairy_transpose2.realize(10, 10, 10, 10); for (int i = 0; i < 10; i++) { for (int j = 0; j < 10; j++) { for (int k = 0; k < 10; k++) { for (int l = 0; l < 10; l++) { int correct1 = (i<<24)|(j<<16)|(k<<8)|l; int correct2 = (l<<24)|(l<<16)|(i<<8)|j; int correct = correct1 + correct2; if (hairy_transposed2(i, j, k, l) != correct) { printf("hairy_transposed2(%d, %d, %d, %d) = %d instead of %d\n", i, j, k, l, hairy_transposed2(i, j, k, l), correct); return -1; } } } } } printf("Success!\n"); return 0; }