SX SX::__mul__(const SX& y) const{ // Only simplifications that do not result in extra nodes area allowed if(!isConstant() && y.isConstant()) return y.__mul__(*this); else if(node->isZero() || y->isZero()) // one of the terms is zero return 0; else if(node->isOne()) // term1 is one return y; else if(y->isOne()) // term2 is one return *this; else if(y->isMinusOne()) return -(*this); else if(node->isMinusOne()) return -y; else if(y.hasDep() && y.getOp()==OP_INV) return (*this)/y.inv(); else if(hasDep() && getOp()==OP_INV) return y/inv(); else if(isConstant() && y.hasDep() && y.getOp()==OP_MUL && y.getDep(0).isConstant() && getValue()*y.getDep(0).getValue()==1) // 5*(0.2*x) = x return y.getDep(1); else if(isConstant() && y.hasDep() && y.getOp()==OP_DIV && y.getDep(1).isConstant() && getValue()==y.getDep(1).getValue()) // 5*(x/5) = x return y.getDep(0); else if(hasDep() && getOp()==OP_DIV && getDep(1).isEquivalent(y)) // ((2/x)*x) return getDep(0); else if(y.hasDep() && y.getOp()==OP_DIV && y.getDep(1).isEquivalent(*this)) // ((2/x)*x) return y.getDep(0); else // create a new branch return BinarySX::create(OP_MUL,*this,y); }
SX SX::__add__(const SX& y) const{ // NOTE: Only simplifications that do not result in extra nodes area allowed if(node->isZero()) return y; else if(y->isZero()) // term2 is zero return *this; else if(y.hasDep() && y.getOp()==OP_NEG) // x + (-y) -> x - y return __sub__(-y); else if(hasDep() && getOp()==OP_NEG) // (-x) + y -> y - x return y.__sub__(getDep()); else if(hasDep() && getOp()==OP_MUL && y.hasDep() && y.getOp()==OP_MUL && getDep(0).isConstant() && getDep(0).getValue()==0.5 && y.getDep(0).isConstant() && y.getDep(0).getValue()==0.5 && y.getDep(1).isEquivalent(getDep(1))) // 0.5x+0.5x = x return getDep(1); else if(hasDep() && getOp()==OP_DIV && y.hasDep() && y.getOp()==OP_DIV && getDep(1).isConstant() && getDep(1).getValue()==2 && y.getDep(1).isConstant() && y.getDep(1).getValue()==2 && y.getDep(0).isEquivalent(getDep(0))) // x/2+x/2 = x return getDep(0); else if(hasDep() && getOp()==OP_SUB && getDep(1).isEquivalent(y)) return getDep(0); else if(y.hasDep() && y.getOp()==OP_SUB && isEquivalent(y.getDep(1))) return y.getDep(0); else // create a new branch return BinarySX::create(OP_ADD,*this, y); }
bool SX::isEquivalent(const SX& y, int depth) const{ if (isEqual(y)) return true; if (isConstant() && y.isConstant()) return y.getValue()==getValue(); if (depth==0) return false; if (hasDep() && y.hasDep() && getOp()==y.getOp()) { if (getDep(0).isEquivalent(y.getDep(0),depth-1) && getDep(1).isEquivalent(y.getDep(1),depth-1)) return true; return (operation_checker<CommChecker>(getOp()) && getDep(0).isEquivalent(y.getDep(1),depth-1) && getDep(1).isEquivalent(y.getDep(0),depth-1)); } return false; }
SX SX::__sub__(const SX& y) const{ // Only simplifications that do not result in extra nodes area allowed if(y->isZero()) // term2 is zero return *this; if(node->isZero()) // term1 is zero return -y; if(isEquivalent(y)) // the terms are equal return 0; else if(y.hasDep() && y.getOp()==OP_NEG) // x - (-y) -> x + y return __add__(-y); else if(hasDep() && getOp()==OP_ADD && getDep(1).isEquivalent(y)) return getDep(0); else if(hasDep() && getOp()==OP_ADD && getDep(0).isEquivalent(y)) return getDep(1); else if(y.hasDep() && y.getOp()==OP_ADD && isEquivalent(y.getDep(1))) return -y.getDep(0); else if(y.hasDep() && y.getOp()==OP_ADD && isEquivalent(y.getDep(0))) return -y.getDep(1); else // create a new branch return BinarySX::create(OP_SUB,*this,y); }
SX SX::__div__(const SX& y) const{ // Only simplifications that do not result in extra nodes area allowed if(y->isZero()) // term2 is zero return casadi_limits<SX>::nan; else if(node->isZero()) // term1 is zero return 0; else if(y->isOne()) // term2 is one return *this; else if(isEquivalent(y)) // terms are equal return 1; else if(isDoubled() && y.isEqual(2)) return node->dep(0); else if(isOp(OP_MUL) && y.isEquivalent(node->dep(0))) return node->dep(1); else if(isOp(OP_MUL) && y.isEquivalent(node->dep(1))) return node->dep(0); else if(node->isOne()) return y.inv(); else if(y.hasDep() && y.getOp()==OP_INV) return (*this)*y.inv(); else if(isDoubled() && y.isDoubled()) return node->dep(0) / y->dep(0); else if(y.isConstant() && hasDep() && getOp()==OP_DIV && getDep(1).isConstant() && y.getValue()*getDep(1).getValue()==1) // (x/5)/0.2 return getDep(0); else if(y.hasDep() && y.getOp()==OP_MUL && y.getDep(1).isEquivalent(*this)) // x/(2*x) = 1/2 return BinarySX::create(OP_DIV,1,y.getDep(0)); else if(hasDep() && getOp()==OP_NEG && getDep(0).isEquivalent(y)) // (-x)/x = -1 return -1; else if(y.hasDep() && y.getOp()==OP_NEG && y.getDep(0).isEquivalent(*this)) // x/(-x) = 1 return -1; else if(y.hasDep() && y.getOp()==OP_NEG && hasDep() && getOp()==OP_NEG && getDep(0).isEquivalent(y.getDep(0))) // (-x)/(-x) = 1 return 1; else if(isOp(OP_DIV) && y.isEquivalent(node->dep(0))) return node->dep(1).inv(); else // create a new branch return BinarySX::create(OP_DIV,*this,y); }