Ejemplo n.º 1
0
Func GeneratorBase::call_extern(std::initializer_list<ExternFuncArgument> function_arguments,
                                std::string function_name){
    Pipeline p = build_pipeline();
    user_assert(p.outputs().size() == 1) \
        << "Can only call_extern Pipelines with a single output Func\n";
    Func f = p.outputs()[0];
    Func f_extern;
    if (function_name.empty()) {
        function_name = generator_name();
        user_assert(!function_name.empty())
            << "call_extern: generator_name is empty\n";
    }
    f_extern.define_extern(function_name, function_arguments, f.output_types(), f.dimensions());
    return f_extern;
}
Ejemplo n.º 2
0
// Merge sort contiguous chunks of size s in a 1d func.
Func merge_sort(Func input, int total_size) {
    std::vector<Func> stages;
    Func result;

    const int parallel_work_size = 512;

    Func parallel_stage("parallel_stage");

    // First gather the input into a 2D array of width four where each row is sorted
    {
        assert(input.dimensions() == 1);
        // Use a small sorting network
        Expr a0 = input(4*y);
        Expr a1 = input(4*y+1);
        Expr a2 = input(4*y+2);
        Expr a3 = input(4*y+3);

        Expr b0 = min(a0, a1);
        Expr b1 = max(a0, a1);
        Expr b2 = min(a2, a3);
        Expr b3 = max(a2, a3);

        a0 = min(b0, b2);
        a1 = max(b0, b2);
        a2 = min(b1, b3);
        a3 = max(b1, b3);

        b0 = a0;
        b1 = min(a1, a2);
        b2 = max(a1, a2);
        b3 = a3;

        result(x, y) = select(x == 0, b0,
                              select(x == 1, b1,
                                  select(x == 2, b2, b3)));

        result.compute_at(parallel_stage, y).bound(x, 0, 4).unroll(x);

        stages.push_back(result);
    }

    // Now build up to the total size, merging each pair of rows
    for (int chunk_size = 4; chunk_size < total_size; chunk_size *= 2) {
        // "result" contains the sorted halves
        assert(result.dimensions() == 2);

        // Merge pairs of rows from the partial result
        Func merge_rows("merge_rows");
        RDom r(0, chunk_size*2);

        // The first dimension of merge_rows is within the chunk, and the
        // second dimension is the chunk index.  Keeps track of two
        // pointers we're merging from and an output value.
        merge_rows(x, y) = Tuple(0, 0, cast(input.value().type(), 0));

        Expr candidate_a = merge_rows(r-1, y)[0];
        Expr candidate_b = merge_rows(r-1, y)[1];
        Expr valid_a = candidate_a < chunk_size;
        Expr valid_b = candidate_b < chunk_size;
        Expr value_a = result(clamp(candidate_a, 0, chunk_size-1), 2*y);
        Expr value_b = result(clamp(candidate_b, 0, chunk_size-1), 2*y+1);
        merge_rows(r, y) = tuple_select(valid_a && ((value_a < value_b) || !valid_b),
                                        Tuple(candidate_a + 1, candidate_b, value_a),
                                        Tuple(candidate_a, candidate_b + 1, value_b));


        if (chunk_size <= parallel_work_size) {
            merge_rows.compute_at(parallel_stage, y);
        } else {
            merge_rows.compute_root();
        }

        if (chunk_size == parallel_work_size) {
            parallel_stage(x, y) = merge_rows(x, y)[2];
            parallel_stage.compute_root().parallel(y);
            result = parallel_stage;
        } else {
            result = lambda(x, y, merge_rows(x, y)[2]);
        }
    }

    // Convert back to 1D
    return lambda(x, result(x, 0));
}
Ejemplo n.º 3
0
int main(int argc, char **argv) {

    Var x("x"), y("y"), z("z"), w("w");

    ImageParam im1(Int(32), 3);
    assert(im1.dimensions() == 3);
    // im1 is a 3d imageparam

    Image<int> im1_val = lambda(x, y, z, x*y*z).realize(10, 10, 10);
    im1.set(im1_val);

    Image<int> im2 = lambda(x, y, x+y).realize(10, 10);
    assert(im2.dimensions() == 2);
    assert(im2(4, 6) == 10);
    // im2 is a 2d image

    Func f;
    f(x, _) = im1(_) + im2(x, _) + im2(_);
    // Equivalent to
    // f(x, i, j, k) = im1(i, j, k) + im2(x, i) + im2(i, j);
    // f(x, i, j, k) = i*j*k + x+i + i+j;

    Image<int> result1 = f.realize(2, 2, 2, 2);
    for (int k = 0; k < 2; k++) {
        for (int j = 0; j < 2; j++) {
            for (int i = 0; i < 2; i++) {
                for (int x = 0; x < 2; x++) {
                    int correct = i*j*k + x+i + i+j;
                    if (result1(x, i, j, k) != correct) {
                        printf("result1(%d, %d, %d, %d) = %d instead of %d\n",
                               x, i, j, k, result1(x, i, j, k), correct);
                        return -1;
                    }
                }
            }
        }
    }

    // f is a 4d function (thanks to the first arg having 3 implicit arguments
    assert(f.dimensions() == 4);

    Func g;
    g(_) = f(2, 2, _) + im2(Expr(1), _);
    f.compute_root();
    // Equivalent to
    // g(i, j) = f(2, 2, i, j) + im2(1, i);
    // g(i, j) = 2*i*j + 2+2 + 2+i + 1+i

    assert(g.dimensions() == 2);

    Image<int> result2 = g.realize(10, 10);
    for (int j = 0; j < 10; j++) {
        for (int i = 0; i < 10; i++) {
            int correct = 2*i*j + 2+2 + 2+i + 1+i;
            if (result2(i, j) != correct) {
                printf("result2(%d, %d) = %d instead of %d\n",
                       i, j, result2(i, j), correct);
                return -1;
            }
        }
    }

    // An image which ensures any transposition of unequal coordinates changes the value
    Image<int> im3 = lambda(x, y, z, w, (x<<24)|(y<<16)|(z<<8)|w).realize(10, 10, 10, 10);

    Func transpose_last_two;
    transpose_last_two(_, x, y) = im3(_, y, x);
    // Equivalent to transpose_last_two(_0, _1, x, y) = im3(_0, _1, x, y)

    Image<int> transposed = transpose_last_two.realize(10, 10, 10, 10);

    for (int i = 0; i < 10; i++) {
        for (int j = 0; j < 10; j++) {
            for (int k = 0; k < 10; k++) {
                 for (int l = 0; l < 10; l++) {
                   int correct = (i<<24)|(j<<16)|(l<<8)|k;
                   if (transposed(i, j, k, l) != correct) {
                         printf("transposed(%d, %d, %d, %d) = %d instead of %d\n",
                                i, j, k, l, transposed(i, j, k, l), correct);
                         return -1;
                     }
                 }
            }
        }
    }

    Func hairy_transpose;
    hairy_transpose(_, x, y) = im3(y, _, x) + im3(y, x, _);
    // Equivalent to hairy_transpose(_0, _1, x, y) = im3(y, _0, _1, x) +
    // im3(y, x, _0, _1)

    Image<int> hairy_transposed = hairy_transpose.realize(10, 10, 10, 10);

    for (int i = 0; i < 10; i++) {
        for (int j = 0; j < 10; j++) {
            for (int k = 0; k < 10; k++) {
                 for (int l = 0; l < 10; l++) {
                   int correct1 = (l<<24)|(i<<16)|(j<<8)|k;
                   int correct2 = (l<<24)|(k<<16)|(i<<8)|j;
                   int correct = correct1 + correct2;
                   if (hairy_transposed(i, j, k, l) != correct) {
                         printf("hairy_transposed(%d, %d, %d, %d) = %d instead of %d\n",
                                i, j, k, l, hairy_transposed(i, j, k, l), correct);
                         return -1;
                     }
                 }
            }
        }
    }

    Func hairy_transpose2;
    hairy_transpose2(_, x) = im3(_, x) + im3(x, x, _);
    // Equivalent to hairy_transpose2(_0, _1, _2, x) = im3(_0, _1, _2, x) +
    // im3(x, x, _0, _1)

    Image<int> hairy_transposed2 = hairy_transpose2.realize(10, 10, 10, 10);

    for (int i = 0; i < 10; i++) {
        for (int j = 0; j < 10; j++) {
            for (int k = 0; k < 10; k++) {
                 for (int l = 0; l < 10; l++) {
                   int correct1 = (i<<24)|(j<<16)|(k<<8)|l;
                   int correct2 = (l<<24)|(l<<16)|(i<<8)|j;
                   int correct = correct1 + correct2;
                   if (hairy_transposed2(i, j, k, l) != correct) {
                         printf("hairy_transposed2(%d, %d, %d, %d) = %d instead of %d\n",
                                i, j, k, l, hairy_transposed2(i, j, k, l), correct);
                         return -1;
                     }
                 }
            }
        }
    }

    printf("Success!\n");
    return 0;
}