int Switch::eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w, void* mem) const { // Input and output buffers const SXElem** arg1 = arg + n_in_; SXElem** res1 = res + n_out_; // Extra memory needed for chaining if_else calls std::vector<SXElem> w_extra(nnz_out()); std::vector<SXElem*> res_tempv(n_out_); SXElem** res_temp = get_ptr(res_tempv); for (casadi_int k=0; k<f_.size()+1; ++k) { // Local work vector SXElem* wl = w; // Local work vector SXElem* wll = get_ptr(w_extra); if (k==0) { // For the default case, redirect the temporary results to res copy_n(res, n_out_, res_temp); } else { // For the other cases, store the temporary results for (casadi_int i=0; i<n_out_; ++i) { res_temp[i] = wll; wll += nnz_out(i); } } copy_n(arg+1, n_in_-1, arg1); copy_n(res_temp, n_out_, res1); const Function& fk = k==0 ? f_def_ : f_[k-1]; // Project arguments with different sparsity for (casadi_int i=0; i<n_in_-1; ++i) { if (arg1[i]) { const Sparsity& f_sp = fk.sparsity_in(i); const Sparsity& sp = sparsity_in_[i+1]; if (f_sp!=sp) { SXElem *t = wl; wl += f_sp.nnz(); // t is non-const casadi_project(arg1[i], sp, t, f_sp, wl); arg1[i] = t; } } } // Temporary memory for results with different sparsity for (casadi_int i=0; i<n_out_; ++i) { if (res1[i]) { const Sparsity& f_sp = fk.sparsity_out(i); const Sparsity& sp = sparsity_out_[i]; if (f_sp!=sp) { res1[i] = wl; wl += f_sp.nnz();} } } // Evaluate the corresponding function if (fk(arg1, res1, iw, wl, 0)) return 1; // Project results with different sparsity for (casadi_int i=0; i<n_out_; ++i) { if (res1[i]) { const Sparsity& f_sp = fk.sparsity_out(i); const Sparsity& sp = sparsity_out_[i]; if (f_sp!=sp) casadi_project(res1[i], f_sp, res_temp[i], sp, wl); } } if (k>0) { // output the temporary results via an if_else SXElem cond = k-1==arg[0][0]; for (casadi_int i=0; i<n_out_; ++i) { if (res[i]) { for (casadi_int j=0; j<nnz_out(i); ++j) { res[i][j] = if_else(cond, res_temp[i][j], res[i][j]); } } } } } return 0; }
/** \brief Number of nodes in the algorithm */ virtual int n_nodes() const { return algorithm_.size() - nnz_out();}