TIntermTyped* TIntermConstantUnion::fold(TOperator op, TIntermTyped* constantNode, TInfoSink& infoSink)
{
    ConstantUnion *unionArray = getUnionArrayPointer();
    size_t objectSize = getType().getObjectSize();

    if (constantNode) {  // binary operations
        TIntermConstantUnion *node = constantNode->getAsConstantUnion();
        ConstantUnion *rightUnionArray = node->getUnionArrayPointer();
        TType returnType = getType();

        // for a case like float f = 1.2 + vec4(2,3,4,5);
        if (constantNode->getType().getObjectSize() == 1 && objectSize > 1) {
            rightUnionArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; ++i)
                rightUnionArray[i] = *node->getUnionArrayPointer();
            returnType = getType();
        } else if (constantNode->getType().getObjectSize() > 1 && objectSize == 1) {
            // for a case like float f = vec4(2,3,4,5) + 1.2;
            unionArray = new ConstantUnion[constantNode->getType().getObjectSize()];
            for (size_t i = 0; i < constantNode->getType().getObjectSize(); ++i)
                unionArray[i] = *getUnionArrayPointer();
            returnType = node->getType();
            objectSize = constantNode->getType().getObjectSize();
        }

        ConstantUnion* tempConstArray = 0;
        TIntermConstantUnion *tempNode;

        bool boolNodeFlag = false;
        switch(op) {
            case EOpAdd:
                tempConstArray = new ConstantUnion[objectSize];
                {// support MSVC++6.0
                    for (size_t i = 0; i < objectSize; i++)
                        tempConstArray[i] = unionArray[i] + rightUnionArray[i];
                }
                break;
            case EOpSub:
                tempConstArray = new ConstantUnion[objectSize];
                {// support MSVC++6.0
                    for (size_t i = 0; i < objectSize; i++)
                        tempConstArray[i] = unionArray[i] - rightUnionArray[i];
                }
                break;

            case EOpMul:
            case EOpVectorTimesScalar:
            case EOpMatrixTimesScalar:
                tempConstArray = new ConstantUnion[objectSize];
                {// support MSVC++6.0
                    for (size_t i = 0; i < objectSize; i++)
                        tempConstArray[i] = unionArray[i] * rightUnionArray[i];
                }
                break;
            case EOpMatrixTimesMatrix:
                if (getType().getBasicType() != EbtFloat || node->getBasicType() != EbtFloat) {
                    infoSink.info.message(EPrefixInternalError, getLine(), "Constant Folding cannot be done for matrix multiply");
                    return 0;
                }
                {// support MSVC++6.0
                    int size = getNominalSize();
                    tempConstArray = new ConstantUnion[size*size];
                    for (int row = 0; row < size; row++) {
                        for (int column = 0; column < size; column++) {
                            tempConstArray[size * column + row].setFConst(0.0f);
                            for (int i = 0; i < size; i++) {
                                tempConstArray[size * column + row].setFConst(tempConstArray[size * column + row].getFConst() + unionArray[i * size + row].getFConst() * (rightUnionArray[column * size + i].getFConst()));
                            }
                        }
                    }
                }
                break;
            case EOpDiv:
                tempConstArray = new ConstantUnion[objectSize];
                {// support MSVC++6.0
                    for (size_t i = 0; i < objectSize; i++) {
                        switch (getType().getBasicType()) {
            case EbtFloat:
                if (rightUnionArray[i] == 0.0f) {
                    infoSink.info.message(EPrefixWarning, getLine(), "Divide by zero error during constant folding");
                    tempConstArray[i].setFConst(unionArray[i].getFConst() < 0 ? -FLT_MAX : FLT_MAX);
                } else
                    tempConstArray[i].setFConst(unionArray[i].getFConst() / rightUnionArray[i].getFConst());
                break;

            case EbtInt:
                if (rightUnionArray[i] == 0) {
                    infoSink.info.message(EPrefixWarning, getLine(), "Divide by zero error during constant folding");
                    tempConstArray[i].setIConst(INT_MAX);
                } else
                    tempConstArray[i].setIConst(unionArray[i].getIConst() / rightUnionArray[i].getIConst());
                break;
            default:
                infoSink.info.message(EPrefixInternalError, getLine(), "Constant folding cannot be done for \"/\"");
                return 0;
                        }
                    }
                }
                break;

            case EOpMatrixTimesVector:
                if (node->getBasicType() != EbtFloat) {
                    infoSink.info.message(EPrefixInternalError, getLine(), "Constant Folding cannot be done for matrix times vector");
                    return 0;
                }
                tempConstArray = new ConstantUnion[getNominalSize()];

                {// support MSVC++6.0
                    for (int size = getNominalSize(), i = 0; i < size; i++) {
                        tempConstArray[i].setFConst(0.0f);
                        for (int j = 0; j < size; j++) {
                            tempConstArray[i].setFConst(tempConstArray[i].getFConst() + ((unionArray[j*size + i].getFConst()) * rightUnionArray[j].getFConst()));
                        }
                    }
                }

                tempNode = new TIntermConstantUnion(tempConstArray, node->getType());
                tempNode->setLine(getLine());

                return tempNode;

            case EOpVectorTimesMatrix:
                if (getType().getBasicType() != EbtFloat) {
                    infoSink.info.message(EPrefixInternalError, getLine(), "Constant Folding cannot be done for vector times matrix");
                    return 0;
                }

                tempConstArray = new ConstantUnion[getNominalSize()];
                {// support MSVC++6.0
                    for (int size = getNominalSize(), i = 0; i < size; i++) {
                        tempConstArray[i].setFConst(0.0f);
                        for (int j = 0; j < size; j++) {
                            tempConstArray[i].setFConst(tempConstArray[i].getFConst() + ((unionArray[j].getFConst()) * rightUnionArray[i*size + j].getFConst()));
                        }
                    }
                }
                break;

            case EOpLogicalAnd: // this code is written for possible future use, will not get executed currently
                tempConstArray = new ConstantUnion[objectSize];
                {// support MSVC++6.0
                    for (size_t i = 0; i < objectSize; i++)
                        tempConstArray[i] = unionArray[i] && rightUnionArray[i];
                }
                break;

            case EOpLogicalOr: // this code is written for possible future use, will not get executed currently
                tempConstArray = new ConstantUnion[objectSize];
                {// support MSVC++6.0
                    for (size_t i = 0; i < objectSize; i++)
                        tempConstArray[i] = unionArray[i] || rightUnionArray[i];
                }
                break;

            case EOpLogicalXor:
                tempConstArray = new ConstantUnion[objectSize];
                {// support MSVC++6.0
                    for (size_t i = 0; i < objectSize; i++)
                        switch (getType().getBasicType()) {
            case EbtBool: tempConstArray[i].setBConst((unionArray[i] == rightUnionArray[i]) ? false : true); break;
            default: assert(false && "Default missing");
                    }
                }
                break;

            case EOpLessThan:
                assert(objectSize == 1);
                tempConstArray = new ConstantUnion[1];
                tempConstArray->setBConst(*unionArray < *rightUnionArray);
                returnType = TType(EbtBool, EbpUndefined, EvqConst);
                break;
            case EOpGreaterThan:
                assert(objectSize == 1);
                tempConstArray = new ConstantUnion[1];
                tempConstArray->setBConst(*unionArray > *rightUnionArray);
                returnType = TType(EbtBool, EbpUndefined, EvqConst);
                break;
            case EOpLessThanEqual:
                {
                    assert(objectSize == 1);
                    ConstantUnion constant;
                    constant.setBConst(*unionArray > *rightUnionArray);
                    tempConstArray = new ConstantUnion[1];
                    tempConstArray->setBConst(!constant.getBConst());
                    returnType = TType(EbtBool, EbpUndefined, EvqConst);
                    break;
                }
            case EOpGreaterThanEqual:
                {
                    assert(objectSize == 1);
                    ConstantUnion constant;
                    constant.setBConst(*unionArray < *rightUnionArray);
                    tempConstArray = new ConstantUnion[1];
                    tempConstArray->setBConst(!constant.getBConst());
                    returnType = TType(EbtBool, EbpUndefined, EvqConst);
                    break;
                }

            case EOpEqual:
                if (getType().getBasicType() == EbtStruct) {
                    if (!CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))
                        boolNodeFlag = true;
                } else {
                    for (size_t i = 0; i < objectSize; i++) {
                        if (unionArray[i] != rightUnionArray[i]) {
                            boolNodeFlag = true;
                            break;  // break out of for loop
                        }
                    }
                }

                tempConstArray = new ConstantUnion[1];
                if (!boolNodeFlag) {
                    tempConstArray->setBConst(true);
                }
                else {
                    tempConstArray->setBConst(false);
                }

                tempNode = new TIntermConstantUnion(tempConstArray, TType(EbtBool, EbpUndefined, EvqConst));
                tempNode->setLine(getLine());

                return tempNode;

            case EOpNotEqual:
                if (getType().getBasicType() == EbtStruct) {
                    if (CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))
                        boolNodeFlag = true;
                } else {
                    for (size_t i = 0; i < objectSize; i++) {
                        if (unionArray[i] == rightUnionArray[i]) {
                            boolNodeFlag = true;
                            break;  // break out of for loop
                        }
                    }
                }

                tempConstArray = new ConstantUnion[1];
                if (!boolNodeFlag) {
                    tempConstArray->setBConst(true);
                }
                else {
                    tempConstArray->setBConst(false);
                }

                tempNode = new TIntermConstantUnion(tempConstArray, TType(EbtBool, EbpUndefined, EvqConst));
                tempNode->setLine(getLine());

                return tempNode;

            default:
                infoSink.info.message(EPrefixInternalError, getLine(), "Invalid operator for constant folding");
                return 0;
        }
        tempNode = new TIntermConstantUnion(tempConstArray, returnType);
        tempNode->setLine(getLine());

        return tempNode;
    } else {
        //
        // Do unary operations
        //
        TIntermConstantUnion *newNode = 0;
        ConstantUnion* tempConstArray = new ConstantUnion[objectSize];
        for (size_t i = 0; i < objectSize; i++) {
            switch(op) {
                case EOpNegative:
                    switch (getType().getBasicType()) {
                        case EbtFloat: tempConstArray[i].setFConst(-unionArray[i].getFConst()); break;
                        case EbtInt:   tempConstArray[i].setIConst(-unionArray[i].getIConst()); break;
                        default:
                            infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
                            return 0;
                    }
                    break;
                case EOpLogicalNot: // this code is written for possible future use, will not get executed currently
                    switch (getType().getBasicType()) {
                        case EbtBool:  tempConstArray[i].setBConst(!unionArray[i].getBConst()); break;
                        default:
                            infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
                            return 0;
                    }
                    break;
                default:
                    return 0;
            }
        }
        newNode = new TIntermConstantUnion(tempConstArray, getType());
        newNode->setLine(getLine());
        return newNode;
    }
}
Exemple #2
0
//
// The fold functions see if an operation on a constant can be done in place,
// without generating run-time code.
//
// Returns the node to keep using, which may or may not be the node passed in.
//
TIntermTyped *TIntermConstantUnion::fold(
    TOperator op, TIntermTyped *constantNode, TInfoSink &infoSink)
{
    ConstantUnion *unionArray = getUnionArrayPointer();

    if (!unionArray)
        return NULL;

    size_t objectSize = getType().getObjectSize();

    if (constantNode)
    {
        // binary operations
        TIntermConstantUnion *node = constantNode->getAsConstantUnion();
        ConstantUnion *rightUnionArray = node->getUnionArrayPointer();
        TType returnType = getType();

        if (!rightUnionArray)
            return NULL;

        // for a case like float f = 1.2 + vec4(2,3,4,5);
        if (constantNode->getType().getObjectSize() == 1 && objectSize > 1)
        {
            rightUnionArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; ++i)
            {
                rightUnionArray[i] = *node->getUnionArrayPointer();
            }
            returnType = getType();
        }
        else if (constantNode->getType().getObjectSize() > 1 && objectSize == 1)
        {
            // for a case like float f = vec4(2,3,4,5) + 1.2;
            unionArray = new ConstantUnion[constantNode->getType().getObjectSize()];
            for (size_t i = 0; i < constantNode->getType().getObjectSize(); ++i)
            {
                unionArray[i] = *getUnionArrayPointer();
            }
            returnType = node->getType();
            objectSize = constantNode->getType().getObjectSize();
        }

        ConstantUnion *tempConstArray = NULL;
        TIntermConstantUnion *tempNode;

        bool boolNodeFlag = false;
        switch(op)
        {
          case EOpAdd:
            tempConstArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; i++)
                tempConstArray[i] = unionArray[i] + rightUnionArray[i];
            break;
          case EOpSub:
            tempConstArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; i++)
                tempConstArray[i] = unionArray[i] - rightUnionArray[i];
            break;

          case EOpMul:
          case EOpVectorTimesScalar:
          case EOpMatrixTimesScalar:
            tempConstArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; i++)
                tempConstArray[i] = unionArray[i] * rightUnionArray[i];
            break;

          case EOpMatrixTimesMatrix:
            {
                if (getType().getBasicType() != EbtFloat ||
                    node->getBasicType() != EbtFloat)
                {
                    infoSink.info.message(
                        EPrefixInternalError, getLine(),
                        "Constant Folding cannot be done for matrix multiply");
                    return NULL;
                }

                const int leftCols = getCols();
                const int leftRows = getRows();
                const int rightCols = constantNode->getType().getCols();
                const int rightRows = constantNode->getType().getRows();
                const int resultCols = rightCols;
                const int resultRows = leftRows;

                tempConstArray = new ConstantUnion[resultCols*resultRows];
                for (int row = 0; row < resultRows; row++)
                {
                    for (int column = 0; column < resultCols; column++)
                    {
                        tempConstArray[resultRows * column + row].setFConst(0.0f);
                        for (int i = 0; i < leftCols; i++)
                        {
                            tempConstArray[resultRows * column + row].setFConst(
                                tempConstArray[resultRows * column + row].getFConst() +
                                unionArray[i * leftRows + row].getFConst() *
                                rightUnionArray[column * rightRows + i].getFConst());
                        }
                    }
                }

                // update return type for matrix product
                returnType.setPrimarySize(resultCols);
                returnType.setSecondarySize(resultRows);
            }
            break;

          case EOpDiv:
          case EOpIMod:
            {
                tempConstArray = new ConstantUnion[objectSize];
                for (size_t i = 0; i < objectSize; i++)
                {
                    switch (getType().getBasicType())
                    {
                      case EbtFloat:
                        if (rightUnionArray[i] == 0.0f)
                        {
                            infoSink.info.message(
                                EPrefixWarning, getLine(),
                                "Divide by zero error during constant folding");
                            tempConstArray[i].setFConst(
                                unionArray[i].getFConst() < 0 ? -FLT_MAX : FLT_MAX);
                        }
                        else
                        {
                            ASSERT(op == EOpDiv);
                            tempConstArray[i].setFConst(
                                unionArray[i].getFConst() /
                                rightUnionArray[i].getFConst());
                        }
                        break;

                      case EbtInt:
                        if (rightUnionArray[i] == 0)
                        {
                            infoSink.info.message(
                                EPrefixWarning, getLine(),
                                "Divide by zero error during constant folding");
                            tempConstArray[i].setIConst(INT_MAX);
                        }
                        else
                        {
                            if (op == EOpDiv)
                            {
                                tempConstArray[i].setIConst(
                                    unionArray[i].getIConst() /
                                    rightUnionArray[i].getIConst());
                            }
                            else
                            {
                                ASSERT(op == EOpIMod);
                                tempConstArray[i].setIConst(
                                    unionArray[i].getIConst() %
                                    rightUnionArray[i].getIConst());
                            }
                        }
                        break;

                      case EbtUInt:
                        if (rightUnionArray[i] == 0)
                        {
                            infoSink.info.message(
                                EPrefixWarning, getLine(),
                                "Divide by zero error during constant folding");
                            tempConstArray[i].setUConst(UINT_MAX);
                        }
                        else
                        {
                            if (op == EOpDiv)
                            {
                                tempConstArray[i].setUConst(
                                    unionArray[i].getUConst() /
                                    rightUnionArray[i].getUConst());
                            }
                            else
                            {
                                ASSERT(op == EOpIMod);
                                tempConstArray[i].setUConst(
                                    unionArray[i].getUConst() %
                                    rightUnionArray[i].getUConst());
                            }
                        }
                        break;

                      default:
                        infoSink.info.message(
                            EPrefixInternalError, getLine(),
                            "Constant folding cannot be done for \"/\"");
                        return NULL;
                    }
                }
            }
            break;

          case EOpMatrixTimesVector:
            {
                if (node->getBasicType() != EbtFloat)
                {
                    infoSink.info.message(
                        EPrefixInternalError, getLine(),
                        "Constant Folding cannot be done for matrix times vector");
                    return NULL;
                }

                const int matrixCols = getCols();
                const int matrixRows = getRows();

                tempConstArray = new ConstantUnion[matrixRows];

                for (int matrixRow = 0; matrixRow < matrixRows; matrixRow++)
                {
                    tempConstArray[matrixRow].setFConst(0.0f);
                    for (int col = 0; col < matrixCols; col++)
                    {
                        tempConstArray[matrixRow].setFConst(
                            tempConstArray[matrixRow].getFConst() +
                            unionArray[col * matrixRows + matrixRow].getFConst() *
                            rightUnionArray[col].getFConst());
                    }
                }

                returnType = node->getType();
                returnType.setPrimarySize(matrixRows);

                tempNode = new TIntermConstantUnion(tempConstArray, returnType);
                tempNode->setLine(getLine());

                return tempNode;
            }

          case EOpVectorTimesMatrix:
            {
                if (getType().getBasicType() != EbtFloat)
                {
                    infoSink.info.message(
                        EPrefixInternalError, getLine(),
                        "Constant Folding cannot be done for vector times matrix");
                    return NULL;
                }

                const int matrixCols = constantNode->getType().getCols();
                const int matrixRows = constantNode->getType().getRows();

                tempConstArray = new ConstantUnion[matrixCols];

                for (int matrixCol = 0; matrixCol < matrixCols; matrixCol++)
                {
                    tempConstArray[matrixCol].setFConst(0.0f);
                    for (int matrixRow = 0; matrixRow < matrixRows; matrixRow++)
                    {
                        tempConstArray[matrixCol].setFConst(
                            tempConstArray[matrixCol].getFConst() +
                            unionArray[matrixRow].getFConst() *
                            rightUnionArray[matrixCol * matrixRows + matrixRow].getFConst());
                    }
                }

                returnType.setPrimarySize(matrixCols);
            }
            break;

          case EOpLogicalAnd:
            // this code is written for possible future use,
            // will not get executed currently
            {
                tempConstArray = new ConstantUnion[objectSize];
                for (size_t i = 0; i < objectSize; i++)
                {
                    tempConstArray[i] = unionArray[i] && rightUnionArray[i];
                }
            }
            break;

          case EOpLogicalOr:
            // this code is written for possible future use,
            // will not get executed currently
            {
                tempConstArray = new ConstantUnion[objectSize];
                for (size_t i = 0; i < objectSize; i++)
                {
                    tempConstArray[i] = unionArray[i] || rightUnionArray[i];
                }
            }
            break;

          case EOpLogicalXor:
            {
                tempConstArray = new ConstantUnion[objectSize];
                for (size_t i = 0; i < objectSize; i++)
                {
                    switch (getType().getBasicType())
                    {
                      case EbtBool:
                        tempConstArray[i].setBConst(
                            unionArray[i] == rightUnionArray[i] ? false : true);
                        break;
                      default:
                        UNREACHABLE();
                        break;
                    }
                }
            }
            break;

          case EOpBitwiseAnd:
            tempConstArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; i++)
                tempConstArray[i] = unionArray[i] & rightUnionArray[i];
            break;
          case EOpBitwiseXor:
            tempConstArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; i++)
                tempConstArray[i] = unionArray[i] ^ rightUnionArray[i];
            break;
          case EOpBitwiseOr:
            tempConstArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; i++)
                tempConstArray[i] = unionArray[i] | rightUnionArray[i];
            break;
          case EOpBitShiftLeft:
            tempConstArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; i++)
                tempConstArray[i] = unionArray[i] << rightUnionArray[i];
            break;
          case EOpBitShiftRight:
            tempConstArray = new ConstantUnion[objectSize];
            for (size_t i = 0; i < objectSize; i++)
                tempConstArray[i] = unionArray[i] >> rightUnionArray[i];
            break;

          case EOpLessThan:
            ASSERT(objectSize == 1);
            tempConstArray = new ConstantUnion[1];
            tempConstArray->setBConst(*unionArray < *rightUnionArray);
            returnType = TType(EbtBool, EbpUndefined, EvqConst);
            break;

          case EOpGreaterThan:
            ASSERT(objectSize == 1);
            tempConstArray = new ConstantUnion[1];
            tempConstArray->setBConst(*unionArray > *rightUnionArray);
            returnType = TType(EbtBool, EbpUndefined, EvqConst);
            break;

          case EOpLessThanEqual:
            {
                ASSERT(objectSize == 1);
                ConstantUnion constant;
                constant.setBConst(*unionArray > *rightUnionArray);
                tempConstArray = new ConstantUnion[1];
                tempConstArray->setBConst(!constant.getBConst());
                returnType = TType(EbtBool, EbpUndefined, EvqConst);
                break;
            }

          case EOpGreaterThanEqual:
            {
                ASSERT(objectSize == 1);
                ConstantUnion constant;
                constant.setBConst(*unionArray < *rightUnionArray);
                tempConstArray = new ConstantUnion[1];
                tempConstArray->setBConst(!constant.getBConst());
                returnType = TType(EbtBool, EbpUndefined, EvqConst);
                break;
            }

          case EOpEqual:
            if (getType().getBasicType() == EbtStruct)
            {
                if (!CompareStructure(node->getType(),
                                      node->getUnionArrayPointer(),
                                      unionArray))
                {
                    boolNodeFlag = true;
                }
            }
            else
            {
                for (size_t i = 0; i < objectSize; i++)
                {
                    if (unionArray[i] != rightUnionArray[i])
                    {
                        boolNodeFlag = true;
                        break;  // break out of for loop
                    }
                }
            }

            tempConstArray = new ConstantUnion[1];
            if (!boolNodeFlag)
            {
                tempConstArray->setBConst(true);
            }
            else
            {
                tempConstArray->setBConst(false);
            }

            tempNode = new TIntermConstantUnion(
                tempConstArray, TType(EbtBool, EbpUndefined, EvqConst));
            tempNode->setLine(getLine());

            return tempNode;

          case EOpNotEqual:
            if (getType().getBasicType() == EbtStruct)
            {
                if (CompareStructure(node->getType(),
                                     node->getUnionArrayPointer(),
                                     unionArray))
                {
                    boolNodeFlag = true;
                }
            }
            else
            {
                for (size_t i = 0; i < objectSize; i++)
                {
                    if (unionArray[i] == rightUnionArray[i])
                    {
                        boolNodeFlag = true;
                        break;  // break out of for loop
                    }
                }
            }

            tempConstArray = new ConstantUnion[1];
            if (!boolNodeFlag)
            {
                tempConstArray->setBConst(true);
            }
            else
            {
                tempConstArray->setBConst(false);
            }

            tempNode = new TIntermConstantUnion(
                tempConstArray, TType(EbtBool, EbpUndefined, EvqConst));
            tempNode->setLine(getLine());

            return tempNode;

          default:
            infoSink.info.message(
                EPrefixInternalError, getLine(),
                "Invalid operator for constant folding");
            return NULL;
        }
        tempNode = new TIntermConstantUnion(tempConstArray, returnType);
        tempNode->setLine(getLine());

        return tempNode;
    }
    else
    {