TEST(MatrixMultiply, RhsBroadcastBatched) { const int M = 512; const int K = 512; const int N = 10; const int D2 = 2; const int D3 = 3; for (int d3 = 1; d3 <= D3; d3 *= D3) { for (int d2 = 1; d2 <= D2; d2 *= D2) { array a = randu(M, K, d2, d3); array b = randu(K, N); array c = matmul(a, b); for (int j = 0; j < d3; j++) { for (int i = 0; i < d2; i++) { array a_ij = a(span, span, i, j); array c_ij = c(span, span, i, j); array res = matmul(a_ij, b); EXPECT_LT(max<float>(abs(c_ij - res)), 1E-3) << " for d2 = " << d2 << " for d3 = " << d3; } } } } }
void cppMatMulCheck(string TestFile) { if (noDoubleTests<T>()) return; vector<dim4> numDims; vector<vector<T> > hData; vector<vector<T> > tests; readTests<T,T,int>(TestFile, numDims, hData, tests); array a(numDims[0], &hData[0].front()); array b(numDims[1], &hData[1].front()); dim4 atdims = numDims[0]; { dim_t f = atdims[0]; atdims[0] = atdims[1]; atdims[1] = f; } dim4 btdims = numDims[1]; { dim_t f = btdims[0]; btdims[0] = btdims[1]; btdims[1] = f; } array aT = moddims(a, atdims.ndims(), atdims.get()); array bT = moddims(b, btdims.ndims(), btdims.get()); vector<array> out(tests.size()); if(isBVector) { out[0] = matmul(aT, b, AF_MAT_NONE, AF_MAT_NONE); out[1] = matmul(bT, a, AF_MAT_NONE, AF_MAT_NONE); out[2] = matmul(b, a, AF_MAT_TRANS, AF_MAT_NONE); out[3] = matmul(bT, aT, AF_MAT_NONE, AF_MAT_TRANS); out[4] = matmul(b, aT, AF_MAT_TRANS, AF_MAT_TRANS); } else { out[0] = matmul(a, b, AF_MAT_NONE, AF_MAT_NONE); out[1] = matmul(a, bT, AF_MAT_NONE, AF_MAT_TRANS); out[2] = matmul(a, bT, AF_MAT_TRANS, AF_MAT_NONE); out[3] = matmul(aT, bT, AF_MAT_TRANS, AF_MAT_TRANS); } for(size_t i = 0; i < tests.size(); i++) { dim_t elems = out[i].elements(); vector<T> h_out(elems); out[i].host((void*)&h_out.front()); if (false == equal(h_out.begin(), h_out.end(), tests[i].begin())) { cout << "Failed test " << i << "\nCalculated: " << endl; copy(h_out.begin(), h_out.end(), ostream_iterator<T>(cout, ", ")); cout << "Expected: " << endl; copy(tests[i].begin(), tests[i].end(), ostream_iterator<T>(cout, ", ")); FAIL(); } } }
TEST(MatrixMultiply, ISSUE_1882) { const int m = 2; const int n = 3; array A = randu(m, n); array BB = randu(n, m); array B = BB(0, span); array res1 = matmul(A.T(), B.T()); array res2 = matmulTT(A, B); ASSERT_ARRAYS_NEAR(res1, res2, 1E-5); }