예제 #1
0
 bool SXElement::isEqual(const SXElement& ex, int depth) const {
   if (node==ex.get())
     return true;
   else if (depth>0)
     return node->isEqual(ex.get(), depth);
   else
     return false;
 }
예제 #2
0
 /** \brief  Create a binary expression */
 inline static SXElement create(unsigned char op, const SXElement& dep0, const SXElement& dep1) {
   if (dep0.isConstant() && dep1.isConstant()) {
     // Evaluate constant
     double dep0_val = dep0.getValue();
     double dep1_val = dep1.getValue();
     double ret_val;
     casadi_math<double>::fun(op, dep0_val, dep1_val, ret_val);
     return ret_val;
   } else {
     // Expression containing free variables
     return SXElement::create(new BinarySX(op, dep0, dep1));
   }
 }
예제 #3
0
  void SXFunctionInternal::evalSX(const SXElement** arg, SXElement** res,
                                  int* iw, SXElement* w) {
    if (verbose()) userOut() << "SXFunctionInternal::evalSXsparse begin" << endl;

    // Iterator to the binary operations
    vector<SXElement>::const_iterator b_it=operations_.begin();

    // Iterator to stack of constants
    vector<SXElement>::const_iterator c_it = constants_.begin();

    // Iterator to free variables
    vector<SXElement>::const_iterator p_it = free_vars_.begin();

    // Evaluate algorithm
    if (verbose()) {
      userOut() << "SXFunctionInternal::evalSXsparse evaluating algorithm forward" << endl;
    }
    for (vector<AlgEl>::const_iterator it = algorithm_.begin(); it!=algorithm_.end(); ++it) {
      switch (it->op) {
      case OP_INPUT:
        w[it->i0] = arg[it->i1]==0 ? 0 : arg[it->i1][it->i2];
        break;
      case OP_OUTPUT:
        if (res[it->i0]!=0) res[it->i0][it->i2] = w[it->i1];
        break;
      case OP_CONST:
        w[it->i0] = *c_it++;
        break;
      case OP_PARAMETER:
        w[it->i0] = *p_it++; break;
      default:
        {
          // Evaluate the function to a temporary value
          // (as it might overwrite the children in the work vector)
          SXElement f;
          switch (it->op) {
            CASADI_MATH_FUN_BUILTIN(w[it->i1], w[it->i2], f)
          }

          // If this new expression is identical to the expression used
          // to define the algorithm, then reuse
          const int depth = 2; // NOTE: a higher depth could possibly give more savings
          f.assignIfDuplicate(*b_it++, depth);

          // Finally save the function value
          w[it->i0] = f;
        }
      }
    }
    if (verbose()) userOut() << "SXFunctionInternal::evalSX end" << endl;
  }
예제 #4
0
  SXElement SXElement::__div__(const SXElement& y) const {
    // Only simplifications that do not result in extra nodes area allowed

    if (!CasadiOptions::simplification_on_the_fly) return BinarySX::create(OP_DIV, *this, y);

    if (y->isZero()) // term2 is zero
      return casadi_limits<SXElement>::nan;
    else if (node->isZero()) // term1 is zero
      return 0;
    else if (y->isOne()) // term2 is one
      return *this;
    else if (y->isMinusOne())
      return -(*this);
    else if (isEqual(y, SXNode::eq_depth_)) // terms are equal
      return 1;
    else if (isDoubled() && y.isEqual(2))
      return node->dep(0);
    else if (isOp(OP_MUL) && y.isEqual(node->dep(0), SXNode::eq_depth_))
      return node->dep(1);
    else if (isOp(OP_MUL) && y.isEqual(node->dep(1), SXNode::eq_depth_))
      return node->dep(0);
    else if (node->isOne())
      return y.inv();
    else if (y.hasDep() && y.getOp()==OP_INV)
      return (*this)*y.inv();
    else if (isDoubled() && y.isDoubled())
      return node->dep(0) / y->dep(0);
    else if (y.isConstant() && hasDep() && getOp()==OP_DIV && getDep(1).isConstant() &&
            y.getValue()*getDep(1).getValue()==1) // (x/5)/0.2
      return getDep(0);
    else if (y.hasDep() && y.getOp()==OP_MUL &&
            y.getDep(1).isEqual(*this, SXNode::eq_depth_)) // x/(2*x) = 1/2
      return BinarySX::create(OP_DIV, 1, y.getDep(0));
    else if (hasDep() && getOp()==OP_NEG &&
            getDep(0).isEqual(y, SXNode::eq_depth_))      // (-x)/x = -1
      return -1;
    else if (y.hasDep() && y.getOp()==OP_NEG &&
            y.getDep(0).isEqual(*this, SXNode::eq_depth_))      // x/(-x) = 1
      return -1;
    else if (y.hasDep() && y.getOp()==OP_NEG && hasDep() &&
            getOp()==OP_NEG && getDep(0).isEqual(y.getDep(0), SXNode::eq_depth_))  // (-x)/(-x) = 1
      return 1;
    else if (isOp(OP_DIV) && y.isEqual(node->dep(0), SXNode::eq_depth_))
      return node->dep(1).inv();
    else // create a new branch
      return BinarySX::create(OP_DIV, *this, y);
  }
예제 #5
0
  SXElement SXElement::__mul__(const SXElement& y) const {

    if (!CasadiOptions::simplification_on_the_fly) return BinarySX::create(OP_MUL, *this, y);

    // Only simplifications that do not result in extra nodes area allowed
    if (y.isEqual(*this, SXNode::eq_depth_))
      return sq();
    else if (!isConstant() && y.isConstant())
      return y.__mul__(*this);
    else if (node->isZero() || y->isZero()) // one of the terms is zero
      return 0;
    else if (node->isOne()) // term1 is one
      return y;
    else if (y->isOne()) // term2 is one
      return *this;
    else if (y->isMinusOne())
      return -(*this);
    else if (node->isMinusOne())
      return -y;
    else if (y.hasDep() && y.getOp()==OP_INV)
      return (*this)/y.inv();
    else if (hasDep() && getOp()==OP_INV)
      return y/inv();
    else if (isConstant() && y.hasDep() && y.getOp()==OP_MUL && y.getDep(0).isConstant() &&
            getValue()*y.getDep(0).getValue()==1) // 5*(0.2*x) = x
      return y.getDep(1);
    else if (isConstant() && y.hasDep() && y.getOp()==OP_DIV && y.getDep(1).isConstant() &&
            getValue()==y.getDep(1).getValue()) // 5*(x/5) = x
      return y.getDep(0);
    else if (hasDep() && getOp()==OP_DIV && getDep(1).isEqual(y, SXNode::eq_depth_)) // ((2/x)*x)
      return getDep(0);
    else if (y.hasDep() && y.getOp()==OP_DIV &&
            y.getDep(1).isEqual(*this, SXNode::eq_depth_)) // ((2/x)*x)
      return y.getDep(0);
    else     // create a new branch
      return BinarySX::create(OP_MUL, *this, y);
  }
예제 #6
0
  SXElement SXElement::__sub__(const SXElement& y) const {
    // Only simplifications that do not result in extra nodes area allowed

    if (!CasadiOptions::simplification_on_the_fly) return BinarySX::create(OP_SUB, *this, y);

    if (y->isZero()) // term2 is zero
      return *this;
    if (node->isZero()) // term1 is zero
      return -y;
    if (isEqual(y, SXNode::eq_depth_)) // the terms are equal
      return 0;
    else if (y.hasDep() && y.getOp()==OP_NEG) // x - (-y) -> x + y
      return __add__(-y);
    else if (hasDep() && getOp()==OP_ADD && getDep(1).isEqual(y, SXNode::eq_depth_))
      return getDep(0);
    else if (hasDep() && getOp()==OP_ADD && getDep(0).isEqual(y, SXNode::eq_depth_))
      return getDep(1);
    else if (y.hasDep() && y.getOp()==OP_ADD && isEqual(y.getDep(1), SXNode::eq_depth_))
      return -y.getDep(0);
    else if (y.hasDep() && y.getOp()==OP_ADD && isEqual(y.getDep(0), SXNode::eq_depth_))
      return -y.getDep(1);
    else // create a new branch
      return BinarySX::create(OP_SUB, *this, y);
  }
예제 #7
0
  SXElement SXElement::__add__(const SXElement& y) const {
    // NOTE: Only simplifications that do not result in extra nodes area allowed

    if (!CasadiOptions::simplification_on_the_fly) return BinarySX::create(OP_ADD, *this, y);

    if (node->isZero())
      return y;
    else if (y->isZero()) // term2 is zero
      return *this;
    else if (y.hasDep() && y.getOp()==OP_NEG) // x + (-y) -> x - y
      return __sub__(-y);
    else if (hasDep() && getOp()==OP_NEG) // (-x) + y -> y - x
      return y.__sub__(getDep());
    else if (hasDep() && getOp()==OP_MUL &&
            y.hasDep() && y.getOp()==OP_MUL &&
            getDep(0).isConstant() && getDep(0).getValue()==0.5 &&
            y.getDep(0).isConstant() && y.getDep(0).getValue()==0.5 &&
            y.getDep(1).isEqual(getDep(1), SXNode::eq_depth_)) // 0.5x+0.5x = x
      return getDep(1);
    else if (hasDep() && getOp()==OP_DIV &&
            y.hasDep() && y.getOp()==OP_DIV &&
            getDep(1).isConstant() && getDep(1).getValue()==2 &&
            y.getDep(1).isConstant() && y.getDep(1).getValue()==2 &&
            y.getDep(0).isEqual(getDep(0), SXNode::eq_depth_)) // x/2+x/2 = x
      return getDep(0);
    else if (hasDep() && getOp()==OP_SUB && getDep(1).isEqual(y, SXNode::eq_depth_))
      return getDep(0);
    else if (y.hasDep() && y.getOp()==OP_SUB && isEqual(y.getDep(1), SXNode::eq_depth_))
      return y.getDep(0);
    else // create a new branch
      return BinarySX::create(OP_ADD, *this, y);
  }