void matmul_strassen(volatile float * a, volatile float * b, volatile float * c, int NN) { float **as, **bs, **cs; int i, j; int LEAF_SIZE; as = allocate_real_matrix(NN, -1); bs = allocate_real_matrix(NN, -1); cs = allocate_real_matrix(NN, -1); for (i=0; i<NN; i++) for (j=0; j<NN; j++) { as[i][j] = a[i*NN+j]; bs[i][j] = b[i*NN+j]; } LEAF_SIZE = 32; strassen(as, bs, cs, NN, LEAF_SIZE); for (i=0; i<NN; i++) for (j=0; j<NN; j++) { c[i*NN+j] = cs[i][j]; } as = free_real_matrix(as, NN); bs = free_real_matrix(bs, NN); cs = free_real_matrix(cs, NN); return; }
int main(int argc, char **argv) { /* Variables required by the algorithm. */ unsigned int old_n, n; unsigned short int **A; unsigned short int **B; unsigned short int **C; /* Variables required to calculate run time. */ clock_t cs, ce; /* Check presence of parameters. */ if (argc > 3) { if (strcmp(argv[1], "-r") == 0) { /* Get seed from mix function for filling the matrices. */ unsigned long int seed = mix(clock(), time(NULL), getpid()); old_n = atoi(argv[2]); n = normalize_power_of_two(old_n); A = init_matr(&n); B = init_matr(&n); C = init_matr(&n); srand(seed); /* Fill A and B randomly with 0s and 1s. */ randomize(A, &n); randomize(B, &n); } else { usage(); } } else if (argc == 2) { /* If only the flag is passed and not the number, exit with an error. */ if (strcmp(argv[1], "-r") == 0) usage(); xscanf("%u", &old_n); /* Normalize n to the closest power of two. */ n = normalize_power_of_two(old_n); A = init_matr(&n); B = init_matr(&n); C = init_matr(&n); fill_matr(A, &n, &old_n); fill_matr(B, &n, &old_n); } else { usage(); } cs = clock(); C = strassen(A, B, n); ce = clock(); printf("%u %.20fs\n", old_n, (float) (ce - cs) / CLOCKS_PER_SEC); if (strcmp(argv[3], "--scale-test") != 0) print_matr(C, &old_n); clear_matr(A, &n); clear_matr(B, &n); clear_matr(C, &n); return EXIT_SUCCESS; }
int* orderedmult(int** m, int* s, int* order, int n_mat) { if (n_mat==1) return m[0]; else { int* a=orderedmult(m,s,order+1,order[0]); int* b=orderedmult(m+order[0], s+order[0],order+order[0],n_mat-order[0]); return strassen(a,b,s[0],s[order[0]],s[n_mat]); } }
matrix strassen(matrix x, matrix y, int n) {//矩阵X和Y是n阶矩阵,元素下标从0到n-1,n是偶数 //返回相乘的结果矩阵Z matrix a(n / 2, n / 2), b(n / 2, n / 2), c(n / 2, n / 2), d(n / 2, n / 2); matrix e(n / 2, n / 2), f(n / 2, n / 2), g(n / 2, n / 2), h(n / 2, n / 2); matrix r(n / 2, n / 2), s(n / 2, n / 2), t(n / 2, n / 2), u(n / 2, n / 2); matrix p[8]; for(int i = 0; i < 8; ++ i) p[1].m_row = n / 2, p[1].m_col = n / 2; //递归终止条件,当子矩阵是2阶矩阵时递归到达最后一层,不再向下 if(n == 2) //计算2阶矩阵X和Y相乘 return(x * y); //初始化X和Y的8个子矩阵abcdefgh construct_xy(a, b, c, d, e, f, g, h, x, y, n); //求P1到P7这7个中间矩阵 //P1 = af - ah = a(f - h) p[1] = strassen(a, f - h, n / 2); //P2 = ah + bh = (a + b)h p[2] = strassen(a + b, h, n / 2); //P3 = ce + de = (c + d)e p[3] = strassen(c + d, e, n / 2); //P4 = dg - de = d(g - e) p[4] = strassen(d, g - e, n / 2); //P5 = ae + ah + de + dh = (a + d)(e + h) p[5] = strassen(a + d, e + h, n / 2); //P6 = bg + bh - dg - dh = (b - d)(g + h) p[6] = strassen(b - d, g + h, n / 2); //P7 = ae + af - ce - cf = (a - c)(e + f) p[7] = strassen(a - c, e + f, n / 2); //求Z的四个子矩阵rstu r = p[5] + p[4] - p[2] + p[6]; s = p[1] + p[2]; t = p[3] + p[4]; u = p[1] + p[5] - p[3] - p[7]; matrix z(n, n); //将rstu这4个子矩阵构造为矩阵Z construct_z(r, s, t, u, z, n); return(z); }
int main() { /* Перемножение матриц 1024*1024 Переход к обычному уумножению при разных значениях k: k = 16 : 3.532507 k = 32 : 3.174222 k = 64 : 3.139104 k = 128 : 3.255240 k = 256 : 3.902180 k = 512 : 5.226510 */ int** a, **b, **c, **d; int n = 1024, k = 16; a = matrix_new(n); b = matrix_new(n); matrix_fill(a, n); matrix_fill(b, n); printf("Starting:\n"); clock_t t = clock(); c = matrix_mult(a, b, n); t = clock()-t; printf("Обычное умножение - %f\n", ((double)t/CLOCKS_PER_SEC)); printf("Начало перемножения методом Штрассена \n"); t = clock(); d = strassen(a, b, n, k); t = clock()-t; printf("Умножение методом Штрассена - %f\n", ((double)t/CLOCKS_PER_SEC)); return 0; }
int main(int argc,char** argv){ if(argc < 4 || argc > 5) return 1; int verbose = 0; Matrices m; int i = 0; for( i = 0 ; i < argc ; i++){ if(strcmp(argv[i],"-p") == 0) verbose = 1; if(strcmp(argv[i],"-f") == 0){ if( argc - i < 3){ printf("You have to specify two files\n"); return 1; } else{ readFiles(argv[i+1],argv[i+2],&m); i+=2; } } } #ifdef conventionnel //multiplyMatrix(&m); strassen(&m); #endif #ifdef conventionnelT multiplyMatrixT(&m); #endif if(verbose){ for( i = 0 ; i < m.size*m.size; i++){ printf("%d\t",m.result[i]); if(i%m.size == m.size - 1) printf("\n"); } } free(m.a); free(m.b); free(m.result); return 0; }
main() { int n,**a,**b,**ree; cout<<"\n\t\tSTRASSEN'S MATRIX MULTIPLICATION USING D AND C"; cout<<"\n\t\t..............................."; cout<<"\nEnter the order of matrix:"; cin>>n; cout<<"\nEnter the first matrix element:"; a=readmat(n); cout<<"\nEnter the second matrix"; b=readmat(n); ree=strassen(a,b,n); cout<<"\nFirst matrix\n"; dismat(a,n); cout<<"\nSecond matrix\n"; dismat(b,n); cout<<"\nResult matrix\n"; dismat(ree,n); free(ree); getch(); }
/* convenience function for Matrix Matrix Multiplication @param A input matrix @param B input matrix (COLUMN MAJOR) @param C output matrix */ inline void MMM(Matrix& A, Matrix& B, Matrix& C){ /*int LD = A.dimRows; if (LD<A.dimCols) LD = A.dimCols; if (LD<B.dimCols) LD = B.dimCols; LD--; LD |= LD >> 1; LD |= LD >> 2; LD |= LD >> 4; LD |= LD >> 8; LD |= LD >> 16; LD++;*/ double* temp = (double*) aligned_alloc(ALIGNMENT, sizeof(double) * LD * LD); Matrix BT(temp, temp, B.getDimN(), B.getDimM(), 0, 0); Matrix P( (double*) aligned_alloc(ALIGNMENT, sizeof(double) * LD * LD), nullptr, A.getDimM(), A.getDimM(), 0, 0); Matrix PS( (double*) aligned_alloc(ALIGNMENT, sizeof(double) * LD * LD), nullptr, A.getDimM(), A.getDimM(), 0, 0); Matrix S( (double*) aligned_alloc(ALIGNMENT, sizeof(double) * LD * LD), nullptr, A.getDimM(), A.getDimM(), 0, 0); Matrix T(nullptr, (double*) aligned_alloc(ALIGNMENT, sizeof(double) * LD * LD), A.getDimM(), A.getDimM(), 0, 0); A.dimRows = LD - PADDING; transpose(B, BT); if (A.dimRows<TRUNCATION_POINT){ naive(A, BT, C); } else { strassen(A, BT, C, P, PS, S, T); } free(BT.data); free(P.data); free(PS.data); free(S.data); free(T.dataT); }
int main() { int n; scanf("%d",&n); int i,j; int arr1[100][100]; int arr2[100][100]; int arr3[100][100]; struct rusage usage; getrusage(RUSAGE_SELF, &usage); for(i=1;i<=n;i++) { for(j=1;j<=n;j++) { scanf("%d",&arr1[i][j]); } } for(i=1;i<=n;i++) { for(j=1;j<=n;j++) { scanf("%d",&arr2[i][j]); } } strassen(arr1,arr2,arr3,n); for(i=1;i<=n;i++) { for(j=1;j<=n;j++) { printf("%d ",arr3[i][j]); } printf("\n"); } printf(" the number of page faults are %ld ",usage.ru_majflt + usage.ru_minflt); return 0; }
NumericMatrix strassenMM(SEXP xs, SEXP ys) { double * X = as<double*>(xs); double * Y = as<double*>(ys); // Dimension xs & ys NumericMatrix xx(xs); NumericMatrix yy(ys); int n = xx.nrow(), m = yy.ncol(); if(xx.ncol() != yy.nrow()) { stop("invalid dimenstion of matrices"); } int p = xx.ncol(); // Result matrix Z double * Z = new double[m * n]; // Initialization // memset(Z, 0, sizeof(double) * m * n); // Matrix production strassen(X, Y, Z, n, p, m); NumericMatrix zz = wrap(Z, n, m); return zz; }
void strassen(int a[100][100],int b[100][100],int c[100][100],int n) { if(n==1) { // printf("kk%dkk%dkk ",a[1][1],b[1][1]); c[1][1]=a[1][1]*b[1][1]; } else { int a11[100][100]; int a12[100][100]; int a21[100][100]; int a22[100][100]; int b11[100][100]; int b12[100][100]; int b21[100][100]; int b22[100][100]; int c11[100][100]; int c12[100][100]; int c21[100][100]; int c22[100][100]; int i,j; int k=1; int l=1; for(i=1;i<=n/2;i++) { for(j=1;j<=n/2;j++) { a11[k][l]=a[i][j]; b11[k][l]=b[i][j]; l++; } k++; l=1; } k=1; l=1; for(i=(n/2)+1;i<=n;i++) { for(j=1;j<=n/2;j++) { a21[k][l]=a[i][j]; b21[k][l]=b[i][j]; l++; } k++; l=1; } k=1; l=1; for(i=1;i<=n/2;i++) { for(j=(n/2)+1;j<=n;j++) { a12[k][l]=a[i][j]; b12[k][l]=b[i][j]; l++; } k++; l=1; } k=1; l=1; for(i=(n/2)+1;i<=n;i++) { for(j=(n/2)+1;j<=n;j++) { a22[k][l]=a[i][j]; b22[k][l]=b[i][j]; l++; } k++; l=1; } /* for(i=1;i<=n/2;i++) { for(j=1;j<=n/2;j++) { printf("%d ",b11[1][1]); } printf("\n"); }*/ int arr1[100][100]; int arr2[100][100]; int arr3[100][100]; int arr4[100][100]; int arr5[100][100]; int arr6[100][100]; int arr7[100][100]; int arr8[100][100]; strassen(a11,b11,arr1,n/2); /* for(i=1;i<=n/2;i++) { for(j=1;j<=n/2;j++) { printf("%d ",arr1[i][j]); } printf("\n"); }*/ strassen(a12,b21,arr2,n/2); strassen(a11,b12,arr3,n/2); strassen(a12,b22,arr4,n/2); strassen(a21,b11,arr5,n/2); strassen(a22,b21,arr6,n/2); strassen(a21,b12,arr7,n/2); strassen(a22,b22,arr8,n/2); for(i=1;i<=n/2;i++) { for(j=1;j<=n/2;j++) { c11[i][j]=arr1[i][j]+arr2[i][j]; c12[i][j]=arr3[i][j]+arr4[i][j]; c21[i][j]=arr5[i][j]+arr6[i][j]; c22[i][j]=arr8[i][j]+arr7[i][j]; } } k=1; l=1; for(i=1;i<=n/2;i++) { for(j=1;j<=n/2;j++) { c[k][l++]=c11[i][j]; } l=1; k++; } k=1; l=(n/2)+1; for(i=1;i<=n/2;i++) { for(j=1;j<=n/2;j++) { c[k][l++]=c12[i][j]; } l=(n/2)+1; k++; } k=(n/2)+1; l=1; for(i=1;i<=n/2;i++) { for(j=1;j<=n/2;j++) { c[k][l++]=c21[i][j]; } k++; l=1; } k=(n/2)+1; l=(n/2)+1; for(i=1;i<=n/2;i++) { for(j=1;j<=n/2;j++) { c[k][l++]=c22[i][j]; } k++; l=(n/2)+1; } } }
void strassen(int a[][num], int b[][num], int c[][num], int size) { int p1[size/2][size/2], p2[size/2][size/2], p3[size/2][size/2], p4[size/2][size/2], p5[size/2][size/2], p6[size/2][size/2], p7[size/2][size/2]; int temp1[size/2][size/2], temp2[size/2][size/2]; int q1, q2, q3, q4, q5, q6, q7, i, j; if(size >= 2) { //give recursive calls //p1 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp1[i][j] = a[i][j] + a[i + size / 2][j + size / 2]; } } for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp2[i][j] = b[i][j] + b[i + size / 2][j + size / 2]; } } num = size / 2; strassen(temp1, temp2, p1, size / 2); //p2 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp1[i][j] = a[i + size / 2][j] + a[i + size / 2][j + size / 2]; } } for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp2[i][j] = b[i][j]; } } num = size / 2; strassen(temp1, temp2, p2, size / 2); //p3 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp1[i][j] = a[i][j]; } } for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp2[i][j] = b[i][j + size / 2] - b[i + size / 2][j + size / 2]; } } num = size / 2; strassen(temp1, temp2, p3, size / 2); //p4 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp1[i][j] = a[i + size / 2][j + size / 2]; } } for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp2[i][j] = b[i + size / 2][j] - b[i][j]; } } num = size / 2; strassen(temp1, temp2, p4, size / 2); //p5 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp1[i][j] = a[i][j] + a[i][j + size / 2]; } } for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp2[i][j] = b[i + size / 2][j + size / 2]; } } num = size / 2; strassen(temp1, temp2, p5, size / 2); //p6 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp1[i][j] = a[i + size / 2][j] - a[i][j]; } }num = size / 2; for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp2[i][j] = b[i][j] + b[i][j + size / 2]; } } num = size / 2; strassen(temp1, temp2, p6, size / 2); //p7 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp1[i][j] = a[i][j + size / 2] - a[i + size / 2][j + size / 2]; } } for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { temp2[i][j] = b[i + size / 2][j] + b[i + size / 2][j + size / 2]; } } num = size / 2; strassen(temp1, temp2, p7, size / 2); //c11 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { c[i][j] = p1[i][j] + p4[i][j] - p5[i][j] + p7[i][j]; } } //c12 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { c[i][j + size / 2] = p3[i][j] + p5[i][j]; } } //c21 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { c[i + size / 2][j] = p2[i][j] + p4[i][j]; } } //c22 for(i = 0; i < size / 2; i++) { for(j = 0; j < size / 2; j++) { c[i + size / 2][j + size / 2] = p1[i][j] + p3[i][j] - p2[i][j] + p6[i][j]; } } } else if(size == 1) { c[0][0] = a[0][0] * b[0][0]; } }
int main() { int i, j, temp; printf("Enter the size of nxn matrix:\n"); scanf("%d", &num); temp = num; if(num <= 0) return 0; num = padding(num); int a[num][num], b[num][num], c[num][num]; printf("Enter matrix a:\n"); //accept inputs for a and b from the user for(i = 0; i < temp; i++) { for(j = 0; j < temp; j++) { scanf("%d", &a[i][j]); } for(j = temp; j < num; j++) { a[i][j] = 0; } } for(i = temp; i < num; i++) for(j = 0; j < num; j++) a[i][j] = 0; printf("\nEnter matrix b:\n"); for(i = 0; i < temp; i++) { for(j = 0; j < temp; j++) { scanf("%d", &b[i][j]); } for(j = temp; j < num; j++) { b[i][j] = 0; } } for(i = temp; i < num; i++) for(j = 0; j < num; j++) b[i][j] = 0; printf("Matrix a:\n"); //printing the actual matrices for strassen's multiplication for(i = 0; i < num; i++) { for(j = 0; j < num; j++) { printf("%d ", a[i][j]); } printf("\n"); } printf("\nMatrix b:\n"); for(i = 0; i < num; i++) { for(j = 0; j < num; j++) { printf("%d ", b[i][j]); } printf("\n"); } strassen(a, b, c, num); printf("\nMatrix c is:\n"); for(i = 0; i < temp; i++) { for(j = 0; j < temp; j++) { printf("%d ", c[i][j]); } printf("\n"); } return 0;
void strassen(my_type **A, my_type **B, my_type **C, size_t size) { if (size == FIXEDSIZE) { matrixMultiplicationTiled(A, B, C, size); return; } // if (size == FIXEDSIZE) { // matrixMultiplicationFixed(A, B, C); // return; // } // if (size == FIXEDSIZE) { // asmMul(*A, *B, *C); // return; // } // if (size == FIXEDSIZE) { // asmMul32(*A, *B, *C); // return; // } size_t mid = size / 2; my_type **A11 = getArray(mid); my_type **A12 = getArray(mid); my_type **A21 = getArray(mid); my_type **A22 = getArray(mid); my_type **B11 = getArray(mid); my_type **B12 = getArray(mid); my_type **B21 = getArray(mid); my_type **B22 = getArray(mid); for (size_t i = 0; i < mid; i++) { for (size_t j = 0; j < mid; j++) { A11[i][j] = A[i][j]; A12[i][j] = A[i][j + mid]; A21[i][j] = A[i + mid][j]; A22[i][j] = A[i + mid][j + mid]; B11[i][j] = B[i][j]; B12[i][j] = B[i][j + mid]; B21[i][j] = B[i + mid][j]; B22[i][j] = B[i + mid][j + mid]; } } my_type **S1 = getArray(mid); my_type **S2 = getArray(mid); addNew(*A21, *A22, *S1, mid); subNew(*S1, *A11, *S2, mid); subRight(*A11, *A21, mid); my_type **S3 = A21; my_type **P1 = getArray(mid); strassen(A11, B11, P1, mid); my_type **T1 = A11; subNew(*B12, *B11, *T1, mid); my_type **T2 = getArray(mid); subNew(*B22, *T1, *T2, mid); subRight(*B22, *B12, mid); my_type **T3 = B12; my_type **S4 = getArray(mid); my_type **T4 = getArray(mid); subNew(*A12, *S2, *S4, mid); subNew(*T2, *B21, *T4, mid); strassen(A12, B21, B11, mid); my_type **P2 = B11; strassen(S4, B22, B21, mid); my_type **P3 = B21; strassen(A22, T4, A12, mid); free(A22[0]); free(A22); my_type **P4 = A12; strassen(S1, T1, S4, mid); free(T1[0]); free(T1); free(S1[0]); free(S1); my_type **P5 = S4; strassen(S2, T2, B22, mid); free(S2[0]); free(S2); free(T2[0]); free(T2); my_type **P6 = B22; strassen(S3, T3, T4, mid); my_type **P7 = T4; free(S3[0]); free(S3); free(T3[0]); free(T3); // A22, T1, S1, S2, T2, S3, T3, // P6, P3 addLeft(*P2, *P1, mid); my_type **U1 = P2; addLeft(*P1, *P6, mid); free(P6[0]); free(P6); my_type **U2 = P1; addLeft(*P7, *U2, mid); my_type **U3 = P7; addLeft(*U2, *P5, mid); my_type **U4 = U2; addLeft(*U4, *P3, mid); free(P3[0]); free(P3); my_type **U5 = U4; subRight(*U3, *P4, mid); my_type **U6 = P4; addLeft(*U3, *P5, mid); my_type **U7 = U3; free(P5[0]); free(P5); for (size_t i = 0; i < mid; i++) { for (size_t j = 0; j < mid; j++) { C[i][j] = U1[i][j]; C[i][j + mid] = U5[i][j]; C[i + mid][j] = U6[i][j]; C[i + mid][j + mid] = U7[i][j]; } } free(U1[0]); free(U1); free(U5[0]); free(U5); free(U6[0]); free(U6); free(U7[0]); free(U7); return; }
void strassen(double **a, double **b, double **c, int tam) { /*trivial case: when the matrix is 1 X 1: */ if (tam <= BREAK){ if (tam == 1) { c[0][0] = a[0][0] * b[0][0]; return; } int i, j, k; for (i = 0; i < tam; i++) { for (k = 0; k < tam; k++) { for (j = 0; j < tam; j++) { c[i][j] += a[i][k] * b[k][j]; } } } return; } // other cases are treated here: int newTam = tam/2; double **a11, **a12, **a21, **a22; double **b11, **b12, **b21, **b22; double **c11, **c12, **c21, **c22; double **p1, **p2, **p3, **p4, **p5, **p6, **p7; // memory allocation: a11 = allocate_real_matrix(newTam, 0); a12 = allocate_real_matrix(newTam, 0); a21 = allocate_real_matrix(newTam, 0); a22 = allocate_real_matrix(newTam, 0); b11 = allocate_real_matrix(newTam, 0); b12 = allocate_real_matrix(newTam, 0); b21 = allocate_real_matrix(newTam, 0); b22 = allocate_real_matrix(newTam, 0); c11 = allocate_real_matrix(newTam, 0); c12 = allocate_real_matrix(newTam, 0); c21 = allocate_real_matrix(newTam, 0); c22 = allocate_real_matrix(newTam, 0); p1 = allocate_real_matrix(newTam, 0); p2 = allocate_real_matrix(newTam, 0); p3 = allocate_real_matrix(newTam, 0); p4 = allocate_real_matrix(newTam, 0); p5 = allocate_real_matrix(newTam, 0); p6 = allocate_real_matrix(newTam, 0); p7 = allocate_real_matrix(newTam, 0); double **aResult = allocate_real_matrix(newTam, 0); double **bResult = allocate_real_matrix(newTam, 0); int i, j; //dividing the matrices in 4 sub-matrices: for (i = 0; i < newTam; i++) { for (j = 0; j < newTam; j++) { a11[i][j] = a[i][j]; a12[i][j] = a[i][j + newTam]; a21[i][j] = a[i + newTam][j]; a22[i][j] = a[i + newTam][j + newTam]; b11[i][j] = b[i][j]; b12[i][j] = b[i][j + newTam]; b21[i][j] = b[i + newTam][j]; b22[i][j] = b[i + newTam][j + newTam]; } } // Calculating p1 to p7: sum(a11, a22, aResult, newTam); // a11 + a22 sum(b11, b22, bResult, newTam); // b11 + b22 strassen(aResult, bResult, p1, newTam); // p1 = (a11+a22) * (b11+b22) sum(a21, a22, aResult, newTam); // a21 + a22 strassen(aResult, b11, p2, newTam); // p2 = (a21+a22) * (b11) subtract(b12, b22, bResult, newTam); // b12 - b22 strassen(a11, bResult, p3, newTam); // p3 = (a11) * (b12 - b22) subtract(b21, b11, bResult, newTam); // b21 - b11 strassen(a22, bResult, p4, newTam); // p4 = (a22) * (b21 - b11) sum(a11, a12, aResult, newTam); // a11 + a12 strassen(aResult, b22, p5, newTam); // p5 = (a11+a12) * (b22) subtract(a21, a11, aResult, newTam); // a21 - a11 sum(b11, b12, bResult, newTam); // b11 + b12 strassen(aResult, bResult, p6, newTam); // p6 = (a21-a11) * (b11+b12) subtract(a12, a22, aResult, newTam); // a12 - a22 sum(b21, b22, bResult, newTam); // b21 + b22 strassen(aResult, bResult, p7, newTam); // p7 = (a12-a22) * (b21+b22) // calculating c21, c21, c11 e c22: sum(p3, p5, c12, newTam); // c12 = p3 + p5 sum(p2, p4, c21, newTam); // c21 = p2 + p4 sum(p1, p4, aResult, newTam); // p1 + p4 sum(aResult, p7, bResult, newTam); // p1 + p4 + p7 subtract(bResult, p5, c11, newTam); // c11 = p1 + p4 - p5 + p7 sum(p1, p3, aResult, newTam); // p1 + p3 sum(aResult, p6, bResult, newTam); // p1 + p3 + p6 subtract(bResult, p2, c22, newTam); // c22 = p1 + p3 - p2 + p6 // Grouping the results obtained in a single matrix: for (i = 0; i < newTam ; i++) { for (j = 0 ; j < newTam ; j++) { c[i][j] = c11[i][j]; c[i][j + newTam] = c12[i][j]; c[i + newTam][j] = c21[i][j]; c[i + newTam][j + newTam] = c22[i][j]; } } // deallocating memory (free): a11 = free_real_matrix(a11, newTam); a12 = free_real_matrix(a12, newTam); a21 = free_real_matrix(a21, newTam); a22 = free_real_matrix(a22, newTam); b11 = free_real_matrix(b11, newTam); b12 = free_real_matrix(b12, newTam); b21 = free_real_matrix(b21, newTam); b22 = free_real_matrix(b22, newTam); c11 = free_real_matrix(c11, newTam); c12 = free_real_matrix(c12, newTam); c21 = free_real_matrix(c21, newTam); c22 = free_real_matrix(c22, newTam); p1 = free_real_matrix(p1, newTam); p2 = free_real_matrix(p2, newTam); p3 = free_real_matrix(p3, newTam); p4 = free_real_matrix(p4, newTam); p5 = free_real_matrix(p5, newTam); p6 = free_real_matrix(p6, newTam); p7 = free_real_matrix(p7, newTam); aResult = free_real_matrix(aResult, newTam); bResult = free_real_matrix(bResult, newTam); } // end of Strassen function
/* Only works for matrices of size 2^n */ void strassen(matrix a, matrix b, matrix c) { if (a.urow == a.lrow) { c.arr[c.lrow][c.lcol] += a.arr[a.lrow][a.lcol] * b.arr[b.lrow][b.lcol]; return; } int newsize = (a.urow - a.lrow + 1)/2; matrix a11, a12, a21, a22; a11.arr = a.arr; a11.lrow = a.lrow; a11.urow = a.lrow + newsize - 1; a11.lcol = a.lcol; a11.ucol = a.lcol + newsize - 1; a12.arr = a.arr; a12.lrow = a.lrow; a12.urow = a.lrow + newsize - 1; a12.lcol = a.lcol + newsize; a12.ucol = a.ucol; a21.arr = a.arr; a21.lrow = a.lrow + newsize; a21.urow = a.urow; a21.lcol = a.lcol; a21.ucol = a.lcol + newsize - 1; a22.arr = a.arr; a22.lrow = a.lrow + newsize; a22.urow = a.urow; a22.lcol = a.lcol + newsize; a22.ucol = a.ucol; matrix b11, b12, b21, b22; b11.arr = b.arr; b11.lrow = b.lrow; b11.urow = b.lrow + newsize - 1; b11.lcol = b.lcol; b11.ucol = b.lcol + newsize - 1; b12.arr = b.arr; b12.lrow = b.lrow; b12.urow = b.lrow + newsize - 1; b12.lcol = b.lcol + newsize; b12.ucol = b.ucol; b21.arr = b.arr; b21.lrow = b.lrow + newsize; b21.urow = b.urow; b21.lcol = b.lcol; b21.ucol = b.lcol + newsize - 1; b22.arr = b.arr; b22.lrow = b.lrow + newsize; b22.urow = b.urow; b22.lcol = b.lcol + newsize; b22.ucol = b.ucol; matrix c11, c12, c21, c22; c11.arr = c.arr; c11.lrow = c.lrow; c11.urow = c.lrow + newsize - 1; c11.lcol = c.lcol; c11.ucol = c.lcol + newsize - 1; c12.arr = c.arr; c12.lrow = c.lrow; c12.urow = c.lrow + newsize - 1; c12.lcol = c.lcol + newsize; c12.ucol = c.ucol; c21.arr = c.arr; c21.lrow = c.lrow + newsize; c21.urow = c.urow; c21.lcol = c.lcol; c21.ucol = c.lcol + newsize - 1; c22.arr = c.arr; c22.lrow = c.lrow + newsize; c22.urow = c.urow; c22.lcol = c.lcol + newsize; c22.ucol = c.ucol; matrix m1 = make_matrix(newsize, newsize); matrix m2 = make_matrix(newsize, newsize); matrix m3 = make_matrix(newsize, newsize); matrix m4 = make_matrix(newsize, newsize); matrix m5 = make_matrix(newsize, newsize); matrix m6 = make_matrix(newsize, newsize); matrix m7 = make_matrix(newsize, newsize); matrix tmp1 = make_matrix(newsize, newsize); matrix tmp2 = make_matrix(newsize, newsize); /* m1 = (a11 + a22)(b11 + b22) */ sum(a11, a22, tmp1); sum(b11, b22, tmp2); strassen(tmp1, tmp2, m1); /* m2 = (a21 + a22)b11 */ sum(a21, a22, tmp1); strassen(tmp1, b11, m2); /* m3 = a11(b12 - b22) */ diff(b12, b22, tmp2); strassen(a11, tmp2, m3); /* m4 = a22(b21 - b11) */ diff(b21, b11, tmp2); strassen(a22, tmp2, m4); /* m5 = (a11 + a12)b22 */ sum(a11, a12, tmp1); strassen(tmp1, b22, m5); /* m6 = (a21 - a11)(b11 + b12) */ diff(a21, a11, tmp1); sum(b11, b12, tmp2); strassen(tmp1, tmp2, m6); /* m7 = (a12 - a22)(b21 + b22) */ diff(a12, a22, tmp1); sum(b21, b22, tmp2); strassen(tmp1, tmp2, m7); /* * Putting it all together. */ /* c11 = (m1 + m4) + (m7 - m5) */ sum(m1, m4, tmp1); diff(m7, m5, tmp2); sum(tmp1, tmp2, c11); /* c12 = m3 + m5 */ sum(m3, m5, c12); /* c21 = m2 + m4 */ sum(m2, m4, c21); /* c22 = (m1 - m2) + (m3 + m6) */ diff(m1, m2, tmp1); sum(m3, m6, tmp2); sum(tmp1, tmp2, c22); /* cleanup our mess */ free_matrix(m1); free_matrix(m2); free_matrix(m3); free_matrix(m4); free_matrix(m5); free_matrix(m6); free_matrix(m7); free_matrix(tmp1); free_matrix(tmp2); return; }
/* Strassen implementation of Matrix Matrix Multiplication @param A input row major matrix @param B input COLUMN MAJOR matrix @param C output row major matrix @param P temporary row major matrix @param Ps temporary row major matrix @param S temporary row major matrix @param T temporary COLUMN MAJOR matrix */ void strassen(Matrix& A, Matrix& B, Matrix& C, Matrix& P, Matrix& Ps, Matrix& S, Matrix& T){ //get matrix dimensions and calculate size of blocks int dim = A.getDimM(); //equal for all matrix dimensions int dim2 = dim * 0.5; //std::cout << dim << "\t" << dim2 << std::endl; //get blocks Matrix A1 = A.getSubMatrix(0, 0, dim2, dim2); Matrix A2 = A.getSubMatrix(0, dim2, dim2, dim2); Matrix A3 = A.getSubMatrix(dim2, 0, dim2, dim2); Matrix A4 = A.getSubMatrix(dim2, dim2, dim2, dim2); Matrix B1 = B.getSubMatrix(0, 0, dim2, dim2); Matrix B2 = B.getSubMatrix(0, dim2, dim2, dim2); Matrix B3 = B.getSubMatrix(dim2, 0, dim2, dim2); Matrix B4 = B.getSubMatrix(dim2, dim2, dim2, dim2); Matrix C1 = C.getSubMatrix(0, 0, dim2, dim2); Matrix C2 = C.getSubMatrix(0, dim2, dim2, dim2); Matrix C3 = C.getSubMatrix(dim2, 0, dim2, dim2); Matrix C4 = C.getSubMatrix(dim2, dim2, dim2, dim2); Matrix S1 = S.getSubMatrix(0, 0, dim2, dim2); Matrix S2 = S.getSubMatrix(0, dim2, dim2, dim2); Matrix S3 = S.getSubMatrix(dim2, 0, dim2, dim2); Matrix S4 = S.getSubMatrix(dim2, dim2, dim2, dim2); Matrix T1 = T.getSubMatrix(0, 0, dim2, dim2); Matrix T2 = T.getSubMatrix(0, dim2, dim2, dim2); Matrix T3 = T.getSubMatrix(dim2, 0, dim2, dim2); Matrix T4 = T.getSubMatrix(dim2, dim2, dim2, dim2); Matrix P1 = P.getSubMatrix(0, 0, dim2, dim2); Matrix P2 = P.getSubMatrix(0, dim2, dim2, dim2); Matrix P3 = P.getSubMatrix(dim2, 0, dim2, dim2); Matrix P4 = P.getSubMatrix(dim2, dim2, dim2, dim2); Matrix P5 = Ps.getSubMatrix(0, 0, dim2, dim2); Matrix P6 = Ps.getSubMatrix(0, dim2, dim2, dim2); Matrix P7 = Ps.getSubMatrix(dim2, 0, dim2, dim2); //Matrix P8 = Ps.getSubMatrix(dim2, dim2, dim2, dim2); //std::cout << "submatrices" << std::endl; //compute temporary S and T matrices for (int i=0; i<dim2; ++i){ for (int j=0; j<dim2; j+=4){ //std::cout << "S" << std::endl; __m256d* pA1 = A1.get(i, j); __m256d* pA2 = A2.get(i, j); __m256d* pA3 = A3.get(i, j); __m256d* pA4 = A4.get(i, j); __m256d* pS2 = S2.get(i, j); S1.set(i, j, (*pA3)+(*pA4)); S2.set(i, j, (*pA3)+(*pA4)-(*pA1)); S3.set(i, j, (*pA1)-(*pA3)); S4.set(i, j, (*pA2)-(*pS2)); pA1++; pA2++; pA3++; pA4++; pS2++; //S1(i, j) = A3(i, j) + A4(i, j); //S2(i, j) = S1(i, j) - A1(i, j); //S3(i, j) = A1(i ,j) - A3(i, j); //S4(i, j) = A2(i, j) - S2(i, j); //std::cout << "T" << std::endl; __m256d* pB1 = B1.getT(j, i); __m256d* pB2 = B2.getT(j, i); __m256d* pB3 = B3.getT(j, i); __m256d* pB4 = B4.getT(j, i); __m256d* pT2 = T2.getT(j, i); T1.setT(j, i, (*pB2)-(*pB1)); T2.setT(j, i, (*pB4)-((*pB2)-(*pB1))); T3.setT(j, i, (*pB4)-(*pB2)); T4.setT(j, i, (*pB3)-(*pT2)); pB1++; pB2++; pB3++; pB4++; pT2++; //T1.T(j, i) = B2.T(j, i) - B1.T(j, i); //T2.T(j, i) = B4.T(j, i) - T1.T(j, i); //T3.T(j, i) = B4.T(j ,i) - B2.T(j, i); //T4.T(j, i) = B3.T(j, i) - T2.T(j, i); } } //calculate products if (dim2<TRUNCATION_POINT) { naive(A1, B1, P1); naive(A2, B3, P2); naive(S1, T1, P3); naive(S2, T2, P4); naive(S3, T3, P5); naive(S4, B4, P6); naive(A4, T4, P7); } else { strassen(A1, B1, P1, C1, C2, C3, B2); strassen(A2, B3, P2, C1, C2, C3, B2); strassen(S1, T1, P3, C1, C2, C3, B2); strassen(S2, T2, P4, C1, C2, C3, B2); strassen(S3, T3, P5, C1, C2, C3, B2); strassen(S4, B4, P6, C1, C2, C3, B2); strassen(A4, T4, P7, C1, C2, C3, B2); } //assemble final matrix for (int i=0; i<dim2; ++i){ for (int j=0; j<dim2; j+=4){ __m256d* pP1 = P1.get(i, j); __m256d* pP2 = P2.get(i, j); __m256d* pP3 = P3.get(i, j); __m256d* pP4 = P4.get(i, j); __m256d* pP5 = P5.get(i, j); __m256d* pP6 = P6.get(i, j); __m256d* pP7 = P7.get(i, j); C1.set(i, j, (*pP1) + (*pP2)); C3.set(i, j, (*pP1) + (*pP4) + (*pP5) + (*pP7)); C4.set(i, j, (*pP1) + (*pP4) + (*pP5) + (*pP3)); C2.set(i, j, (*pP1) + (*pP4) + (*pP3) + (*pP6)); //C1(i, j) = P1(i, j) + P2(i, j); //C3(i, j) = P1(i, j) + P4(i, j) + P5(i, j) + P7(i, j); //C4(i, j) = P1(i, j) + P4(i, j) + P5(i, j) + P3(i, j); //C2(i, j) = P1(i, j) + P3(i, j) + P4(i, j) + P6(i, j); pP1++; pP2++; pP3++; pP4++; pP5++; pP6++; pP7++; } } }
void strassen(Matrices *m){ int seuil = 0; int *a11,*a12,*a21,*a22; int *b11,*b12,*b21,*b22; a11 = malloc((m->size*m->size*sizeof(int))/4); a12 = malloc((m->size*m->size*sizeof(int))/4); a21 = malloc((m->size*m->size*sizeof(int))/4); a22 = malloc((m->size*m->size*sizeof(int))/4); b11 = malloc((m->size*m->size*sizeof(int))/4); b12 = malloc((m->size*m->size*sizeof(int))/4); b21 = malloc((m->size*m->size*sizeof(int))/4); b22 = malloc((m->size*m->size*sizeof(int))/4); int index1=0; int index2=0; int index3=0; int index4=0; int i; for(i = 0 ; i < m->size*m->size ; i++){ if(i%m->size < m->size/2 ){ if((i/m->size) < m->size/2){ a11[index1]=m->a[i]; b11[index1++]=m->b[i]; }else{ a21[index2]=m->a[i]; b21[index2++]=m->b[i]; } }else{ if((i/m->size) < m->size/2){ a12[index3]=m->a[i]; b12[index3++]=m->b[i]; }else{ a22[index4]=m->a[i]; b22[index4++]=m->b[i]; } } } Matrices temp; temp.size = m->size/2; int* p1; p1 = malloc(m->size*sizeof(int)); temp.result = p1; matrixSum(a11,a22,temp.size,temp.a); matrixSum(b11,b22,temp.size,temp.b); strassen(&temp); int* p2; p2 = malloc(m->size*sizeof(int)); temp.result = p2; matrixSum(a21,a22,temp.size,temp.a); temp.b = b11; strassen(&temp); int* p3; p3 = malloc(m->size*sizeof(int)); temp.result = p3; matrixSub(b12,b22,temp.size,temp.b); temp.a = a11; strassen(&temp); }
Matrix strassen(Matrix& a, Matrix& b){ if(a.size != b.size) throw "size mismatch\n"; int n = a.size; Matrix c; c.resize(n); if(n==1){ c[0][0] = a[0][0] * b[0][0]; }else{ Matrix C[4]; { Matrix P[7]; { Matrix A[4], B[4], S[10]; partition(a, A); partition(b, B); S[0] = B[1]-B[3]; S[1] = A[0]+A[1]; S[2] = A[2]+A[3]; S[3] = B[2]-B[0]; S[4] = A[0]+A[3]; S[5] = B[0]+B[3]; S[6] = A[1]-A[3]; S[7] = B[2]+B[3]; S[8] = A[0]-A[2]; S[9] = B[0]+B[1]; P[0] = strassen(A[0], S[0]); P[1] = strassen(S[1], B[3]); P[2] = strassen(S[2], B[0]); P[3] = strassen(A[3], S[3]); P[4] = strassen(S[4], S[5]); P[5] = strassen(S[6], S[7]); P[6] = strassen(S[8], S[9]); } C[0] = P[4] + P[3] - P[1] + P[5]; C[1] = P[0] + P[1]; C[2] = P[2] + P[3]; C[3] = P[4] + P[0] - P[2] - P[6]; } for(int i=0; i<n/2; i++) for(int j=0; j<n/2; j++) c[i][j] = C[0][i][j]; for(int i=0; i<n/2; i++) for(int j=0; j<n/2; j++) c[i][j + n/2] = C[1][i][j]; for(int i=0; i<n/2; i++) for(int j=0; j<n/2; j++) c[i + n/2][j] = C[2][i][j]; for(int i=0; i<n/2; i++) for(int j=0; j<n/2; j++) c[i + n/2][j + n/2] = C[3][i][j]; } return c; }
int main (int argn, char** argv) { #if 1 int m, n, _n, o; int * A, * B, * C, * D; double t; if (NULL == (A = read_matrix(&m,&n))) { printf("Could not read 1st matrix.\n"); return 1; } if (NULL == (B = read_matrix(&_n,&o))) { printf("Could not read 2nd matrix.\n"); return 1; } else if (_n!=n) { printf("Matrix sizes not compatible.\n"); return 1; } #if PRINT print_matrix(A, m, n); print_matrix(B, n, o); printf("Strassen...\n"); #endif t=clock(); if (NULL == (C = strassen(A,B,m,n,o))) { printf("Multiplication failed.\n"); return 1; } #if PRINT printf("Time : %.3fs\n",(clock()-t)/CLOCKS_PER_SEC); print_matrix(C, m, o); #else printf((MULT_NAIVE)?"%.4f,":"%.4f\n",(clock()-t)/CLOCKS_PER_SEC); #endif #if PRINT printf("Naive mult...\n"); #endif #if MULT_NAIVE==1 t=clock(); if (NULL == (D = naive_mult(A,B,m,n,o))) { printf("Multiplication failed.\n"); return 2; } #if PRINT printf("Time : %.3fs\n",(clock()-t)/CLOCKS_PER_SEC); print_matrix(D, m, o); if (matrix_equal(C,D,m,o)) printf("1 : Okay\n"); else printf("0 : Not Okay\n"); #else printf("%.4f\n",(clock()-t)/CLOCKS_PER_SEC); #endif #endif free(A); free(B); free(C); free(D); A=B=C=D=0; #endif return 0; }