Example #1
0
void FastMatmulRecursive(LockAndCounter& locker, MemoryManager<Scalar>& mem_mngr, Matrix<Scalar>& A, Matrix<Scalar>& B, Matrix<Scalar>& C, int total_steps, int steps_left, int start_index, double x, int num_threads, Scalar beta) {
    // Update multipliers
    C.UpdateMultiplier(A.multiplier());
    C.UpdateMultiplier(B.multiplier());
    A.set_multiplier(Scalar(1.0));
    B.set_multiplier(Scalar(1.0));
    // Base case for recursion
    if (steps_left == 0) {
        MatMul(A, B, C);
        return;
    }

    Matrix<Scalar> A11 = A.Subblock(2, 2, 1, 1);
    Matrix<Scalar> A12 = A.Subblock(2, 2, 1, 2);
    Matrix<Scalar> A21 = A.Subblock(2, 2, 2, 1);
    Matrix<Scalar> A22 = A.Subblock(2, 2, 2, 2);
    Matrix<Scalar> B11 = B.Subblock(2, 2, 1, 1);
    Matrix<Scalar> B12 = B.Subblock(2, 2, 1, 2);
    Matrix<Scalar> B21 = B.Subblock(2, 2, 2, 1);
    Matrix<Scalar> B22 = B.Subblock(2, 2, 2, 2);
    Matrix<Scalar> C11 = C.Subblock(2, 2, 1, 1);
    Matrix<Scalar> C12 = C.Subblock(2, 2, 1, 2);
    Matrix<Scalar> C21 = C.Subblock(2, 2, 2, 1);
    Matrix<Scalar> C22 = C.Subblock(2, 2, 2, 2);


    // Matrices to store the results of multiplications.
#ifdef _PARALLEL_
    Matrix<Scalar> M1(mem_mngr.GetMem(start_index, 1, total_steps - steps_left, M), C11.m(), C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M2(mem_mngr.GetMem(start_index, 2, total_steps - steps_left, M), C11.m(), C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M3(mem_mngr.GetMem(start_index, 3, total_steps - steps_left, M), C11.m(), C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M4(mem_mngr.GetMem(start_index, 4, total_steps - steps_left, M), C11.m(), C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M5(mem_mngr.GetMem(start_index, 5, total_steps - steps_left, M), C11.m(), C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M6(mem_mngr.GetMem(start_index, 6, total_steps - steps_left, M), C11.m(), C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M7(mem_mngr.GetMem(start_index, 7, total_steps - steps_left, M), C11.m(), C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M8(mem_mngr.GetMem(start_index, 8, total_steps - steps_left, M), C11.m(), C11.m(), C11.n(), C.multiplier());
#else
    Matrix<Scalar> M1(C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M2(C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M3(C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M4(C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M5(C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M6(C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M7(C11.m(), C11.n(), C.multiplier());
    Matrix<Scalar> M8(C11.m(), C11.n(), C.multiplier());
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    bool sequential1 = should_launch_task(8, total_steps, steps_left, start_index, 1, num_threads);
    bool sequential2 = should_launch_task(8, total_steps, steps_left, start_index, 2, num_threads);
    bool sequential3 = should_launch_task(8, total_steps, steps_left, start_index, 3, num_threads);
    bool sequential4 = should_launch_task(8, total_steps, steps_left, start_index, 4, num_threads);
    bool sequential5 = should_launch_task(8, total_steps, steps_left, start_index, 5, num_threads);
    bool sequential6 = should_launch_task(8, total_steps, steps_left, start_index, 6, num_threads);
    bool sequential7 = should_launch_task(8, total_steps, steps_left, start_index, 7, num_threads);
    bool sequential8 = should_launch_task(8, total_steps, steps_left, start_index, 8, num_threads);
#else
    bool sequential1 = false;
    bool sequential2 = false;
    bool sequential3 = false;
    bool sequential4 = false;
    bool sequential5 = false;
    bool sequential6 = false;
    bool sequential7 = false;
    bool sequential8 = false;
#endif



    // M1 = (1 * A11) * (1 * B11)
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    # pragma omp task if(sequential1) shared(mem_mngr, locker) untied
    {
#endif
        M1.UpdateMultiplier(Scalar(1));
        M1.UpdateMultiplier(Scalar(1));
        FastMatmulRecursive(locker, mem_mngr, A11, B11, M1, total_steps, steps_left - 1, (start_index + 1 - 1) * 8, x, num_threads, Scalar(0.0));
#ifndef _PARALLEL_
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
        locker.Decrement();
    }
    if (should_task_wait(8, total_steps, steps_left, start_index, 1, num_threads)) {
        # pragma omp taskwait
# if defined(_PARALLEL_) && (_PARALLEL_ == _HYBRID_PAR_)
        SwitchToDFS(locker, num_threads);
# endif
    }
#endif

    // M2 = (1 * A12) * (1 * B21)
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    # pragma omp task if(sequential2) shared(mem_mngr, locker) untied
    {
#endif
        M2.UpdateMultiplier(Scalar(1));
        M2.UpdateMultiplier(Scalar(1));
        FastMatmulRecursive(locker, mem_mngr, A12, B21, M2, total_steps, steps_left - 1, (start_index + 2 - 1) * 8, x, num_threads, Scalar(0.0));
#ifndef _PARALLEL_
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
        locker.Decrement();
    }
    if (should_task_wait(8, total_steps, steps_left, start_index, 2, num_threads)) {
        # pragma omp taskwait
# if defined(_PARALLEL_) && (_PARALLEL_ == _HYBRID_PAR_)
        SwitchToDFS(locker, num_threads);
# endif
    }
#endif

    // M3 = (1 * A11) * (1 * B12)
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    # pragma omp task if(sequential3) shared(mem_mngr, locker) untied
    {
#endif
        M3.UpdateMultiplier(Scalar(1));
        M3.UpdateMultiplier(Scalar(1));
        FastMatmulRecursive(locker, mem_mngr, A11, B12, M3, total_steps, steps_left - 1, (start_index + 3 - 1) * 8, x, num_threads, Scalar(0.0));
#ifndef _PARALLEL_
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
        locker.Decrement();
    }
    if (should_task_wait(8, total_steps, steps_left, start_index, 3, num_threads)) {
        # pragma omp taskwait
# if defined(_PARALLEL_) && (_PARALLEL_ == _HYBRID_PAR_)
        SwitchToDFS(locker, num_threads);
# endif
    }
#endif

    // M4 = (1 * A12) * (1 * B22)
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    # pragma omp task if(sequential4) shared(mem_mngr, locker) untied
    {
#endif
        M4.UpdateMultiplier(Scalar(1));
        M4.UpdateMultiplier(Scalar(1));
        FastMatmulRecursive(locker, mem_mngr, A12, B22, M4, total_steps, steps_left - 1, (start_index + 4 - 1) * 8, x, num_threads, Scalar(0.0));
#ifndef _PARALLEL_
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
        locker.Decrement();
    }
    if (should_task_wait(8, total_steps, steps_left, start_index, 4, num_threads)) {
        # pragma omp taskwait
# if defined(_PARALLEL_) && (_PARALLEL_ == _HYBRID_PAR_)
        SwitchToDFS(locker, num_threads);
# endif
    }
#endif

    // M5 = (1 * A21) * (1 * B11)
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    # pragma omp task if(sequential5) shared(mem_mngr, locker) untied
    {
#endif
        M5.UpdateMultiplier(Scalar(1));
        M5.UpdateMultiplier(Scalar(1));
        FastMatmulRecursive(locker, mem_mngr, A21, B11, M5, total_steps, steps_left - 1, (start_index + 5 - 1) * 8, x, num_threads, Scalar(0.0));
#ifndef _PARALLEL_
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
        locker.Decrement();
    }
    if (should_task_wait(8, total_steps, steps_left, start_index, 5, num_threads)) {
        # pragma omp taskwait
# if defined(_PARALLEL_) && (_PARALLEL_ == _HYBRID_PAR_)
        SwitchToDFS(locker, num_threads);
# endif
    }
#endif

    // M6 = (1 * A22) * (1 * B21)
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    # pragma omp task if(sequential6) shared(mem_mngr, locker) untied
    {
#endif
        M6.UpdateMultiplier(Scalar(1));
        M6.UpdateMultiplier(Scalar(1));
        FastMatmulRecursive(locker, mem_mngr, A22, B21, M6, total_steps, steps_left - 1, (start_index + 6 - 1) * 8, x, num_threads, Scalar(0.0));
#ifndef _PARALLEL_
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
        locker.Decrement();
    }
    if (should_task_wait(8, total_steps, steps_left, start_index, 6, num_threads)) {
        # pragma omp taskwait
# if defined(_PARALLEL_) && (_PARALLEL_ == _HYBRID_PAR_)
        SwitchToDFS(locker, num_threads);
# endif
    }
#endif

    // M7 = (1 * A21) * (1 * B12)
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    # pragma omp task if(sequential7) shared(mem_mngr, locker) untied
    {
#endif
        M7.UpdateMultiplier(Scalar(1));
        M7.UpdateMultiplier(Scalar(1));
        FastMatmulRecursive(locker, mem_mngr, A21, B12, M7, total_steps, steps_left - 1, (start_index + 7 - 1) * 8, x, num_threads, Scalar(0.0));
#ifndef _PARALLEL_
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
        locker.Decrement();
    }
    if (should_task_wait(8, total_steps, steps_left, start_index, 7, num_threads)) {
        # pragma omp taskwait
# if defined(_PARALLEL_) && (_PARALLEL_ == _HYBRID_PAR_)
        SwitchToDFS(locker, num_threads);
# endif
    }
#endif

    // M8 = (1 * A22) * (1 * B22)
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    # pragma omp task if(sequential8) shared(mem_mngr, locker) untied
    {
#endif
        M8.UpdateMultiplier(Scalar(1));
        M8.UpdateMultiplier(Scalar(1));
        FastMatmulRecursive(locker, mem_mngr, A22, B22, M8, total_steps, steps_left - 1, (start_index + 8 - 1) * 8, x, num_threads, Scalar(0.0));
#ifndef _PARALLEL_
#endif
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
        locker.Decrement();
    }
    if (should_task_wait(8, total_steps, steps_left, start_index, 8, num_threads)) {
        # pragma omp taskwait
    }
#endif

    M_Add1(M1, M2, C11, x, false, beta);
    M_Add2(M3, M4, C12, x, false, beta);
    M_Add3(M5, M6, C21, x, false, beta);
    M_Add4(M7, M8, C22, x, false, beta);

    // Handle edge cases with dynamic peeling
#if defined(_PARALLEL_) && (_PARALLEL_ == _BFS_PAR_ || _PARALLEL_ == _HYBRID_PAR_)
    if (total_steps == steps_left) {
        mkl_set_num_threads_local(num_threads);
        mkl_set_dynamic(0);
    }
#endif
    DynamicPeeling(A, B, C, 2, 2, 2, beta);
}