returnValue ExportGaussElim::setupSolveReuse( ExportFunction& _solveReuse, ExportFunction& _solveTriangular, ExportVariable& _bPerm ) { uint run1, run2; if (nRightHandSides > 0) return ACADOERROR(RET_INVALID_OPTION); for( run1 = 0; run1 < dim; run1++ ) { _solveReuse << _bPerm.get( run1,0 ) << " = b[" << rk_perm.getFullName() << "[" << toString( run1 ) << "]];\n"; } for( run2 = 1; run2 < dim; run2++ ) { // row run2 for( run1 = 0; run1 < run2; run1++ ) { // column run1 _solveReuse << _bPerm.get( run2,0 ) << " += A[" << toString( run2*dim+run1 ) << "]*" << _bPerm.getFullName() << "[" << toString( run1 ) << "];\n"; } _solveReuse.addLinebreak(); } _solveReuse.addLinebreak(); _solveReuse.addFunctionCall( _solveTriangular, A, _bPerm ); _solveReuse.addStatement( b == _bPerm ); return SUCCESSFUL_RETURN; }
returnValue ExportGaussElim::setupSolve( ExportFunction& _solve, ExportFunction& _solveTriangular, ExportVariable& _swap, ExportVariable& _determinant, const string& absF ) { uint run1, run2, run3; ExportIndex i( "i" ); _solve.addIndex( i ); ExportIndex j( "j" ); ExportIndex k( "k" ); ExportVariable indexMax( "indexMax", 1, 1, INT, ACADO_LOCAL, true ); ExportVariable intSwap( "intSwap", 1, 1, INT, ACADO_LOCAL, true ); ExportVariable valueMax( "valueMax", 1, 1, REAL, ACADO_LOCAL, true ); ExportVariable temp( "temp", 1, 1, REAL, ACADO_LOCAL, true ); if( !UNROLLING ) { _solve.addIndex( j ); _solve.addIndex( k ); _solve.addDeclaration( indexMax ); if( REUSE ) _solve.addDeclaration( intSwap ); _solve.addDeclaration( valueMax ); _solve.addDeclaration( temp ); } if (nRightHandSides > 0) return ACADOERROR(RET_INVALID_OPTION); // initialise rk_perm (the permutation vector) if( REUSE ) { ExportForLoop loop1( i,0,dim ); loop1 << rk_perm.get( 0,i ) << " = " << i.getName() << ";\n"; _solve.addStatement( loop1 ); } _solve.addStatement( _determinant == 1 ); if( UNROLLING || dim <= 5 ) { // Start the factorization: for( run1 = 0; run1 < (dim-1); run1++ ) { // Search for pivot in column run1: for( run2 = run1+1; run2 < dim; run2++ ) { // add the test (if or else if): stringstream test; if( run2 == (run1+1) ) { test << "if("; } else { test << "else if("; } test << absF << "(A[" << toString( run2*dim+run1 ) << "]) > " << absF << "(A[" << toString( run1*dim+run1 ) << "])"; for( run3 = run1+1; run3 < dim; run3++ ) { if( run3 != run2) { test << " && " << absF << "(A[" << toString( run2*dim+run1 ) << "]) > " << absF << "(A[" << toString( run3*dim+run1 ) << "])"; } } test << ") {\n"; _solve.addStatement( test.str() ); // do the row swaps: // for A: for( run3 = 0; run3 < dim; run3++ ) { _solve.addStatement( _swap == A.getSubMatrix( run1,run1+1,run3,run3+1 ) ); _solve.addStatement( A.getSubMatrix( run1,run1+1,run3,run3+1 ) == A.getSubMatrix( run2,run2+1,run3,run3+1 ) ); _solve.addStatement( A.getSubMatrix( run2,run2+1,run3,run3+1 ) == _swap ); } // for b: _solve.addStatement( _swap == b.getRow( run1 ) ); _solve.addStatement( b.getRow( run1 ) == b.getRow( run2 ) ); _solve.addStatement( b.getRow( run2 ) == _swap ); if( REUSE ) { // rk_perm also needs to be updated if it needs to be possible to reuse the factorization _solve.addStatement( intSwap == rk_perm.getCol( run1 ) ); _solve.addStatement( rk_perm.getCol( run1 ) == rk_perm.getCol( run2 ) ); _solve.addStatement( rk_perm.getCol( run2 ) == intSwap ); } _solve.addStatement( "}\n" ); } // potentially needed row swaps are done _solve.addLinebreak(); // update of the next rows: for( run2 = run1+1; run2 < dim; run2++ ) { _solve << "A[" << toString( run2*dim+run1 ) << "] = -A[" << toString( run2*dim+run1 ) << "]/A[" << toString( run1*dim+run1 ) << "];\n"; _solve.addStatement( A.getSubMatrix( run2,run2+1,run1+1,dim ) += A.getSubMatrix( run2,run2+1,run1,run1+1 ) * A.getSubMatrix( run1,run1+1,run1+1,dim ) ); _solve.addStatement( b.getRow( run2 ) += A.getSubMatrix( run2,run2+1,run1,run1+1 ) * b.getRow( run1 ) ); _solve.addLinebreak(); } _solve.addStatement( _determinant == _determinant*A.getSubMatrix(run1,run1+1,run1,run1+1) ); _solve.addLinebreak(); } _solve.addStatement( _determinant == _determinant*A.getSubMatrix(dim-1,dim,dim-1,dim) ); _solve.addLinebreak(); } else { // without UNROLLING: _solve << "for( i=0; i < (" << toString( dim-1 ) << "); i++ ) {\n"; _solve << " indexMax = i;\n"; _solve << " valueMax = " << absF << "(A[i*" << toString( dim ) << "+i]);\n"; _solve << " for( j=(i+1); j < " << toString( dim ) << "; j++ ) {\n"; _solve << " temp = " << absF << "(A[j*" << toString( dim ) << "+i]);\n"; _solve << " if( temp > valueMax ) {\n"; _solve << " indexMax = j;\n"; _solve << " valueMax = temp;\n"; _solve << " }\n"; _solve << " }\n"; _solve << " if( indexMax > i ) {\n"; ExportForLoop loop2( k,0,dim ); loop2 << " " << _swap.getFullName() << " = A[i*" << toString( dim ) << "+" << k.getName() << "];\n"; loop2 << " A[i*" << toString( dim ) << "+" << k.getName() << "] = A[indexMax*" << toString( dim ) << "+" << k.getName() << "];\n"; loop2 << " A[indexMax*" << toString( dim ) << "+" << k.getName() << "] = " << _swap.getFullName() << ";\n"; _solve.addStatement( loop2 ); _solve << " " << _swap.getFullName() << " = b[i];\n"; _solve << " b[i] = b[indexMax];\n"; _solve << " b[indexMax] = " << _swap.getFullName() << ";\n"; if( REUSE ) { _solve << " " << intSwap.getFullName() << " = " << rk_perm.getFullName() << "[i];\n"; _solve << " " << rk_perm.getFullName() << "[i] = " << rk_perm.getFullName() << "[indexMax];\n"; _solve << " " << rk_perm.getFullName() << "[indexMax] = " << intSwap.getFullName() << ";\n"; } _solve << " }\n"; _solve << " " << _determinant.getFullName() << " *= A[i*" << toString( dim ) << "+i];\n"; _solve << " for( j=i+1; j < " << toString( dim ) << "; j++ ) {\n"; _solve << " A[j*" << toString( dim ) << "+i] = -A[j*" << toString( dim ) << "+i]/A[i*" << toString( dim ) << "+i];\n"; _solve << " for( k=i+1; k < " << toString( dim ) << "; k++ ) {\n"; _solve << " A[j*" << toString( dim ) << "+k] += A[j*" << toString( dim ) << "+i] * A[i*" << toString( dim ) << "+k];\n"; _solve << " }\n"; _solve << " b[j] += A[j*" << toString( dim ) << "+i] * b[i];\n"; _solve << " }\n"; _solve << "}\n"; _solve << _determinant.getFullName() << " *= A[" << toString( (dim-1)*dim+(dim-1) ) << "];\n"; } _solve << _determinant.getFullName() << " = " << absF << "(" << _determinant.getFullName() << ");\n"; _solve.addFunctionCall( _solveTriangular, A, b ); return SUCCESSFUL_RETURN; }