void Sgemm(float alpha, const CMatrix& A, const CMatrix& B, float beta, CMatrix& C) { if( A.GetColCount() != B.GetRowCount() ) { throw string("In Sgemm():执行相乘的两个矩阵维数不满足相乘的条件!"); } if(C.GetRowCount() < A.GetRowCount() || C.GetColCount() < B.GetColCount()) { throw string("目标矩阵不能容纳乘积矩阵\n"); } for(int i=0; i < A.GetRowCount(); i++) { //printf("Row = %d/%d\n", i, m_nRow); for(int j=0; j < B.GetColCount(); j++) { VALTYPE sum = 0; for(int m=0; m < A.GetColCount(); m++) { sum += alpha * A.m_pTMatrix(i,m) * B.m_pTMatrix(m,j); //printf("%f\t%f\n", A.m_pTMatrix(i,m), B.m_pTMatrix(m,j)); } C.m_pTMatrix(i,j) = sum + beta * C.m_pTMatrix(i,j); //printf("---\n"); } } }
CMatrix MergeMatrix(CMatrix& cMatrixA,CMatrix& cMatrixB) { // 条件检测 if( cMatrixA.GetRowCount () != cMatrixB.GetRowCount () ) { throw string("参与合并的两个矩阵的行数不相等!"); return cMatrixA; // return invalid value } CMatrix cMatrix(cMatrixA.GetRowCount (),cMatrixA.GetColCount () + cMatrixB.GetColCount ()); for(int i=0; i < cMatrixA.GetRowCount (); i++) { for(int j=0; j < cMatrixA.GetColCount (); j++) { cMatrix.m_pTMatrix (i, j) = cMatrixA.m_pTMatrix(i, j); } for(int k=0; k < cMatrixB.GetColCount (); k++) { cMatrix.m_pTMatrix (i, cMatrixA.GetColCount () + k) = cMatrixB.m_pTMatrix(i, k); } } return cMatrix; }
CMatrix::CMatrix(const CMatrix& cMatrixB) : m_pTMatrix(cMatrixB.GetRowCount(), cMatrixB.GetColCount()) { // Initialize the variable m_nRow = cMatrixB.m_nRow ; m_nCol = cMatrixB.m_nCol ; m_pTMatrix = cMatrixB.m_pTMatrix ; }
CMatrix operator - (double nValue, const CMatrix& cMatrixB) { CMatrix cMatrix(cMatrixB.GetRowCount(), cMatrixB.GetColCount()) ; for(int i=0; i < cMatrix.GetRowCount (); i++) { for(int j=0; j < cMatrix.GetColCount (); j++) { cMatrix.m_pTMatrix (i, j) = nValue - cMatrixB.m_pTMatrix(i, j); } } return cMatrix; }
void CMatrix::CopyTo(CMatrix& matrix, int startRow, int startCol) { if(startCol + m_nCol > matrix.GetColCount() || startRow + m_nRow > matrix.GetRowCount()) { throw string("目标矩阵不能容纳源矩阵"); } for(int i = 0; i<m_nRow;i++) { for(int j=0;j<m_nCol;j++) { matrix.m_pTMatrix(startRow+i, startCol + j) = m_pTMatrix(i, j); } } }