bool contains::operator()(BinOp v) const { if (exp==Expression(v)) return true; Expression l=v.left(), r=v.right(); bool rl = apply(l); bool ll = apply(r); return rl || ll; }
Array<To> createBinaryNode(const Array<Ti> &lhs, const Array<Ti> &rhs, const af::dim4 &odims) { BinOp<To, Ti, op> bop; common::Node_ptr lhs_node = lhs.getNode(); common::Node_ptr rhs_node = rhs.getNode(); common::BinaryNode *node = new common::BinaryNode(getFullName<To>(), shortname<To>(true), bop.name(), lhs_node, rhs_node, (int)(op)); return createNodeArray<To>(odims, common::Node_ptr(node)); }
Array<To> createBinaryNode(const Array<Ti> &lhs, const Array<Ti> &rhs, const af::dim4 &odims) { BinOp<To, Ti, op> bop; JIT::Node_ptr lhs_node = lhs.getNode(); JIT::Node_ptr rhs_node = rhs.getNode(); JIT::BinaryNode *node = new JIT::BinaryNode(dtype_traits<To>::getName(), shortname<To>(true), bop.name(), lhs_node, rhs_node, (int)(op)); return createNodeArray<To>(odims, JIT::Node_ptr( reinterpret_cast<JIT::Node *>(node))); }
void *calc(int x, int y, int z, int w) { m_val = m_op.eval(*(Ti *)m_lhs->calc(x, y, z, w), *(Ti *)m_rhs->calc(x, y, z, w)); return (void *)&m_val; }
static Tree simplification (Tree sig) { assert(sig); int opnum; Tree t1, t2, t3, t4; xtended* xt = (xtended*) getUserData(sig); // primitive elements if (xt) { //return 3; vector<Tree> args; for (int i=0; i<sig->arity(); i++) { args.push_back( sig->branch(i) ); } // to avoid negative power to further normalization if (xt != gPowPrim) { return xt->computeSigOutput(args); } else { return normalizeAddTerm(xt->computeSigOutput(args)); } } else if (isSigBinOp(sig, &opnum, t1, t2)) { BinOp* op = gBinOpTable[opnum]; Node n1 = t1->node(); Node n2 = t2->node(); if (isNum(n1) && isNum(n2)) return tree(op->compute(n1,n2)); else if (op->isLeftNeutral(n1)) return t2; else if (op->isRightNeutral(n2)) return t1; else return normalizeAddTerm(sig); } else if (isSigDelay1(sig, t1)) { return normalizeDelay1Term (t1); } else if (isSigFixDelay(sig, t1, t2)) { return normalizeFixedDelayTerm (t1, t2); } else if (isSigIntCast(sig, t1)) { Tree tx; int i; double x; Node n1 = t1->node(); if (isInt(n1, &i)) return t1; if (isDouble(n1, &x)) return tree(int(x)); if (isSigIntCast(t1, tx)) return t1; return sig; } else if (isSigFloatCast(sig, t1)) { Tree tx; int i; double x; Node n1 = t1->node(); if (isInt(n1, &i)) return tree(double(i)); if (isDouble(n1, &x)) return t1; if (isSigFloatCast(t1, tx)) return t1; return sig; } else if (isSigSelect2(sig, t1, t2, t3)){ Node n1 = t1->node(); if (isZero(n1)) return t2; if (isNum(n1)) return t3; if (t2==t3) return t2; return sig; } else if (isSigSelect3(sig, t1, t2, t3, t4)){ Node n1 = t1->node(); if (isZero(n1)) return t2; if (isOne(n1)) return t3; if (isNum(n1)) return t4; if (t3==t4) return simplification(sigSelect2(t1,t2,t3)); return sig; } else { return sig; } }
TypeInfo const* SecondPass::GetResultTypeOf(Expr* p, Symbol const* doNotUse = nullptr){ NodeType nt = p->GetNodeType(); switch (nt){ case NodeType::IntLit: case NodeType::FloatLit: case NodeType::BoolLit: case NodeType::StringLit: return p->GetTypeInfo(); case NodeType::Ident: { Ident* n = (Ident*)p; if (n->GetTypeInfo() == nullptr) return nullptr; if (n->GetSymbol()->type == Symbol::FUNCTION){ AddError(n->GetToken(), "Symbol '%s' is a function, but used like a variable.", n->GetName().c_str()); return nullptr; } if (doNotUse == nullptr) return p->GetTypeInfo(); else{ if (n->GetSymbol() == doNotUse){ AddError(n->GetToken(), "Symbol '%s' is used in its definition.", n->GetName().c_str()); return nullptr; } return p->GetTypeInfo(); } } case NodeType::UnOp: { UnOp* n = (UnOp*)p; UnOp::Types ntype = n->GetType(); TypeInfo const* ti = GetResultTypeOf(n->GetExpr()); if (ti == nullptr) return nullptr; if (ntype == UnOp::Neg && ti->name != "int" && ti->name != "float"){ AddError(n->GetExpr()->GetToken(), "Type mismatch: Expected 'int' type but found '%s' in operation '%s'.", ti->name.c_str(), UnOp::GetTypeAsString(ntype)); return nullptr; } if( ntype == UnOp::Not && ti->name != "bool" ){ AddError(n->GetExpr()->GetToken(), "Type mismatch: Expected 'bool' type but found '%s' in operation '%s'.", ti->name.c_str(), UnOp::GetTypeAsString(ntype)); return nullptr; } p->SetTypeInfo(m_typeTable.Get("bool")); return p->GetTypeInfo(); } case NodeType::BinOp: { BinOp* n = (BinOp*) p; TypeInfo const* lt = GetResultTypeOf(n->GetLeft()); TypeInfo const* rt = GetResultTypeOf(n->GetRight()); if (lt == nullptr || rt == nullptr) return nullptr; switch (n->GetType()){ case BinOp::Equal: case BinOp::Unequal: { if( lt != rt ){ AddError(n->GetToken(), "Type mismatch: Expected same types but found '%s' and '%s' in operation '%s'.", lt->name.c_str(), rt->name.c_str(), BinOp::GetTypeAsString(n->GetType())); return nullptr; } p->SetTypeInfo(m_typeTable.Get("bool")); return p->GetTypeInfo(); } case BinOp::LThan: case BinOp::LThanEq: case BinOp::GThan: case BinOp::GThanEq: { if( rt != lt || rt->name != "int" || rt->name != "float" ){ AddError(n->GetToken(), "Type mismatch: Expected matching numeric types but found '%s' and '%s' in operation '%s'.", lt->name.c_str(), rt->name.c_str(), BinOp::GetTypeAsString(n->GetType())); return nullptr; } p->SetTypeInfo(m_typeTable.Get("bool")); return p->GetTypeInfo(); } case BinOp::And: case BinOp::Or: case BinOp::Xor: { if( lt->name != "bool" || rt->name != "bool" ){ AddError(n->GetToken(), "Type mismatch: Expected 'bool' types but found '%s' and '%s' in operation '%s'.", lt->name.c_str(), rt->name.c_str(), BinOp::GetTypeAsString(n->GetType())); return nullptr; } p->SetTypeInfo(m_typeTable.Get("bool")); return p->GetTypeInfo(); } case BinOp::Sub: case BinOp::Mul: case BinOp::Div: { if( lt != rt || lt->name != "int" && lt->name != "float" ){ AddError(n->GetToken(), "Type mismatch: Expected matching numeric types but found '%s' and '%s' in operation '%s'.", lt->name.c_str(), rt->name.c_str(), BinOp::GetTypeAsString(n->GetType())); return nullptr; } p->SetTypeInfo(lt); return p->GetTypeInfo(); } case BinOp::Mod: { if (lt != rt || lt->name != "int"){ AddError(n->GetToken(), "Type mismatch: Expected 'int' types but found '%s' and '%s' in operation '%s'.", lt->name.c_str(), rt->name.c_str(), BinOp::GetTypeAsString(n->GetType())); return nullptr; } p->SetTypeInfo(lt); return p->GetTypeInfo(); } case BinOp::Add: { if (lt->name == "string" || rt->name == "string"){ //can concatenate strings with everything p->SetTypeInfo(m_typeTable.Get("string")); return p->GetTypeInfo(); } if (lt != rt || lt->name != "int" && lt->name != "float"){ AddError(n->GetToken(), "Type mismatch: Expected matching numeric types or strings but found '%s' and '%s' in operation '%s'.", lt->name.c_str(), rt->name.c_str(), BinOp::GetTypeAsString(n->GetType())); return nullptr; } p->SetTypeInfo(lt); return p->GetTypeInfo(); } case BinOp::Subscript: { if (!lt->isArray){ AddError(n->GetRight()->GetToken(), "Type mismatch: Expected array type but found '%s' in operation '%s'.", rt->name.c_str(), BinOp::GetTypeAsString(n->GetType())); return nullptr; } if (rt->name != "int"){ AddError(n->GetRight()->GetToken(), "Type mismatch: Expected numeric type but found '%s' in operation '%s'.", rt->name.c_str(), BinOp::GetTypeAsString(n->GetType())); return nullptr; } auto simpleType = m_typeTable.Get(lt->GetSimpleTypeName()); _ASSERT(simpleType != nullptr); p->SetTypeInfo(simpleType); return p->GetTypeInfo(); } default: __debugbreak(); return nullptr; } } case NodeType::FuncCallExpr: { FuncCallExpr* n = (FuncCallExpr*)p; Symbol const* calleeSym = n->GetCallee()->GetSymbol(); TypeInfo const* calleeTI = p->GetTypeInfo(); // GetResultTypeOf(n->GetCallee()); if (calleeSym->type != Symbol::FUNCTION){ AddError(n->GetCallee()->GetToken(), "Symbol '%s' is not a function, but used as one.", calleeSym->name.c_str()); return nullptr; } p->SetTypeInfo(calleeSym->GetTypeInfo()); //check arguments const auto& args = n->GetArgs()->GetChildren(); const auto& params = calleeSym->GetParamList()->GetChildren(); if (args.size() != params.size()){ AddError(n->GetCallee()->GetToken(), "Function requires %d arguments, but %d supplied.", params.size(), args.size()); return p->GetTypeInfo(); } for (size_t i = 0; i < args.size(); ++i){ TypeInfo const* actual = GetResultTypeOf(args[i].get()); TypeInfo const* required = params[i]->GetType()->GetTypeInfo(); if (actual != required){ AddError(args[i]->GetToken(), "%dth argument is type '%s' but must be of type '%s'.", i+1, required->name.c_str(), actual->name.c_str()); //return p->GetTypeInfo(); } } //All checked return p->GetTypeInfo(); } default: __debugbreak(); return nullptr; } }