void expand(const SXMatrix& ex2, SXMatrix &ww, SXMatrix& tt){ casadi_assert(ex2.scalar()); SX ex = ex2.toScalar(); // Terms, weights and indices of the nodes that are already expanded std::vector<std::vector<SXNode*> > terms; std::vector<std::vector<double> > weights; std::map<SXNode*,int> indices; // Stack of nodes that are not yet expanded std::stack<SXNode*> to_be_expanded; to_be_expanded.push(ex.get()); while(!to_be_expanded.empty()){ // as long as there are nodes to be expanded // Check if the last element on the stack is already expanded if (indices.find(to_be_expanded.top()) != indices.end()){ // Remove from stack to_be_expanded.pop(); continue; } // Weights and terms std::vector<double> w; // weights std::vector<SXNode*> f; // terms if(to_be_expanded.top()->isConstant()){ // constant nodes are seen as multiples of one w.push_back(to_be_expanded.top()->getValue()); f.push_back(casadi_limits<SX>::one.get()); } else if(to_be_expanded.top()->isSymbolic()){ // symbolic nodes have weight one and itself as factor w.push_back(1); f.push_back(to_be_expanded.top()); } else { // binary node casadi_assert(to_be_expanded.top()->hasDep()); // make sure that the node is binary // Check if addition, subtracton or multiplication SXNode* node = to_be_expanded.top(); // If we have a binary node that we can factorize if(node->getOp() == OP_ADD || node->getOp() == OP_SUB || (node->getOp() == OP_MUL && (node->dep(0)->isConstant() || node->dep(1)->isConstant()))){ // Make sure that both children are factorized, if not - add to stack if (indices.find(node->dep(0).get()) == indices.end()){ to_be_expanded.push(node->dep(0).get()); continue; } if (indices.find(node->dep(1).get()) == indices.end()){ to_be_expanded.push(node->dep(1).get()); continue; } // Get indices of children int ind1 = indices[node->dep(0).get()]; int ind2 = indices[node->dep(1).get()]; // If multiplication if(node->getOp() == OP_MUL){ double fac; if(node->dep(0)->isConstant()){ // Multiplication where the first factor is a constant fac = node->dep(0)->getValue(); f = terms[ind2]; w = weights[ind2]; } else { // Multiplication where the second factor is a constant fac = node->dep(1)->getValue(); f = terms[ind1]; w = weights[ind1]; } for(int i=0; i<w.size(); ++i) w[i] *= fac; } else { // if addition or subtraction if(node->getOp() == OP_ADD){ // Addition: join both sums f = terms[ind1]; f.insert(f.end(), terms[ind2].begin(), terms[ind2].end()); w = weights[ind1]; w.insert(w.end(), weights[ind2].begin(), weights[ind2].end()); } else { // Subtraction: join both sums with negative weights for second term f = terms[ind1]; f.insert(f.end(), terms[ind2].begin(), terms[ind2].end()); w = weights[ind1]; w.reserve(f.size()); for(int i=0; i<weights[ind2].size(); ++i) w.push_back(-weights[ind2][i]); } // Eliminate multiple elements std::vector<double> w_new; w_new.reserve(w.size()); // weights std::vector<SXNode*> f_new; f_new.reserve(f.size()); // terms std::map<SXNode*,int> f_ind; // index in f_new for(int i=0; i<w.size(); i++){ // Try to locate the node std::map<SXNode*,int>::iterator it = f_ind.find(f[i]); if(it == f_ind.end()){ // if the term wasn't found w_new.push_back(w[i]); f_new.push_back(f[i]); f_ind[f[i]] = f_new.size()-1; } else { // if the term already exists w_new[it->second] += w[i]; // just add the weight } } w = w_new; f = f_new; } } else { // if we have a binary node that we cannot factorize // By default, w.push_back(1); f.push_back(node); } } // Save factorization of the node weights.push_back(w); terms.push_back(f); indices[to_be_expanded.top()] = terms.size()-1; // Remove node from stack to_be_expanded.pop(); } // Save expansion to output int thisind = indices[ex.get()]; ww = SXMatrix(weights[thisind]); vector<SX> termsv(terms[thisind].size()); for(int i=0; i<termsv.size(); ++i) termsv[i] = SX::create(terms[thisind][i]); tt = SXMatrix(termsv); }