Ejemplo n.º 1
0
int main() {
    // This test must be run with an OpenGL target.
    const Target target = get_jit_target_from_environment().with_feature(Target::OpenGL);

    // Define the input.
    const int width = 10, height = 10, channels = 3;
    Buffer<float> input(width, height, channels);
    input.fill([](int x, int y, int c) {
        return x + y;
    });

    // Define the algorithm.
    Var x, y, c;
    RDom r(0, 3, "r");
    Func g;

    g(x, y, c) = sum(input(x, y, r));

    // Schedule f and g to compute in separate passes on the GPU.
    g.bound(c, 0, 3).glsl(x, y, c);

    // Generate the result.
    Buffer<float> result = g.realize(10, 10, 3, target);
    result.copy_to_host();

    // Check the result.
    if (!Testing::check_result<float>(result, 1e-6, [](int x, int y, int c) { return 3.0f * (x + y); })) {
        return 1;
    }

    printf("Success!\n");
    return 0;
}
Ejemplo n.º 2
0
int main() {
    // This test must be run with an OpenGL target
    const Target &target = get_jit_target_from_environment();
    if (!(target.features & Target::OpenGL))  {
        fprintf(stderr,"ERROR: This test must be run with an OpenGL target, e.g. by setting HL_JIT_TARGET=host-opengl.\n");
        return 1;
    }

    Func f;
    Var x, y, c;

    f(x, y, c) = cast<uint8_t>(select(c == 0, 10*x + y,
                                      c == 1, 127, 12));

    Image<uint8_t> out(10, 10, 3);
    f.bound(c, 0, 3).glsl(x, y, c);
    f.realize(out);

    out.copy_to_host();
    for (int y=0; y<out.height(); y++) {
        for (int x=0; x<out.width(); x++) {
            if (!(out(x, y, 0) == 10*x+y && out(x, y, 1) == 127 && out(x, y, 2) == 12)) {
                fprintf(stderr, "Incorrect pixel (%d, %d, %d) at x=%d y=%d.\n",
                        out(x, y, 0), out(x, y, 1), out(x, y, 2),
                        x, y);
                return 1;
            }
        }
    }

    printf("Success!\n");
    return 0;
}
Ejemplo n.º 3
0
int main() {

    // This test must be run with an OpenGL target
    const Target &target = get_jit_target_from_environment();
    if (!target.has_feature(Target::OpenGL))  {
        fprintf(stderr,"ERROR: This test must be run with an OpenGL target, e.g. by setting HL_JIT_TARGET=host-opengl.\n");
        return 1;
    }

    Image<float> input(255, 255, 3);
    for (int y=0; y<input.height(); y++) {
        for (int x=0; x<input.width(); x++) {
            for (int c=0; c<3; c++) {
                // Note: the following values can be >1.0f to test whether
                // OpenGL performs clamping operations as part of the copy
                // operation.  (It may do so if something other than floats
                // are stored in the actual texture.)
                float v = (10 * x + y + c);
                input(x, y, c) = v;
            }
        }
    }

    Var x, y, c;
    Func g;
    g(x, y, c) = input(x, y, c);

    Image<float> out(255, 255, 3);
    g.bound(c, 0, 3);
    g.glsl(x, y, c);
    g.realize(out);
    out.copy_to_host();

    for (int y=0; y<out.height(); y++) {
        for (int x=0; x<out.width(); x++) {
            if (!(out(x, y, 0) == input(x, y, 0) &&
                  out(x, y, 1) == input(x, y, 1) &&
                  out(x, y, 2) == input(x, y, 2))) {
                fprintf(stderr, "Incorrect pixel (%g,%g,%g) != (%g,%g,%g) at x=%d y=%d.\n",
                        out(x, y, 0), out(x, y, 1), out(x, y, 2),
                        input(x, y, 0), input(x, y, 1), input(x, y, 2),
                        x, y);
                return 1;
            }
        }
    }

    printf("Success!\n");
    return 0;
}
Ejemplo n.º 4
0
int main() {
    // This test must be run with an OpenGL target.
    const Target target = get_jit_target_from_environment().with_feature(Target::OpenGL);

    // Define the input
    const int width = 10, height = 10, channels = 4;
    Buffer<float> input(width, height, channels);
    for (int c = 0; c < input.channels(); c++) {
        for (int y = 0; y < input.height(); y++) {
            for (int x = 0; x < input.width(); x++) {
                input(x, y, c) = float(x + y);
            }
        }
    }

    // Define the algorithm.
    Var x, y, c;
    RDom r(0, 5, "r");
    Func g;
    Expr coordx = clamp(x + r, 0, input.width() - 1);
    g(x, y, c) = cast<float>( sum(input(coordx, y, c)) / sum(r) * 255.0f );

    // Schedule f and g to compute in separate passes on the GPU.
    g.bound(c, 0, 4).glsl(x, y, c);

    // Generate the result.
    Buffer<float> result = g.realize(width, height, channels, target);
    result.copy_to_host();

    // Check the result.
    for (int c = 0; c < result.channels(); c++) {
        for (int y = 0; y < result.height(); y++) {
            for (int x = 0; x < result.width(); x++) {
                float temp = 0.0f;
                for (int r = 0; r < 5; r++){
                    temp += input(std::min(x+r, input.width()-1), y, c);
                }
                float correct = temp / 10.0f * 255.0f;
                if (fabs(result(x, y, c) - correct) > 1e-3) {
                    fprintf(stderr, "result(%d, %d, %d) = %f instead of %f\n",
                            x, y, c, result(x, y, c), correct);
                    return 1;
                }
            }
        }
    }

    printf("Success!\n");
    return 0;
}
int main(int argc, char **argv) {
    Var x, y;
    Func f;

    f(x, y) = my_func(0, Expr(0)) + my_func(1, y) + my_func(2, x);

    // llvm rightly refuses to lift loop invariants out of loops that
    // might have an extent of zero. It's possible wasted work.
    f.bound(x, 0, 32).bound(y, 0, 32);
    
    Image<int> im = f.realize(32, 32);

    // Check the result was what we expected
    for (int i = 0; i < 32; i++) {
        for (int j = 0; j < 32; j++) {
            int correct = i + j; 
            if (im(i, j) != correct) {
                printf("im[%d, %d] = %d instead of %d\n", i, j, im(i, j), correct);
                return -1;
            }
        }
    }

    // Check the call counters
    if (call_counter[0] != 1 || call_counter[1] != 32 || call_counter[2] != 32*32) {
        printf("Call counters were %d %d %d instead of %d %d %d\n", 
               call_counter[0], call_counter[1], call_counter[2], 
               1, 32, 32*32);
        return -1;
    }

    // Note that things don't get lifted out of parallel loops - Each
    // thread will independently call your extern function.
    Func g;
    g(x, y) = my_func(3, Expr(0));
    g.parallel(y);
    // Avoid the race condition by not actually being parallel
    g.set_custom_do_par_for(&not_really_parallel_for);
    g.realize(32, 32);

    if (call_counter[3] != 32) {
        printf("Call counter for parallel call was %d instead of %d\n", 
               call_counter[3], 32);
        return -1;
    }

    printf("Success!\n");
    return 0;
}
Ejemplo n.º 6
0
int main() {
    // This test must be run with an OpenGL target.
    const Target target = get_jit_target_from_environment().with_feature(Target::OpenGL);

    // Define the input.
    const int width = 10, height = 10, channels = 3;
    Image<float> input(width, height, channels);
    for (int c = 0; c < input.channels(); c++) {
        for (int y = 0; y < input.height(); y++) {
            for (int x = 0; x < input.width(); x++) {
                input(x, y, c) = x + y;
            }
        }
    }

    // Define the algorithm.
    Var x, y, c;
    RDom r(0, 3, "r");
    Func g;

    g(x, y, c) = sum(input(x, y, r));

    // Schedule f and g to compute in separate passes on the GPU.
    g.bound(c, 0, 3).glsl(x, y, c);

    // Generate the result.
    Image<float> result = g.realize(10, 10, 3, target);
    result.copy_to_host();

    // Check the result.
    for (int c = 0; c < result.channels(); c++) {
        for (int y = 0; y < result.height(); y++) {
            for (int x = 0; x < result.width(); x++) {
                float correct = 3.0f * (x + y);
                if (fabs(result(x, y, c) - correct) > 1e-6) {
                    fprintf(stderr, "result(%d, %d, %d) = %f instead of %f\n",
                            x, y, c, result(x, y, c), correct);
                    return 1;
                }
            }
        }
    }

    printf("Success!\n");
    return 0;
}
Ejemplo n.º 7
0
int main() {

    // This test must be run with an OpenGL target
    const Target &target = get_jit_target_from_environment();
    if (!(target.features & Target::OpenGL))  {
        fprintf(stderr,"ERROR: This test must be run with an OpenGL target, e.g. by setting HL_JIT_TARGET=host-opengl.\n");
        return 1;
    }

    Image<uint8_t> input(255, 10, 3);
    for (int y=0; y<input.height(); y++) {
        for (int x=0; x<input.width(); x++) {
            for (int c=0; c<3; c++) {
              input(x, y, c) = 10*x + y + c;
            }
        }
    }

    Var x, y, c;
    Func g;
    g(x, y, c) = input(x, y, c);

    Image<uint8_t> out(255, 10, 3);
    g.bound(c, 0, 3);
    g.glsl(x, y, c);
    g.realize(out);
    out.copy_to_host();

    for (int y=0; y<out.height(); y++) {
        for (int x=0; x<out.width(); x++) {
            if (!(out(x, y, 0) == input(x, y, 0) &&
                  out(x, y, 1) == input(x, y, 1) &&
                  out(x, y, 2) == input(x, y, 2))) {
                fprintf(stderr, "Incorrect pixel (%d,%d,%d) != (%d,%d,%d) at x=%d y=%d.\n",
                        out(x, y, 0), out(x, y, 1), out(x, y, 2),
                        input(x, y, 0), input(x, y, 1), input(x, y, 2),
                        x, y);
                return 1;
            }
        }
    }

    printf("Success!\n");
    return 0;
}
Ejemplo n.º 8
0
Func process(Func raw, Type result_type,
             ImageParam matrix_3200, ImageParam matrix_7000, Param<float> color_temp,
             Param<float> gamma, Param<float> contrast) {

    Var xi, yi;

    Func denoised = hot_pixel_suppression(raw);
    Func deinterleaved = deinterleave(denoised);
    Func demosaiced = demosaic(deinterleaved);
    Func corrected = color_correct(demosaiced, matrix_3200, matrix_7000, color_temp);
    Func curved = apply_curve(corrected, result_type, gamma, contrast);

    processed(tx, ty, c) = curved(tx, ty, c);

    // Schedule
    processed.bound(c, 0, 3); // bound color loop 0-3, properly
    if (schedule == 0) {
        // Compute in chunks over tiles, vectorized by 8
        denoised.compute_at(processed, tx).vectorize(x, 8);
        deinterleaved.compute_at(processed, tx).vectorize(x, 8).reorder(c, x, y).unroll(c);
        corrected.compute_at(processed, tx).vectorize(x, 4).reorder(c, x, y).unroll(c);
        processed.tile(tx, ty, xi, yi, 32, 32).reorder(xi, yi, c, tx, ty);
        processed.parallel(ty);
    } else if (schedule == 1) {
        // Same as above, but don't vectorize (sse is bad at interleaved 16-bit ops)
        denoised.compute_at(processed, tx);
        deinterleaved.compute_at(processed, tx);
        corrected.compute_at(processed, tx);
        processed.tile(tx, ty, xi, yi, 128, 128).reorder(xi, yi, c, tx, ty);
        processed.parallel(ty);
    } else {
        denoised.compute_root();
        deinterleaved.compute_root();
        corrected.compute_root();
        processed.compute_root();
    }

    return processed;
}
Ejemplo n.º 9
0
int main(int argc, char **argv) {
    Param<float> time;

    const float pi = 3.1415926536;

    Var x, y, c;
    Func result;

    Expr kx, ky;
    Expr xx, yy;
    kx = x / 150.0f;
    ky = y / 150.0f;

    xx = kx + sin(time/3.0f);
    yy = ky + sin(time/2.0f);

    Expr angle;
    angle = 2 * pi * sin(time/20.0f);
    kx = kx * cos(angle) - ky * sin(angle);
    ky = kx * sin(angle) + ky * cos(angle);

    Expr v = 0.0f;
    v += sin((ky + time) / 2.0f);
    v += sin((kx + ky + time) / 2.0f);
    v += sin(sqrt(xx * xx + yy * yy + 1.0f) + time);

    result(x, y, c) = cast<uint8_t>(
        select(c == 0, 32,
               select(c == 1, cos(pi * v),
                      sin(pi * v)) * 80 + (255 - 80)));

    result.output_buffer().set_stride(0, 4);
    result.bound(c, 0, 4);
    result.glsl(x, y, c);

    result.compile_to_file("halide_gl_filter", {time}, "halide_gl_filter");

    return 0;
}
Ejemplo n.º 10
0
int main() {
    Func f;
    Var x, y, c;

    Expr e = 0;

    // Max with integer arguments requires Halide to introduce an implicit
    // cast to float.
    e = select(x == 0, max(y, 5), e);
    // But using float directly should also work.
    e = select(x == 1, cast<int>(min(cast<float>(y), 5.0f)), e);

    e = select(x == 2, y % 3, e);
    e = select(x == 3, cast<int>(127*sin(y) + 128), e);
    e = select(x == 4, y / 2, e);

    f(x, y, c) = cast<uint8_t>(e);

    Image<uint8_t> out(10, 10, 1);
    f.bound(c, 0, 1);
    f.glsl(x, y, c);
    f.realize(out);

    out.copy_to_host();

    for (int y = 0; y < out.height(); y++) {
        CHECK_EQ(out(0, y, 0), std::max(y, 5));
        CHECK_EQ(out(1, y, 0), std::min(y, 5));
        CHECK_EQ(out(2, y, 0), y % 3);
        CHECK_EQ(out(3, y, 0), static_cast<int>(127*std::sin(y) + 128));
        CHECK_EQ(out(4, y, 0), y / 2);
    }

    printf("Success!\n");
    return 0;
}
Ejemplo n.º 11
0
int main(int argc, char **argv) {

    const int N = 1 << 10;

    Image<int> data(N);
    for (int i = 0; i < N; i++) {
        data(i) = rand() & 0xfffff;
    }
    Func input = lambda(x, data(x));

    printf("Bitonic sort...\n");
    Func f = bitonic_sort(input, N);
    f.bound(x, 0, N);
    f.compile_jit();
    printf("Running...\n");
    Image<int> bitonic_sorted(N);
    f.realize(bitonic_sorted);
    double t1 = current_time();
    for (int i = 0; i < 10; i++) {
        f.realize(bitonic_sorted);
    }
    double t2 = current_time();

    printf("Merge sort...\n");
    f = merge_sort(input, N);
    f.bound(x, 0, N);
    f.compile_jit();
    printf("Running...\n");
    Image<int> merge_sorted(N);
    f.realize(merge_sorted);
    double t3 = current_time();
    for (int i = 0; i < 10; i++) {
        f.realize(merge_sorted);
    }
    double t4 = current_time();

    Image<int> correct(N);
    for (int i = 0; i < N; i++) {
        correct(i) = data(i);
    }
    printf("std::sort...\n");
    double t5 = current_time();
    std::sort(&correct(0), &correct(N));
    double t6 = current_time();

    printf("Times:\n"
           "bitonic sort: %f \n"
           "merge sort: %f \n"
           "std::sort %f\n",
           (t2-t1)/10, (t4-t3)/10, t6-t5);

    if (N <= 100) {
        for (int i = 0; i < N; i++) {
            printf("%8d %8d %8d\n",
                   correct(i), bitonic_sorted(i), merge_sorted(i));
        }
    }

    for (int i = 0; i < N; i++) {
        if (bitonic_sorted(i) != correct(i)) {
            printf("bitonic sort failed: %d -> %d instead of %d\n", i, bitonic_sorted(i), correct(i));
            return -1;
        }
        if (merge_sorted(i) != correct(i)) {
            printf("merge sort failed: %d -> %d instead of %d\n", i, merge_sorted(i), correct(i));
            return -1;
        }
    }

    return 0;
}
Ejemplo n.º 12
0
// Merge sort contiguous chunks of size s in a 1d func.
Func merge_sort(Func input, int total_size) {
    std::vector<Func> stages;
    Func result;

    const int parallel_work_size = 512;

    // First gather the input into a 2D array of width four where each row is sorted
    {
        assert(input.dimensions() == 1);
        // Use a small sorting network
        Expr a0 = input(4*y);
        Expr a1 = input(4*y+1);
        Expr a2 = input(4*y+2);
        Expr a3 = input(4*y+3);

        Expr b0 = min(a0, a1);
        Expr b1 = max(a0, a1);
        Expr b2 = min(a2, a3);
        Expr b3 = max(a2, a3);

        a0 = min(b0, b3);
        a1 = min(b1, b2);
        a2 = max(b1, b2);
        a3 = max(b0, b3);

        b0 = min(a0, a1);
        b1 = max(a0, a1);
        b2 = min(a2, a3);
        b3 = max(a2, a3);

        result(x, y) = select(x == 0, b0,
                              select(x == 1, b1,
                                     select(x == 2, b2, b3)));

        result.bound(x, 0, 4).unroll(x);

        stages.push_back(result);
    }

    Func parallel_stage("parallel_stage");

    // Now build up to the total size, merging each pair of rows
    for (int chunk_size = 4; chunk_size < total_size; chunk_size *= 2) {
        // "result" contains the sorted halves
        assert(result.dimensions() == 2);

        // Merge pairs of rows from the partial result
        Func merge_rows("merge_rows");
        RDom r(0, chunk_size*2);

        // The first dimension of merge_rows is within the chunk, and the
        // second dimension is the chunk index.  Keeps track of two
        // pointers we're merging from and an output value.
        merge_rows(x, y) = Tuple(0, 0, cast(input.value().type(), 0));

        Expr candidate_a = merge_rows(r-1, y)[0];
        Expr candidate_b = merge_rows(r-1, y)[1];
        Expr valid_a = candidate_a < chunk_size;
        Expr valid_b = candidate_b < chunk_size;
        Expr value_a = result(clamp(candidate_a, 0, chunk_size-1), 2*y);
        Expr value_b = result(clamp(candidate_b, 0, chunk_size-1), 2*y+1);
        merge_rows(r, y) = tuple_select(valid_a && ((value_a < value_b) || !valid_b),
                                        Tuple(candidate_a + 1, candidate_b, value_a),
                                        Tuple(candidate_a, candidate_b + 1, value_b));


        if (chunk_size <= parallel_work_size) {
            merge_rows.compute_at(parallel_stage, y);
        } else {
            merge_rows.compute_root();
        }

        if (chunk_size == parallel_work_size) {
            parallel_stage(x, y) = merge_rows(x, y)[2];
            parallel_stage.compute_root().parallel(y);
            result = parallel_stage;
        } else {
            result = lambda(x, y, merge_rows(x, y)[2]);
        }
    }

    // Convert back to 1D
    return lambda(x, result(x, 0));
}
Ejemplo n.º 13
0
int main(int argc, char **argv) {

    const int N = 1 << 10;

    Buffer<int> data(N);
    for (int i = 0; i < N; i++) {
        data(i) = rand() & 0xfffff;
    }
    Func input = lambda(x, data(x));

    printf("Bitonic sort...\n");
    Func f = bitonic_sort(input, N);
    f.bound(x, 0, N);
    f.compile_jit();
    printf("Running...\n");
    Buffer<int> bitonic_sorted(N);
    f.realize(bitonic_sorted);
    double t_bitonic = benchmark(1, 10, [&]() {
        f.realize(bitonic_sorted);
    });

    printf("Merge sort...\n");
    f = merge_sort(input, N);
    f.bound(x, 0, N);
    f.compile_jit();
    printf("Running...\n");
    Buffer<int> merge_sorted(N);
    f.realize(merge_sorted);
    double t_merge = benchmark(1, 10, [&]() {
        f.realize(merge_sorted);
    });

    Buffer<int> correct(N);
    for (int i = 0; i < N; i++) {
        correct(i) = data(i);
    }
    printf("std::sort...\n");
    double t_std = benchmark(1, 1, [&]() {
        std::sort(&correct(0), &correct(N));
    });

    printf("Times:\n"
           "bitonic sort: %fms \n"
           "merge sort: %fms \n"
           "std::sort %fms\n",
           t_bitonic * 1e3, t_merge * 1e3, t_std * 1e3);

    if (N <= 100) {
        for (int i = 0; i < N; i++) {
            printf("%8d %8d %8d\n",
                   correct(i), bitonic_sorted(i), merge_sorted(i));
        }
    }

    for (int i = 0; i < N; i++) {
        if (bitonic_sorted(i) != correct(i)) {
            printf("bitonic sort failed: %d -> %d instead of %d\n", i, bitonic_sorted(i), correct(i));
            return -1;
        }
        if (merge_sorted(i) != correct(i)) {
            printf("merge sort failed: %d -> %d instead of %d\n", i, merge_sorted(i), correct(i));
            return -1;
        }
    }

    return 0;
}
int main(int argc, char **argv) {
    Var x, y, c;

    {
        Func f;
        f(x, y, c) = 1 + select(c < 1, x,
                                c == 1, y,
                                x + y);
        f.reorder(c, x, y).vectorize(x, 4);

        // The select in c should go away.
        if (uses_branches(f)) {
            printf("There weren't supposed to be branches!\n");
            return -1;
        }

        Image<int> f_result = f.realize(10, 10, 3);
        for (int y = 0; y < f_result.height(); y++) {
            for (int x = 0; x < f_result.width(); x++) {
                for (int c = 0; c < f_result.channels(); c++) {
                    int correct = 1 + (c == 0 ? x : (c == 1 ? y : x + y));
                    if (f_result(x, y, c) != correct) {
                        printf("f_result(%d, %d, %d) = %d instead of %d\n",
                               x, y, c, f_result(x, y, c), correct);
                        return -1;
                    }
                }
            }
        }
    }

    {
        Func g;
        g(x, y, c) = (select(c > 1, 2*x,
                             c == 1, x - y,
                             y)
                      + select(c < 1, x,
                               c == 1, y,
                               x + y));
        g.vectorize(x, 4);

        g.output_buffer()
            .set_min(0,0).set_min(1,0).set_min(2,0)
            .set_extent(0,10).set_extent(1,10).set_extent(2,3);

        // The select in c should go away.
        if (uses_branches(g)) {
            printf("There weren't supposed to be branches!\n");
            return -1;
        }

        Image<int> g_result = g.realize(10, 10, 3);
        for (int y = 0; y < g_result.height(); y++) {
            for (int x = 0; x < g_result.width(); x++) {
                for (int c = 0; c < g_result.channels(); c++) {
                    int correct = (c > 1 ? 2*x : (c == 1 ? x - y : y))
                        + (c < 1 ? x : (c == 1 ? y : x + y));
                    if (g_result(x, y, c) != correct) {
                        printf("g_result(%d, %d, %d) = %d instead of %d\n",
                               x, y, c, g_result(x, y, c), correct);
                        return -1;
                    }
                }
            }
        }
    }

    {
        // An RDom with a conditional
        Func f, sum_scan;
        f(x) = x*17 + 3;
        f.compute_root();

        RDom r(0, 100);
        sum_scan(x) = undef<int>();
        sum_scan(r) = select(r == 0, f(r), f(r) + sum_scan(max(0, r-1)));

        if (uses_branches(sum_scan)) {
            printf("There weren't supposed to be branches!\n");
            return -1;
        }

        Image<int> result = sum_scan.realize(100);

        int correct = 0;
        for (int x = 0; x < 100; x++) {
            correct += x*17 + 3;
            if (result(x) != correct) {
                printf("sum scan result(%d) = %d instead of %d\n",
                       x, result(x), correct);
                return -1;
            }
        }
    }

    // Sliding window optimizations inject a select in a let expr. See if it gets simplified.
    {
        Func f, g;
        f(x) = x*x*17;
        g(x) = f(x-1) + f(x+1);
        f.store_root().compute_at(g, x);

        if (uses_branches(g)) {
            printf("There weren't supposed to be branches!\n");
            return -1;
        }

        Image<int> result = g.realize(100);

        for (int x = 0; x < 100; x++) {
            int correct = (x-1)*(x-1)*17 + (x+1)*(x+1)*17;
            if (result(x) != correct) {
                printf("sliding window result(%d) = %d instead of %d\n",
                       x, result(x), correct);
                return -1;
            }
        }

    }

    // Check it still works when unrolling (and doesn't change the order of evaluation).
    {
        Func f;
        f(x) = select(x > 3, x*3, x*17) + count(x);
        f.bound(x, 0, 100).unroll(x, 2);

        Image<int> result = f.realize(100);

        for (int x = 0; x < 100; x++) {
            int correct = x > 3 ? x*3 : x*17;
            correct += x;
            if (result(x) != correct) {
                printf("Unrolled result(%d) = %d instead of %d\n",
                       x, result(x), correct);
                break; // Failing. Continue to other tests.
                //return -1;
            }
        }
    }

    // Skip stages introduces conditional allocations, check that we handle them correctly.
    {
      Func f, g;
      f(x) = x*3;
      g(x, c) = select(c == 0, f(x), x*5);
      f.compute_at(g, c);

      Image<int> result = g.realize(100, 3);
      for (int c = 0; c < 3; c++) {
        for (int x = 0; x < 100; x++) {
          int correct = c == 0? x*3: x*5;
          if (result(x, c) != correct) {
            printf("conditional alloc result(%d, %d) = %d instead of %d\n",
                   x, c, result(x, c), correct);
          }
        }
      }
    }

    // Test that we can deal with undefined values.
    {
      Func result("result");

      RDom rv(0, 50, 0, 50);

      result(x, y) = 0;
      result(rv.x, rv.y) = select(rv.y < 10, 100, undef<int>());

      result.compile_jit();
    }

    // Check for combinatorial explosion when there are lots of selects
    {
        Func f;
        Expr e = 0;
        for (int i = 19; i >= 0; i--) {
            e = select(x <= i, i*i, e);
        }
        f(x) = e;

        Image<int> result = f.realize(100);

        for (int x = 0; x < 100; x++) {
            int correct = x < 20 ? x*x : 0;
            if (result(x) != correct) {
                printf("lots of selects result(%d) = %d instead of %d\n",
                       x, result(x), correct);
                return -1;
            }
        }
    }

    // Check recursive merging of branches does not change result.
    {
        ImageParam input(UInt(8), 4, "input");

        Expr ch = input.extent(2);
        Func f("f");
        f(x, y, c) = select(ch == 1,
                            select(c < 3, input(x, y, 0, 0), 255),
                            select(c < ch, input(x, y, min(c, ch-1), 0), 255));

        f.bound(c, 0, 4);

        int xn = 16;
        int yn = 8;
        int cn = 3;
        int wn = 2;

        Image<uint8_t> in(xn, yn, cn, wn);
        for (int x = 0; x < xn; ++x) {
            for (int y = 0; y < yn; ++y) {
                for (int c = 0; c < cn; ++c) {
                    for (int w = 0; w < wn; ++w) {
                        in(x, y, c, w) = x + y + c + w;
                    }
                }
            }
        }
        input.set(in);

        Image<uint8_t> f_result = f.realize(xn, yn, 4);
        for (int x = 0; x < f_result.width(); x++) {
            for (int y = 0; y < f_result.height(); y++) {
                for (int c = 0; c < f_result.channels(); c++) {
                    int correct = (c < cn ? x + y + c : 255);
                    if (f_result(x, y, c) != correct) {
                        printf("f_result(%d, %d, %d) = %d instead of %d\n",
                               x, y, c, f_result(x, y, c), correct);
                        return -1;
                    }
                }
            }
        }
    }

    printf("Success!\n");
    return 0;
}