void test_gemm(const std::string &name, int M, int N, int K, float alpha, float beta) { Halide::Buffer<int32_t> sizes(3); Halide::Buffer<float> params(3); Halide::Buffer<float> A(K, M); Halide::Buffer<float> B(N, K); Halide::Buffer<float> C(N, M); Halide::Buffer<float> C_ref(N, M); sizes(0) = M; sizes(1) = N; sizes(2) = K; params(0) = alpha; params(1) = beta; for (int i = 0; i < K; i++) for (int j = 0; j < M; j++) A(i, j) = std::rand() % 10 - 5; for (int i = 0; i < N; i++) for (int j = 0; j < K; j++) B(i, j) = std::rand() % 10 - 5; for (int i = 0; i < N; i++) for (int j = 0; j < M; j++) C(i, j) = std::rand() % 10 - 5; for (int i = 0; i < N; i++) { for (int j = 0; j < M; j++) { C_ref(i, j) = beta * C(i, j); for (int k = 0; k < K; k++) { C_ref(i, j) += alpha * A(k, j) * B(i, k); } } } test_162(sizes.raw_buffer(), params.raw_buffer(), A.raw_buffer(), B.raw_buffer(), C.raw_buffer()); compare_buffers(name, C, C_ref); }
void computeError( int ldc, int ldc_ref, int m, int n, double *C, double *C_ref ) { int i, j; for ( i = 0; i < m; i ++ ) { for ( j = 0; j < n; j ++ ) { if ( fabs( C( i, j ) - C_ref( i, j ) ) > TOLERANCE ) { printf( "C[ %d ][ %d ] != C_ref, %E, %E\n", i, j, C( i, j ), C_ref( i, j ) ); break; } } } }
void test_gemm(const std::string &name, int M, int N, int K, float alpha, float beta, int rowsA, int colsA, int rowsB, int colsB, int rowsC, int colsC, int offsetA, int offsetB, int offsetC, bool transposeA, bool transposeB) { Halide::Buffer<int32_t> sizes(12); Halide::Buffer<float> params(2); Halide::Buffer<bool> transposes(2); Halide::Buffer<float> A(colsA, rowsA); Halide::Buffer<float> B(colsB, rowsB); Halide::Buffer<float> C(colsC, rowsC); Halide::Buffer<float> C_ref(colsC, rowsC); sizes(0) = M; sizes(1) = N; sizes(2) = K; sizes(3) = rowsA; sizes(4) = colsA; sizes(5) = rowsB; sizes(6) = colsB; sizes(7) = rowsC; sizes(8) = colsC; sizes(9) = offsetA; sizes(10) = offsetB; sizes(11) = offsetC; params(0) = alpha; params(1) = beta; transposes(0) = transposeA; transposes(1) = transposeB; for (int i = 0; i < rowsA; i++) for (int j = 0; j < colsA; j++) A(j, i) = std::rand() % 10 - 5; for (int i = 0; i < rowsB; i++) for (int j = 0; j < colsB; j++) B(j, i) = std::rand() % 10 - 5; for (int i = 0; i < rowsC; i++) for (int j = 0; j < colsC; j++) C(j, i) = C_ref(j, i) = std::rand() % 10 - 5; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { C_ref(j + offsetC % colsC, i + offsetC / colsC) *= beta; for (int k = 0; k < K; k++) { float a = transposeA ? A(i + offsetA % colsA, k + offsetA / colsA) : A(k + offsetA % colsA, i + offsetA / colsA); float b = transposeB ? B(k + offsetB % colsB, j + offsetB / colsB) : B(j + offsetB % colsB, k + offsetB / colsB); C_ref(j + offsetC % colsC, i + offsetC / colsC) += alpha * a * b; } } } test_164(sizes.raw_buffer(), params.raw_buffer(), transposes.raw_buffer(), A.raw_buffer(), B.raw_buffer(), C.raw_buffer()); compare_buffers(name, C, C_ref); }
void test_bl_dgemm( int m, int n, int k ) { int i, j, p, nx; double *A, *B, *C, *C_ref; double tmp, error, flops; double ref_beg, ref_time, bl_dgemm_beg, bl_dgemm_time; int nrepeats; int lda, ldb, ldc, ldc_ref; double ref_rectime, bl_dgemm_rectime; A = (double*)malloc( sizeof(double) * m * k ); B = (double*)malloc( sizeof(double) * k * n ); lda = m; ldb = k; //ldc = ( ( m - 1 ) / DGEMM_MR + 1 ) * DGEMM_MR; ldc = m; ldc_ref = m; C = bl_malloc_aligned( ldc, n + 4, sizeof(double) ); C_ref = (double*)malloc( sizeof(double) * m * n ); nrepeats = 3; srand (time(NULL)); // Randonly generate points in [ 0, 1 ]. for ( p = 0; p < k; p ++ ) { for ( i = 0; i < m; i ++ ) { A( i, p ) = (double)( drand48() ); } } for ( j = 0; j < n; j ++ ) { for ( p = 0; p < k; p ++ ) { B( p, j ) = (double)( drand48() ); } } for ( j = 0; j < n; j ++ ) { for ( i = 0; i < m; i ++ ) { C_ref( i, j ) = (double)( 0.0 ); C( i, j ) = (double)( 0.0 ); } } for ( i = 0; i < nrepeats; i ++ ) { bl_dgemm_beg = bl_clock(); { bl_dgemm( m, n, k, A, lda, B, ldb, C, ldc ); } bl_dgemm_time = bl_clock() - bl_dgemm_beg; if ( i == 0 ) { bl_dgemm_rectime = bl_dgemm_time; } else { bl_dgemm_rectime = bl_dgemm_time < bl_dgemm_rectime ? bl_dgemm_time : bl_dgemm_rectime; } } for ( i = 0; i < nrepeats; i ++ ) { ref_beg = bl_clock(); { bl_dgemm_ref( m, n, k, A, lda, B, ldb, C_ref, ldc_ref ); } ref_time = bl_clock() - ref_beg; if ( i == 0 ) { ref_rectime = ref_time; } else { ref_rectime = ref_time < ref_rectime ? ref_time : ref_rectime; } } computeError( ldc, ldc_ref, m, n, C, C_ref ); // Compute overall floating point operations. flops = ( m * n / ( 1000.0 * 1000.0 * 1000.0 ) ) * ( 2 * k ); printf( "%5d\t %5d\t %5d\t %5.2lf\t %5.2lf\n", m, n, k, flops / bl_dgemm_rectime, flops / ref_rectime ); free( A ); free( B ); free( C ); free( C_ref ); }