bool TGlslOutputTraverser::traverseBinary( bool preVisit, TIntermBinary *node, TIntermTraverser *it ) { TString op = "??"; TGlslOutputTraverser* goit = static_cast<TGlslOutputTraverser*>(it); GlslFunction *current = goit->current; std::stringstream& out = current->getActiveOutput(); bool infix = true; bool assign = false; bool needsParens = true; switch (node->getOp()) { case EOpAssign: op = "="; infix = true; needsParens = false; break; case EOpAddAssign: op = "+="; infix = true; needsParens = false; break; case EOpSubAssign: op = "-="; infix = true; needsParens = false; break; case EOpMulAssign: op = "*="; infix = true; needsParens = false; break; case EOpVectorTimesMatrixAssign: op = "*="; infix = true; needsParens = false; break; case EOpVectorTimesScalarAssign: op = "*="; infix = true; needsParens = false; break; case EOpMatrixTimesScalarAssign: op = "*="; infix = true; needsParens = false; break; case EOpMatrixTimesMatrixAssign: op = "*="; infix = true; needsParens = false; break; case EOpDivAssign: op = "/="; infix = true; needsParens = false; break; case EOpModAssign: op = "%="; infix = true; needsParens = false; break; case EOpAndAssign: op = "&="; infix = true; needsParens = false; break; case EOpInclusiveOrAssign: op = "|="; infix = true; needsParens = false; break; case EOpExclusiveOrAssign: op = "^="; infix = true; needsParens = false; break; case EOpLeftShiftAssign: op = "<<="; infix = true; needsParens = false; break; case EOpRightShiftAssign: op = "??="; infix = true; needsParens = false; break; case EOpIndexDirect: { TIntermTyped *left = node->getLeft(); TIntermTyped *right = node->getRight(); assert( left && right); current->beginStatement(); if (Check2DMatrixIndex (goit, out, left, right)) return false; if (left->isMatrix() && !left->isArray()) { if (right->getAsConstant()) { current->addLibFunction (EOpMatrixIndex); out << "xll_matrixindex ("; left->traverse(goit); out << ", "; right->traverse(goit); out << ")"; return false; } else { current->addLibFunction (EOpTranspose); current->addLibFunction (EOpMatrixIndex); current->addLibFunction (EOpMatrixIndexDynamic); out << "xll_matrixindexdynamic ("; left->traverse(goit); out << ", "; right->traverse(goit); out << ")"; return false; } } left->traverse(goit); // Special code for handling a vector component select (this improves readability) if (left->isVector() && !left->isArray() && right->getAsConstant()) { char swiz[] = "xyzw"; goit->visitConstantUnion = TGlslOutputTraverser::traverseImmediateConstant; goit->generatingCode = false; right->traverse(goit); assert( goit->indexList.size() == 1); assert( goit->indexList[0] < 4); out << "." << swiz[goit->indexList[0]]; goit->indexList.clear(); goit->visitConstantUnion = TGlslOutputTraverser::traverseConstantUnion; goit->generatingCode = true; } else { out << "["; right->traverse(goit); out << "]"; } return false; } case EOpIndexIndirect: { TIntermTyped *left = node->getLeft(); TIntermTyped *right = node->getRight(); current->beginStatement(); if (Check2DMatrixIndex (goit, out, left, right)) return false; if (left && right && left->isMatrix() && !left->isArray()) { if (right->getAsConstant()) { current->addLibFunction (EOpMatrixIndex); out << "xll_matrixindex ("; left->traverse(goit); out << ", "; right->traverse(goit); out << ")"; return false; } else { current->addLibFunction (EOpTranspose); current->addLibFunction (EOpMatrixIndex); current->addLibFunction (EOpMatrixIndexDynamic); out << "xll_matrixindexdynamic ("; left->traverse(goit); out << ", "; right->traverse(goit); out << ")"; return false; } } if (left) left->traverse(goit); out << "["; if (right) right->traverse(goit); out << "]"; return false; } case EOpIndexDirectStruct: { current->beginStatement(); GlslStruct *s = goit->createStructFromType(node->getLeft()->getTypePointer()); if (node->getLeft()) node->getLeft()->traverse(goit); // The right child is always an offset into the struct, switch to get an // immediate constant, and put it back afterwords goit->visitConstantUnion = TGlslOutputTraverser::traverseImmediateConstant; goit->generatingCode = false; if (node->getRight()) { node->getRight()->traverse(goit); assert( goit->indexList.size() == 1); assert( goit->indexList[0] < s->memberCount()); out << "." << s->getMember(goit->indexList[0]).name; } goit->indexList.clear(); goit->visitConstantUnion = TGlslOutputTraverser::traverseConstantUnion; goit->generatingCode = true; } return false; case EOpVectorSwizzle: current->beginStatement(); if (node->getLeft()) node->getLeft()->traverse(goit); goit->visitConstantUnion = TGlslOutputTraverser::traverseImmediateConstant; goit->generatingCode = false; if (node->getRight()) { node->getRight()->traverse(goit); assert( goit->indexList.size() <= 4); out << '.'; const char fields[] = "xyzw"; for (int ii = 0; ii < (int)goit->indexList.size(); ii++) { int val = goit->indexList[ii]; assert( val >= 0); assert( val < 4); out << fields[val]; } } goit->indexList.clear(); goit->visitConstantUnion = TGlslOutputTraverser::traverseConstantUnion; goit->generatingCode = true; return false; case EOpMatrixSwizzle: // This presently only works for swizzles as rhs operators if (node->getRight()) { goit->visitConstantUnion = TGlslOutputTraverser::traverseImmediateConstant; goit->generatingCode = false; node->getRight()->traverse(goit); goit->visitConstantUnion = TGlslOutputTraverser::traverseConstantUnion; goit->generatingCode = true; std::vector<int> elements = goit->indexList; goit->indexList.clear(); if (elements.size() > 4 || elements.size() < 1) { goit->infoSink.info << "Matrix swizzle operations can must contain at least 1 and at most 4 element selectors."; return true; } unsigned column[4] = {0}, row[4] = {0}; for (unsigned i = 0; i != elements.size(); ++i) { unsigned val = elements[i]; column[i] = val % 4; row[i] = val / 4; } bool sameColumn = true; for (unsigned i = 1; i != elements.size(); ++i) sameColumn &= column[i] == column[i-1]; static const char* fields = "xyzw"; if (sameColumn) { //select column, then swizzle row if (node->getLeft()) node->getLeft()->traverse(goit); out << "[" << column[0] << "]."; for (unsigned i = 0; i < elements.size(); ++i) out << fields[row[i]]; } else { // Insert constructor, and dereference individually // Might need to account for different types here assert( elements.size() != 1); //should have hit same collumn case out << "vec" << elements.size() << "("; if (node->getLeft()) node->getLeft()->traverse(goit); out << "[" << column[0] << "]."; out << fields[row[0]]; for (unsigned i = 1; i < elements.size(); ++i) { out << ", "; if (node->getLeft()) node->getLeft()->traverse(goit); out << "[" << column[i] << "]."; out << fields[row[i]]; } out << ")"; } } return false; case EOpAdd: op = "+"; infix = true; break; case EOpSub: op = "-"; infix = true; break; case EOpMul: op = "*"; infix = true; break; case EOpDiv: op = "/"; infix = true; break; case EOpMod: op = "mod"; infix = false; break; case EOpRightShift: op = "<<"; infix = true; break; case EOpLeftShift: op = ">>"; infix = true; break; case EOpAnd: op = "&"; infix = true; break; case EOpInclusiveOr: op = "|"; infix = true; break; case EOpExclusiveOr: op = "^"; infix = true; break; case EOpEqual: writeComparison ( "==", "equal", node, goit ); return false; case EOpNotEqual: writeComparison ( "!=", "notEqual", node, goit ); return false; case EOpLessThan: writeComparison ( "<", "lessThan", node, goit ); return false; case EOpGreaterThan: writeComparison ( ">", "greaterThan", node, goit ); return false; case EOpLessThanEqual: writeComparison ( "<=", "lessThanEqual", node, goit ); return false; case EOpGreaterThanEqual: writeComparison ( ">=", "greaterThanEqual", node, goit ); return false; case EOpVectorTimesScalar: op = "*"; infix = true; break; case EOpVectorTimesMatrix: op = "*"; infix = true; break; case EOpMatrixTimesVector: op = "*"; infix = true; break; case EOpMatrixTimesScalar: op = "*"; infix = true; break; case EOpMatrixTimesMatrix: op = "*"; infix = true; break; case EOpLogicalOr: op = "||"; infix = true; break; case EOpLogicalXor: op = "^^"; infix = true; break; case EOpLogicalAnd: op = "&&"; infix = true; break; default: assert(0); } current->beginStatement(); if (infix) { // special case for swizzled matrix assignment if (node->getOp() == EOpAssign && node->getLeft() && node->getRight()) { TIntermBinary* lval = node->getLeft()->getAsBinaryNode(); if (lval && lval->getOp() == EOpMatrixSwizzle) { static const char* vec_swizzles = "xyzw"; TIntermTyped* rval = node->getRight(); TIntermTyped* lexp = lval->getLeft(); goit->visitConstantUnion = TGlslOutputTraverser::traverseImmediateConstant; goit->generatingCode = false; lval->getRight()->traverse(goit); goit->visitConstantUnion = TGlslOutputTraverser::traverseConstantUnion; goit->generatingCode = true; std::vector<int> swizzles = goit->indexList; goit->indexList.clear(); char temp_rval[128]; unsigned n_swizzles = swizzles.size(); if (n_swizzles > 1) { snprintf(temp_rval, 128, "xlat_swiztemp%d", goit->swizzleAssignTempCounter++); current->beginStatement(); out << "vec" << n_swizzles << " " << temp_rval << " = "; rval->traverse(goit); current->endStatement(); } for (unsigned i = 0; i != n_swizzles; ++i) { unsigned col = swizzles[i] / 4; unsigned row = swizzles[i] % 4; current->beginStatement(); lexp->traverse(goit); out << "[" << row << "][" << col << "] = "; if (n_swizzles > 1) out << temp_rval << "." << vec_swizzles[i]; else rval->traverse(goit); current->endStatement(); } return false; } } if (needsParens) out << '('; if (node->getLeft()) node->getLeft()->traverse(goit); out << ' ' << op << ' '; if (node->getRight()) node->getRight()->traverse(goit); if (needsParens) out << ')'; } else { if (assign) { // Need to traverse the left child twice to allow for the assign and the op // This is OK, because we know it is an lvalue if (node->getLeft()) node->getLeft()->traverse(goit); out << " = " << op << '('; if (node->getLeft()) node->getLeft()->traverse(goit); out << ", "; if (node->getRight()) node->getRight()->traverse(goit); out << ')'; } else { out << op << '('; if (node->getLeft()) node->getLeft()->traverse(goit); out << ", "; if (node->getRight()) node->getRight()->traverse(goit); out << ')'; } } return false; }
void ScalarizeVecAndMatConstructorArgs::scalarizeArgs( TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix) { ASSERT(aggregate); int size = 0; switch (aggregate->getOp()) { case EOpConstructVec2: case EOpConstructBVec2: case EOpConstructIVec2: size = 2; break; case EOpConstructVec3: case EOpConstructBVec3: case EOpConstructIVec3: size = 3; break; case EOpConstructVec4: case EOpConstructBVec4: case EOpConstructIVec4: case EOpConstructMat2: size = 4; break; case EOpConstructMat2x3: case EOpConstructMat3x2: size = 6; break; case EOpConstructMat2x4: case EOpConstructMat4x2: size = 8; break; case EOpConstructMat3: size = 9; break; case EOpConstructMat3x4: case EOpConstructMat4x3: size = 12; break; case EOpConstructMat4: size = 16; break; default: break; } TIntermSequence *sequence = aggregate->getSequence(); TIntermSequence original(*sequence); sequence->clear(); for (size_t ii = 0; ii < original.size(); ++ii) { ASSERT(size > 0); TIntermTyped *node = original[ii]->getAsTyped(); ASSERT(node); TString varName = createTempVariable(node); if (node->isScalar()) { TIntermSymbol *symbolNode = new TIntermSymbol(-1, varName, node->getType()); sequence->push_back(symbolNode); size--; } else if (node->isVector()) { if (scalarizeVector) { int repeat = std::min(size, node->getNominalSize()); size -= repeat; for (int index = 0; index < repeat; ++index) { TIntermSymbol *symbolNode = new TIntermSymbol(-1, varName, node->getType()); TIntermBinary *newNode = ConstructVectorIndexBinaryNode( symbolNode, index); sequence->push_back(newNode); } } else { TIntermSymbol *symbolNode = new TIntermSymbol(-1, varName, node->getType()); sequence->push_back(symbolNode); size -= node->getNominalSize(); } } else { ASSERT(node->isMatrix()); if (scalarizeMatrix) { int colIndex = 0, rowIndex = 0; int repeat = std::min(size, node->getCols() * node->getRows()); size -= repeat; while (repeat > 0) { TIntermSymbol *symbolNode = new TIntermSymbol(-1, varName, node->getType()); TIntermBinary *newNode = ConstructMatrixIndexBinaryNode( symbolNode, colIndex, rowIndex); sequence->push_back(newNode); rowIndex++; if (rowIndex >= node->getRows()) { rowIndex = 0; colIndex++; } repeat--; } } else { TIntermSymbol *symbolNode = new TIntermSymbol(-1, varName, node->getType()); sequence->push_back(symbolNode); size -= node->getCols() * node->getRows(); } } } }