// // Subroutine of explain // A step of explanation for x and y // void Egraph::expExplainAlongPath ( Enode * x, Enode * y ) { Enode * v = expHighestNode( x ); Enode * to = expHighestNode( y ); while ( v != to ) { Enode * p = v->getExpParent( ); assert( p != NULL ); Enode * r = v->getExpReason( ); // If it is not a congruence edge if ( r != NULL ) { if ( !isDup1( r ) ) { assert( r->isTerm( ) ); explanation.push_back( r ); storeDup1( r ); } } // Otherwise it is a congruence edge // This means that the edge is linking nodes // like (v)f(a1,...,an) (p)f(b1,...,bn), and that // a1,...,an = b1,...bn. For each pair ai,bi // we have therefore to compute the reasons else { assert( v->getCar( ) == p->getCar( ) ); assert( v->getArity( ) == p->getArity( ) ); expEnqueueArguments( v, p ); } #ifdef PRODUCE_PROOF if ( config.produce_inter > 0 && config.logic != QF_AX ) { cgraph.addCNode( v ); cgraph.addCNode( p ); cgraph.addCEdge( v, p, r ); } #endif expUnion( v, p ); v = expHighestNode( p ); } }
double eval_enode(Enode * const e, unordered_map<Enode*, double> const & var_map) { if (e->isVar()) { auto const it = var_map.find(e); if (it == var_map.cend()) { throw runtime_error("variable not found"); } else { // Variable is found in var_map return it->second; } } else if (e->isConstant()) { double const v = e->getValue(); return v; } else if (e->isSymb()) { throw runtime_error("eval_enode: Symb"); } else if (e->isNumb()) { throw runtime_error("eval_enode: Numb"); } else if (e->isTerm()) { assert(e->getArity() >= 1); enodeid_t id = e->getCar()->getId(); double ret = 0.0; Enode * tmp = e; switch (id) { case ENODE_ID_PLUS: ret = eval_enode(tmp->get1st(), var_map); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = ret + eval_enode(tmp->getCar(), var_map); tmp = tmp->getCdr(); } return ret; case ENODE_ID_MINUS: ret = eval_enode(tmp->get1st(), var_map); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = ret - eval_enode(tmp->getCar(), var_map); tmp = tmp->getCdr(); } return ret; case ENODE_ID_UMINUS: ret = eval_enode(tmp->get1st(), var_map); assert(tmp->getArity() == 1); return (- ret); case ENODE_ID_TIMES: ret = eval_enode(tmp->get1st(), var_map); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = ret * eval_enode(tmp->getCar(), var_map); tmp = tmp->getCdr(); } return ret; case ENODE_ID_DIV: ret = eval_enode(tmp->get1st(), var_map); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = ret / eval_enode(tmp->getCar(), var_map); tmp = tmp->getCdr(); } return ret; case ENODE_ID_ACOS: assert(e->getArity() == 1); return acos(eval_enode(e->get1st(), var_map)); case ENODE_ID_ASIN: assert(e->getArity() == 1); return asin(eval_enode(e->get1st(), var_map)); case ENODE_ID_ATAN: assert(e->getArity() == 1); return atan(eval_enode(e->get1st(), var_map)); case ENODE_ID_ATAN2: assert(e->getArity() == 2); return atan2(eval_enode(e->get1st(), var_map), eval_enode(e->get2nd(), var_map)); case ENODE_ID_MIN: assert(e->getArity() == 2); return fmin(eval_enode(e->get1st(), var_map), eval_enode(e->get2nd(), var_map)); case ENODE_ID_MAX: assert(e->getArity() == 2); return fmax(eval_enode(e->get1st(), var_map), eval_enode(e->get2nd(), var_map)); case ENODE_ID_MATAN: assert(e->getArity() == 1); throw runtime_error("eval_enode: MATAN"); case ENODE_ID_SAFESQRT: assert(e->getArity() == 1); throw runtime_error("eval_enode: SAFESQRT"); case ENODE_ID_SQRT: assert(e->getArity() == 1); return sqrt(eval_enode(e->get1st(), var_map)); case ENODE_ID_EXP: assert(e->getArity() == 1); return exp(eval_enode(e->get1st(), var_map)); case ENODE_ID_LOG: assert(e->getArity() == 1); return log(eval_enode(e->get1st(), var_map)); case ENODE_ID_POW: assert(e->getArity() == 2); return pow(eval_enode(e->get1st(), var_map), eval_enode(e->get2nd(), var_map)); case ENODE_ID_ABS: assert(e->getArity() == 1); return fabs(eval_enode(e->get1st(), var_map)); case ENODE_ID_SIN: assert(e->getArity() == 1); return sin(eval_enode(e->get1st(), var_map)); case ENODE_ID_COS: assert(e->getArity() == 1); return cos(eval_enode(e->get1st(), var_map)); case ENODE_ID_TAN: assert(e->getArity() == 1); return tan(eval_enode(e->get1st(), var_map)); case ENODE_ID_SINH: assert(e->getArity() == 1); return sinh(eval_enode(e->get1st(), var_map)); case ENODE_ID_COSH: assert(e->getArity() == 1); return cosh(eval_enode(e->get1st(), var_map)); case ENODE_ID_TANH: assert(e->getArity() == 1); return tanh(eval_enode(e->get1st(), var_map)); default: throw runtime_error("eval_enode: Unknown Term"); } } else if (e->isList()) { throw runtime_error("eval_enode: List"); } else if (e->isDef()) { throw runtime_error("eval_enode: Def"); } else if (e->isEnil()) { throw runtime_error("eval_enode: Nil"); } else { throw runtime_error("eval_enode: unknown case"); } throw runtime_error("Not implemented yet: eval_enode"); }
double deriv_enode(Enode * const e, Enode * const v, unordered_map<Enode*, double> const & var_map) { if (e == v) { return 1.0; } if (e->isVar()) { auto const it = var_map.find(e); if (it == var_map.cend()) { throw runtime_error("variable not found"); } else { // Variable is found in var_map return 0.0; } } else if (e->isConstant()) { return 0.0; } else if (e->isSymb()) { throw runtime_error("eval_enode: Symb"); } else if (e->isNumb()) { throw runtime_error("eval_enode: Numb"); } else if (e->isTerm()) { assert(e->getArity() >= 1); enodeid_t id = e->getCar()->getId(); double ret = 0.0; Enode * tmp = e; switch (id) { case ENODE_ID_PLUS: ret = deriv_enode(tmp->get1st(), v, var_map); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = ret + deriv_enode(tmp->getCar(), v, var_map); tmp = tmp->getCdr(); } return ret; case ENODE_ID_MINUS: ret = deriv_enode(tmp->get1st(), v, var_map); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = ret - deriv_enode(tmp->getCar(), v, var_map); tmp = tmp->getCdr(); } return ret; case ENODE_ID_UMINUS: ret = deriv_enode(tmp->get1st(), v, var_map); assert(tmp->getArity() == 1); return (- ret); case ENODE_ID_TIMES: { // (f * g)' = f' * g + f * g' if (tmp->getArity() != 2) { throw runtime_error("deriv_enode: only support arity = 2 case for multiplication"); } double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); double const g = eval_enode(e->get2nd(), var_map); double const g_ = deriv_enode(e->get2nd(), v, var_map); return f_ * g + f * g_; } case ENODE_ID_DIV: { // (f / g)' = (f' * g - f * g') / g^2 if (tmp->getArity() != 2) { throw runtime_error("deriv_enode: only support arity = 2 case for division"); } double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); double const g = eval_enode(e->get2nd(), var_map); double const g_ = deriv_enode(e->get2nd(), v, var_map); return (f_ * g - f * g_) / (g * g); } case ENODE_ID_ACOS: { // (acos f)' = -(1 / sqrt(1 - f^2)) f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return - (1 / sqrt(1 - f * f)) * f_; } case ENODE_ID_ASIN: { // (asin f)' = (1 / sqrt(1 - f^2)) f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return 1 / sqrt(1 - f * f) * f_; } case ENODE_ID_ATAN: { // (atan f)' = (1 / (1 + f^2)) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return 1 / (1 + f * f) * f_; } case ENODE_ID_ATAN2: { // atan2(x,y)' = -y / (x^2 + y^2) dx + x / (x^2 + y^2) dy // = (-y dx + x dy) / (x^2 + y^2) assert(e->getArity() == 2); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); double const g = eval_enode(e->get2nd(), var_map); double const g_ = deriv_enode(e->get2nd(), v, var_map); return (-g * f_ + f * g_) / (f * f + g * g); } case ENODE_ID_MIN: assert(e->getArity() == 2); throw runtime_error("deriv_enode: no support for min"); case ENODE_ID_MAX: assert(e->getArity() == 2); throw runtime_error("deriv_enode: no support for max"); case ENODE_ID_MATAN: assert(e->getArity() == 1); throw runtime_error("deriv_enode: no support for matan"); case ENODE_ID_SAFESQRT: assert(e->getArity() == 1); throw runtime_error("deriv_enode: no support for safesqrt"); case ENODE_ID_SQRT: { // (sqrt(f))' = 1/2 * 1/(sqrt(f)) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return 0.5 * 1 / sqrt(f) * f_; } case ENODE_ID_EXP: { // (exp f)' = (exp f) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return exp(f) * f_; } case ENODE_ID_LOG: { // (log f)' = f' / f assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return f_ / f; } case ENODE_ID_POW: { // (f^g)' = f^g (f' * g / f + g' * ln g) assert(e->getArity() == 2); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); double const g = eval_enode(e->get2nd(), var_map); double const g_ = deriv_enode(e->get2nd(), v, var_map); return pow(f, g) * (f_ * g / f + g_ * log(g)); } case ENODE_ID_ABS: { assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); if (f > 0) { return f_; } else { return - f_; } } case ENODE_ID_SIN: { // (sin f)' = (cos f) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return cos(f) * f_; } case ENODE_ID_COS: { // (cos f)' = - (sin f) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return - sin(f) * f_; } case ENODE_ID_TAN: { // (tan f)' = (1 + tan^2 f) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return (1 + tan(f) * tan(f)) * f_; } case ENODE_ID_SINH: { // (sinh f)' = (e^f + e^(-f))/2 * f' // = cosh(f) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return cosh(f) * f_; } case ENODE_ID_COSH: { // (cosh f)' = (e^f - e^(-f))/2 * f' // = sinh(f) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return sinh(f) * f_; } case ENODE_ID_TANH: { // (tanh f)' = (sech^2 f) * f' // = (1 - tanh(f) ^ 2) * f' assert(e->getArity() == 1); double const f = eval_enode(e->get1st(), var_map); double const f_ = deriv_enode(e->get1st(), v, var_map); return (1 - tanh(f) * tanh(f)) * f_; } default: throw runtime_error("deriv_enode: Unknown Term"); } } else if (e->isList()) { throw runtime_error("deriv_enode: List"); } else if (e->isDef()) { throw runtime_error("deriv_enode: Def"); } else if (e->isEnil()) { throw runtime_error("deriv_enode: Nil"); } else { throw runtime_error("deriv_enode: unknown case"); } throw runtime_error("Not implemented yet: deriv_enode"); }
// Translate an Enode e into ibex::ExprNode. // Note: As a side-effect, update var_map : string -> ibex::Variable // Note: Use subst map (Enode ->ibex::Interval) ExprNode const * translate_enode_to_exprnode(map<string, Variable const> & var_map, Enode * const e, unordered_map<Enode*, ibex::Interval> const & subst) { // TODO(soonhok): for the simple case such as 0 <= x or x <= 10. // Handle it as a domain specification instead of constraints. if (e->isVar()) { auto const subst_it = subst.find(e); if (subst_it != subst.cend()) { auto const i = subst_it->second; return &ExprConstant::new_scalar(i); } string const & var_name = e->getCar()->getNameFull(); auto const it = var_map.find(var_name); if (it == var_map.cend()) { // The variable is new, we need to make one. Variable v(var_name.c_str()); // double const lb = e->getLowerBound(); // double const ub = e->getUpperBound(); var_map.emplace(var_name, v); return v.symbol; } else { // Variable is found in var_map Variable const & v = it->second; return v.symbol; } } else if (e->isConstant()) { double const lb = e->getValueLowerBound(); double const ub = e->getValueUpperBound(); return &ExprConstant::new_scalar(ibex::Interval(lb, ub)); } else if (e->isSymb()) { throw logic_error("translateEnodeExprNode: Symb"); } else if (e->isNumb()) { throw logic_error("translateEnodeExprNode: Numb"); } else if (e->isTerm()) { assert(e->getArity() >= 1); enodeid_t id = e->getCar()->getId(); ExprNode const * ret = nullptr; Enode * tmp = e; switch (id) { case ENODE_ID_PLUS: ret = translate_enode_to_exprnode(var_map, tmp->get1st(), subst); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = &(*ret + *translate_enode_to_exprnode(var_map, tmp->getCar(), subst)); tmp = tmp->getCdr(); } return ret; case ENODE_ID_MINUS: ret = translate_enode_to_exprnode(var_map, tmp->get1st(), subst); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = &(*ret - *translate_enode_to_exprnode(var_map, tmp->getCar(), subst)); tmp = tmp->getCdr(); } return ret; case ENODE_ID_UMINUS: ret = translate_enode_to_exprnode(var_map, tmp->get1st(), subst); assert(tmp->getArity() == 1); return &(- *ret); case ENODE_ID_TIMES: ret = translate_enode_to_exprnode(var_map, tmp->get1st(), subst); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = &(*ret * *translate_enode_to_exprnode(var_map, tmp->getCar(), subst)); tmp = tmp->getCdr(); } return ret; case ENODE_ID_DIV: ret = translate_enode_to_exprnode(var_map, tmp->get1st(), subst); tmp = tmp->getCdr()->getCdr(); // e is pointing to the 2nd arg while (!tmp->isEnil()) { ret = &(*ret / *translate_enode_to_exprnode(var_map, tmp->getCar(), subst)); tmp = tmp->getCdr(); } return ret; case ENODE_ID_ACOS: assert(e->getArity() == 1); return &acos(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_ASIN: assert(e->getArity() == 1); return &asin(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_ATAN: assert(e->getArity() == 1); return &atan(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_ATAN2: assert(e->getArity() == 2); return &atan2(*translate_enode_to_exprnode(var_map, e->get1st(), subst), *translate_enode_to_exprnode(var_map, e->get2nd(), subst)); case ENODE_ID_MIN: assert(e->getArity() == 2); return &min(*translate_enode_to_exprnode(var_map, e->get1st(), subst), *translate_enode_to_exprnode(var_map, e->get2nd(), subst)); case ENODE_ID_MAX: assert(e->getArity() == 2); return &max(*translate_enode_to_exprnode(var_map, e->get1st(), subst), *translate_enode_to_exprnode(var_map, e->get2nd(), subst)); case ENODE_ID_MATAN: // TODO(soonhok): MATAN throw logic_error("translateEnodeExprNode: MATAN"); case ENODE_ID_SAFESQRT: // TODO(soonhok): SAFESQRT throw logic_error("translateEnodeExprNode: SAFESQRT"); case ENODE_ID_SQRT: assert(e->getArity() == 1); return &sqrt(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_EXP: assert(e->getArity() == 1); return &exp(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_LOG: assert(e->getArity() == 1); return &log(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_POW: { assert(e->getArity() == 2); bool is_1st_constant = false; bool is_1st_int = false; bool is_2nd_constant = false; bool is_2nd_int = false; double dbl_1st = 0.0; int int_1st = 0; double dbl_2nd = 0.0; int int_2nd = 0; if (e->get1st()->isConstant()) { dbl_1st = e->get1st()->getValue(); is_1st_constant = true; double tmp; if (modf(dbl_1st, &tmp) == 0.0) { is_1st_int = true; int_1st = static_cast<int>(tmp); } } if (e->get2nd()->isConstant()) { dbl_2nd = e->get2nd()->getValue(); is_2nd_constant = true; double tmp; if (modf(dbl_2nd, &tmp) == 0.0) { is_2nd_int = true; int_2nd = static_cast<int>(tmp); } } if (is_1st_constant && is_2nd_constant) { // Both of them are constant, just compute and return a number return &ExprConstant::new_scalar(pow(dbl_1st, dbl_2nd)); } // Now, either of them is non-constant. if (is_1st_int) { return &pow(int_1st, *translate_enode_to_exprnode(var_map, e->get2nd(), subst)); } if (is_1st_constant) { return &pow(dbl_1st, *translate_enode_to_exprnode(var_map, e->get2nd(), subst)); } if (is_2nd_int) { return &pow(*translate_enode_to_exprnode(var_map, e->get1st(), subst), int_2nd); } if (is_2nd_constant) { return &pow(*translate_enode_to_exprnode(var_map, e->get1st(), subst), dbl_2nd); } return &pow(*translate_enode_to_exprnode(var_map, e->get1st(), subst), *translate_enode_to_exprnode(var_map, e->get2nd(), subst)); } case ENODE_ID_ABS: assert(e->getArity() == 1); return &abs(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_SIN: assert(e->getArity() == 1); return &sin(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_COS: assert(e->getArity() == 1); return &cos(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_TAN: assert(e->getArity() == 1); return &tan(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_SINH: assert(e->getArity() == 1); return &sinh(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_COSH: assert(e->getArity() == 1); return &cosh(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); case ENODE_ID_TANH: assert(e->getArity() == 1); return &tanh(*translate_enode_to_exprnode(var_map, e->get1st(), subst)); default: throw logic_error("translateEnodeExprNode: Unknown Term"); } } else if (e->isList()) { throw logic_error("translateEnodeExprNode: List"); } else if (e->isDef()) { throw logic_error("translateEnodeExprNode: Def"); } else if (e->isEnil()) { throw logic_error("translateEnodeExprNode: Nil"); } else { throw logic_error("translateEnodeExprNode: unknown case"); } throw logic_error("Not implemented yet: translateEnodeExprNode"); }