bool AnalysisVisitor::getDimension(SymbolicDimension & dim, ast::Exp & arg, bool & safe, SymbolicDimension & out) { switch (arg.getType()) { case ast::Exp::COLONVAR : { out = dim; safe = true; arg.getDecorator().setDollarInfo(argIndices.top()); return true; } case ast::Exp::DOLLARVAR : // a($) { out = SymbolicDimension(getGVN(), 1.); safe = true; arg.getDecorator().setDollarInfo(argIndices.top()); return true; } case ast::Exp::DOUBLEEXP : // a(12) or a([1 2]) { ast::DoubleExp & de = static_cast<ast::DoubleExp &>(arg); if (types::InternalType * const pIT = de.getConstant()) { if (pIT->isDouble()) { types::Double * const pDbl = static_cast<types::Double *>(pIT); if (pDbl->isEmpty()) { out = SymbolicDimension(getGVN(), 0.); safe = true; return true; } const double * real = pDbl->getReal(); const int size = pDbl->getSize(); int64_t max; if (tools::asInteger(real[0], max)) { int64_t min = max; if (!pDbl->isComplex()) { for (int i = 0; i < size; ++i) { int64_t _real; if (tools::asInteger(real[i], _real)) { if (_real < min) { min = _real; } else if (_real > max) { max = _real; } } else { return false; } } out = SymbolicDimension(getGVN(), size); safe = (min >= 1) && getCM().check(ConstraintManager::GREATER, dim.getValue(), getGVN().getValue(max)); return true; } else { const double * imag = pDbl->getImg(); int i; for (i = 0; i < size; ++i) { if (imag[i]) { break; } int64_t _real; if (tools::asInteger(real[i], _real)) { if (_real < min) { min = _real; } else if (_real > max) { max = _real; } } } if (i == size) { out = SymbolicDimension(getGVN(), size); safe = (min >= 1) && getCM().check(ConstraintManager::GREATER, dim.getValue(), getGVN().getValue(max)); return true; } else { return false; } } } else { return false; } } else if (pIT->isImplicitList()) { types::ImplicitList * const pIL = static_cast<types::ImplicitList *>(pIT); double start, step, end; if (AnalysisVisitor::asDouble(pIL->getStart(), start) && AnalysisVisitor::asDouble(pIL->getStep(), step) && AnalysisVisitor::asDouble(pIL->getEnd(), end)) { double single; const int type = ForList64::checkList(start, end, step, single); switch (type) { case 0 : { out = SymbolicDimension(getGVN(), 0.); safe = true; return true; } case 1 : { out = SymbolicDimension(getGVN(), 1.); safe = false; return true; } case 2 : { const uint64_t N = ForList64::size(start, end, step); uint64_t max, min; if (step > 0) { min = start; max = (uint64_t)(start + (N - 1) * step); } else { max = start; min = (uint64_t)(start + (N - 1) * step); } out = SymbolicDimension(getGVN(), N); safe = (min >= 1) && getCM().check(ConstraintManager::GREATER, dim.getValue(), getGVN().getValue((int64_t)max)); return true; } } } } } else { out = SymbolicDimension(getGVN(), 1.); safe = (de.getValue() >= 1) && getCM().check(ConstraintManager::GREATER, dim.getValue(), getGVN().getValue(de.getValue())); return true; } return false; } case ast::Exp::BOOLEXP : // a(a > 1) => a([%f %t %t]) => a([2 3]) { ast::BoolExp & be = static_cast<ast::BoolExp &>(arg); if (types::InternalType * const pIT = be.getConstant()) { if (pIT->isBool()) { types::Bool * const pBool = static_cast<types::Bool *>(pIT); const int size = pBool->getSize(); const int * data = pBool->get(); int64_t max = -1; int64_t count = 0; for (int i = 0; i < size; ++i) { if (data[i]) { ++count; max = i; } } out = SymbolicDimension(getGVN(), count); safe = getCM().check(ConstraintManager::GREATER, dim.getValue(), getGVN().getValue(max)); return true; } } else { if (be.getValue()) { out = SymbolicDimension(getGVN(), int64_t(1)); } else { out = SymbolicDimension(getGVN(), int64_t(0)); } safe = true; return true; } return false; } case ast::Exp::LISTEXP : { ast::ListExp & le = static_cast<ast::ListExp &>(arg); SymbolicList sl; if (SymbolicList::get(*this, le, sl)) { if (sl.isSymbolic()) { sl.evalDollar(getGVN(), dim.getValue()); } TIType typ; if (sl.getType(getGVN(), typ)) { out = SymbolicDimension(getGVN(), typ.cols.getValue()); safe = false;//getCM().check(ConstraintManager::GREATER, dim.getValue(), getGVN().getValue(max)); return true; } } return false; } default : { arg.accept(*this); Result & _res = getResult(); SymbolicRange & range = _res.getRange(); if (range.isValid()) { //std::wcerr << *range.getStart()->poly << ":" << *range.getEnd()->poly << ",," << *dim.getValue()->poly << std::endl; safe = getCM().check(ConstraintManager::VALID_RANGE, range.getStart(), range.getEnd(), getGVN().getValue(int64_t(1)), dim.getValue()); out = _res.getType().rows * _res.getType().cols; return true; } if (GVN::Value * const v = _res.getConstant().getGVNValue()) { GVN::Value * w = v; if (GVN::Value * const dollar = getGVN().getExistingValue(symbol::Symbol(L"$"))) { if (GVN::Value * const x = SymbolicList::evalDollar(getGVN(), v, dollar, dim.getValue())) { w = x; } } bool b = getCM().check(ConstraintManager::GREATER, dim.getValue(), w); if (b) { safe = getCM().check(ConstraintManager::STRICT_POSITIVE, w); } else { safe = false; } out = SymbolicDimension(getGVN(), 1); return true; } // To use with find // e.g. a(find(a > 0)): find(a > 0) return a matrix where the max index is rc(a) so the extraction is safe if (_res.getType().ismatrix() && _res.getType().type != TIType::BOOLEAN) { out = _res.getType().rows * _res.getType().cols; SymbolicDimension & maxIndex = _res.getMaxIndex(); if (maxIndex.isValid()) { safe = getCM().check(ConstraintManager::GREATER, dim.getValue(), maxIndex.getValue()); } else { safe = false; } return true; } return false; } } }
bool MatrixAnalyzer::analyze(AnalysisVisitor & visitor, const unsigned int lhs, ast::CallExp & e) { if (lhs > 1) { return false; } const ast::exps_t args = e.getArgs(); const unsigned int size = args.size(); if (size != 2 && size != 3) { return false; } if (size == 2) { return analyze2Args(visitor, args, e); } ast::Exp * first = args[0]; ast::Exp * second = args[1]; ast::Exp * third = args[2]; first->accept(visitor); Result R1 = visitor.getResult(); if (!R1.getType().ismatrix()) { return false; } second->accept(visitor); Result R2 = visitor.getResult(); third->accept(visitor); Result & R3 = visitor.getResult(); double val; SymbolicDimension rows; SymbolicDimension cols; if (R2.getConstant().getDblValue(val)) { const int nrows = tools::cast<int>(val); if (nrows <= 0) { return false; } else { rows = SymbolicDimension(visitor.getGVN(), nrows); } } else if (GVN::Value * gvnValue = R2.getConstant().getGVNValue()) { if (gvnValue->poly->isConstant() && gvnValue->poly->constant <= 0) { return false; } rows.setValue(gvnValue); rows.setGVN(&visitor.getGVN()); } else { return false; } if (R3.getConstant().getDblValue(val)) { const int ncols = tools::cast<int>(val); if (ncols <= 0) { return false; } else { cols = SymbolicDimension(visitor.getGVN(), ncols); } } else if (GVN::Value * gvnValue = R3.getConstant().getGVNValue()) { if (gvnValue->poly->isConstant() && gvnValue->poly->constant <= 0) { return false; } cols.setValue(gvnValue); cols.setGVN(&visitor.getGVN()); } else { return false; } const TIType & type = R1.getType(); SymbolicDimension prod1 = type.rows * type.cols; SymbolicDimension prod2 = rows * cols; bool res = visitor.getCM().check(ConstraintManager::EQUAL, prod1.getValue(), prod2.getValue()); if (res) { res = visitor.getCM().check(ConstraintManager::POSITIVE, rows.getValue()); if (!res) { return false; } } else { return false; } TIType resT(visitor.getGVN(), R1.getType().type, rows, cols); int tempId; if (R1.getTempId() != -1) { tempId = R1.getTempId(); } else { tempId = visitor.getDM().getTmpId(resT, false); } Result & _res = e.getDecorator().setResult(Result(resT, tempId)); visitor.setResult(_res); return true; }
bool SizeAnalyzer::analyze(AnalysisVisitor & visitor, const unsigned int lhs, ast::CallExp & e) { if (lhs > 2) { return false; } const ast::exps_t args = e.getArgs(); enum Kind { ROWS, COLS, ROWSTIMESCOLS, ROWSCOLS, ONE, BOTH, DUNNO } kind = DUNNO; const std::size_t size = args.size(); if (size == 0 || size >= 3) { return false; } ast::Exp * first = *args.begin(); if (!first) { return false; } first->accept(visitor); Result & res = visitor.getResult(); if (!res.getType().ismatrix()) { visitor.getDM().releaseTmp(res.getTempId(), first); return false; } switch (size) { case 1: if (lhs == 1) { kind = BOTH; } else if (lhs == 2) { kind = ROWSCOLS; } break; case 2: { ast::Exp * second = *std::next(args.begin()); if (second && lhs == 1) { if (second->isStringExp()) { const std::wstring & arg2 = static_cast<ast::StringExp *>(second)->getValue(); if (arg2 == L"r") { kind = ROWS; } else if (arg2 == L"c") { kind = COLS; } else if (arg2 == L"*") { kind = ROWSTIMESCOLS; } else { visitor.getDM().releaseTmp(res.getTempId(), first); return false; } } else if (second->isDoubleExp()) { const double arg2 = static_cast<ast::DoubleExp *>(second)->getValue(); if (arg2 == 1) { kind = ROWS; } else if (arg2 == 2) { kind = COLS; } else if (arg2 >= 3) { // TODO: we should handle hypermatrix kind = ONE; } else { visitor.getDM().releaseTmp(res.getTempId(), first); return false; } } } else { visitor.getDM().releaseTmp(res.getTempId(), first); return false; } break; } default: visitor.getDM().releaseTmp(res.getTempId(), first); return false; } TIType type(visitor.getGVN(), TIType::DOUBLE); switch (kind) { case ROWS: { SymbolicDimension & rows = res.getType().rows; Result & _res = e.getDecorator().setResult(type); _res.getConstant() = rows.getValue(); e.getDecorator().setCall(new SizeCall(SizeCall::R)); visitor.setResult(_res); break; } case COLS: { SymbolicDimension & cols = res.getType().cols; Result & _res = e.getDecorator().setResult(type); _res.getConstant() = cols.getValue(); e.getDecorator().setCall(new SizeCall(SizeCall::C)); visitor.setResult(_res); break; } case ROWSTIMESCOLS: { SymbolicDimension & rows = res.getType().rows; SymbolicDimension & cols = res.getType().cols; SymbolicDimension prod = rows * cols; Result & _res = e.getDecorator().setResult(type); _res.getConstant() = prod.getValue(); e.getDecorator().setCall(new SizeCall(SizeCall::RC)); visitor.setResult(_res); break; } case ROWSCOLS: { SymbolicDimension & rows = res.getType().rows; SymbolicDimension & cols = res.getType().cols; std::vector<Result> & mlhs = visitor.getLHSContainer(); mlhs.clear(); mlhs.reserve(2); mlhs.emplace_back(type); mlhs.back().getConstant() = rows.getValue(); mlhs.emplace_back(type); mlhs.back().getConstant() = cols.getValue(); e.getDecorator().setCall(new SizeCall(SizeCall::R_C)); break; } case ONE: { Result & _res = e.getDecorator().setResult(type); _res.getConstant() = new types::Double(1); e.getDecorator().setCall(new SizeCall(SizeCall::ONE)); visitor.setResult(_res); break; } case BOTH: { TIType _type(visitor.getGVN(), TIType::DOUBLE, 1, 2); Result & _res = e.getDecorator().setResult(_type); e.getDecorator().setCall(new SizeCall(SizeCall::BOTH)); visitor.setResult(_res); break; } default: return false; } return true; }