void strassen_mm( const unsigned int m, const unsigned int n, const unsigned int k, const Dtype *A, const int incRowA, const Dtype *B, const int incRowB, Dtype *C, const int incRowC){ int max_length = maxThree(m, n, k); int new_length = getNumberLargerThanXAndIsPowerOfTwo(max_length); Dtype* newA = pad_matrix(A, m, k, incRowA, new_length, new_length); Dtype* newB = pad_matrix(B, k, n, incRowB, new_length, new_length); Dtype* newC = pad_matrix(C, m, n, incRowC, new_length, new_length); strassen_mm_worker( new_length, new_length, new_length, newA, new_length, newB, new_length, newC, new_length); matrix_copyTo(newC, new_length, new_length, new_length, C, m, n, incRowC); // remove the extra workspace remove_matrix(newA); remove_matrix(newB); remove_matrix(newC); }
/** * NAME: strassen_preprocess * INPUT: MATRIX* mOrig1, MATRIX* mOrig2, MARIX* mNew1, MARIX* mNew1 * USAGE: Checks to see if mOrig1 and mOrig2 can be multiplied. * If so, pads both matrices and stores the new matrices into mNew1 and mNew2. * * NOTES: Assumes all inputs are already initialized. */ void strassen_preprocess(MATRIX* mOrig1, MATRIX* mOrig2, MATRIX* mNew1, MATRIX* mNew2) { // Checks to see whether m1 and m2 are valid matrices. if (mOrig1->numCols != mOrig2->numRows) { printf("Error: Matrices cannot be multiplied"); free(mNew1); free(mNew2); mNew1=NULL; mNew2=NULL; return; } // Pad matrices. pad_matrix(mOrig1, mNew1); pad_matrix(mOrig2, mNew2); }