// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void RecursionElimination::GenerateAccumulatorUpdate(TailCallInfo& info) {
    DebugValidator::IsTrue(info.Additive.Count() > 0 ||
                           info.Multiplicative.Count() > 0);
    
    auto parentBlock = info.TailCall->ParentBlock();
    auto& refs = funct_->ParentUnit()->References();
    auto blockRef = refs.GetBlockRef(parentBlock);
    
    // The additive accumulator is computed as follows:
    // (a + b + ...) + f(...)
    // aAccum += (a + b + ...) * mAccum
    if(info.Additive.Count() > 0) {
        auto sumOp = info.Additive[0];
        bool isFloating = sumOp->IsFloating();

        for(int i = 1; i < info.Additive.Count(); i++) {
            sumOp = CreateAdd(sumOp, info.Additive[i], info.TailCall);
        }

        // Multiply the sum with the multiplicative accumulator,
        auto mulAccumOp = multiplicativeAccumulator_->ResultOp();
        auto mulOp = CreateMul(sumOp, mulAccumOp, info.TailCall);

        // Add the whole result to the existing additive accumulator.
        auto addAccumOp = additiveAccumulator_->ResultOp();
        auto addOp = CreateAdd(addAccumOp, mulOp, info.TailCall);
        additiveAccumulator_->AddOperand(addOp, blockRef);
    }

    // The multiplicative accumulator is computed as follows:
    // (a * b * ...) * f(...)
    // mAccum = (a * b * ...) * mAccum
    if(info.Multiplicative.Count() > 0) {
        auto parentBlock = info.TailCall->ParentBlock();
        auto productOp = info.Multiplicative[0];
        bool isFloating = productOp->IsFloating();

        for(int i = 1; i < info.Multiplicative.Count(); i++) {
            productOp = CreateMul(productOp, info.Multiplicative[i], 
                                  info.TailCall);
        }

        // Multiply the product with the multiplicative accumulator,
        auto mulAccumOp = multiplicativeAccumulator_->ResultOp();
        auto mulOp = CreateMul(productOp, mulAccumOp, info.TailCall);
        multiplicativeAccumulator_->AddOperand(mulOp, blockRef);
    }
}
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void RecursionElimination::RewriteOtherReturns() {
    // If the function has no return value (it's 'void')
    // then we have nothing to do.
    if(funct_->IsVoid()) {
        return;
    }
    
    for(int i = 0; i < returnInstrs_.Count(); i++) {
        // Rewrite the return to the following form:
        // ret a -> ret (aAccum + mAccum * a)
        auto returnInstr = returnInstrs_[i];
        auto resultOp = returnInstr->ReturnedOp();
        DebugValidator::IsNotNull(resultOp);

        if(multiplicativeAccumulator_) {
            // mAccum * a
            auto mulOp = Temporary::GetTemporary(funct_->ReturnType());
            auto mAccum = multiplicativeAccumulator_->ResultOp();
            resultOp = CreateMul(resultOp, mAccum, returnInstr);
        }

        if(additiveAccumulator_) {
            auto addOp = Temporary::GetTemporary(funct_->ReturnType());
            auto aAccum = additiveAccumulator_->ResultOp();
            resultOp = CreateAdd(resultOp, aAccum, returnInstr);
        }

        returnInstr->SetReturnedOp(resultOp);
    }
}
void LLVMColumnMapProjectionBuilder::build(ScanQuery* query) {
    auto& srcRecord = mContext.record();
    auto& destRecord = query->record();

    // -> auto mainPage = reinterpret_cast<const ColumnMapMainPage*>(page);
    auto mainPage = CreateBitCast(getParam(page), mMainPageStructTy->getPointerTo());

    // -> auto count = static_cast<uint64_t>(mainPage->count);
    auto count = CreateInBoundsGEP(mainPage, { getInt64(0), getInt32(0) });
    count = CreateZExt(CreateAlignedLoad(count, 4u), getInt64Ty());

    // -> auto index = static_cast<uint64_t>(idx);
    auto index = CreateZExt(getParam(idx), getInt64Ty());

    if (destRecord.headerSize() != 0u) {
        // -> auto headerOffset = static_cast<uint64_t>(mainPage->headerOffset);
        auto headerOffset = CreateInBoundsGEP(mainPage, { getInt64(0), getInt32(1) });
        headerOffset = CreateZExt(CreateAlignedLoad(headerOffset, 4u), getInt64Ty());

        // -> auto headerData = page + headerOffset + idx;
        auto headerData = CreateAdd(headerOffset, index);
        headerData = CreateInBoundsGEP(getParam(page), headerData);

        auto i = query->projectionBegin();
        for (decltype(destRecord.fieldCount()) destFieldIdx = 0u; destFieldIdx < destRecord.fieldCount();
                ++i, ++destFieldIdx) {
            auto srcFieldIdx = *i;
            auto& srcMeta = srcRecord.getFieldMeta(srcFieldIdx);
            auto& destMeta = destRecord.getFieldMeta(destFieldIdx);
            auto& field = destMeta.field;
            if (field.isNotNull()) {
                continue;
            }

            // -> auto srcData = headerData + page->count * srcNullIdx
            auto srcData = headerData;
            if (srcMeta.nullIdx != 0) {
                srcData = CreateInBoundsGEP(headerData, createConstMul(count, srcMeta.nullIdx));
            }

            // -> auto nullValue = *srcData;
            auto nullValue = CreateAlignedLoad(srcData, 1u);

            // -> auto destData = dest + destNullIdx;
            auto destData = getParam(dest);
            if (destMeta.nullIdx != 0) {
                destData = CreateInBoundsGEP(destData, getInt64(destMeta.nullIdx));
            }

            // -> *destData = srcValue;
            CreateAlignedStore(nullValue, destData, 1u);
        }
    }

    auto i = query->projectionBegin();
    if (destRecord.fixedSizeFieldCount() != 0) {
        // -> auto fixedOffset = static_cast<uint64_t>(mainPage->fixedOffset);
        auto fixedOffset = CreateInBoundsGEP(mainPage, { getInt64(0), getInt32(2) });
        fixedOffset = CreateZExt(CreateAlignedLoad(fixedOffset, 4u), getInt64Ty());

        // -> auto fixedData = page + fixedOffset;
        auto fixedData = CreateInBoundsGEP(getParam(page), fixedOffset);

        for (decltype(destRecord.fixedSizeFieldCount()) destFieldIdx = 0u;
                destFieldIdx < destRecord.fixedSizeFieldCount(); ++i, ++destFieldIdx) {
            auto srcFieldIdx = *i;
            auto& srcMeta = mContext.fixedMetaData()[srcFieldIdx];
            auto& destMeta = destRecord.getFieldMeta(destFieldIdx);
            auto& field = destMeta.field;
            LOG_ASSERT(field.isFixedSized(), "Field must be fixed size");

            auto fieldAlignment = field.alignOf();
            auto fieldPtrType = getFieldPtrTy(field.type());

            // -> auto srcData = reinterpret_cast<const T*>(fixedData + srcMeta.offset) + index;
            auto srcData = fixedData;
            if (srcMeta.offset != 0) {
                srcData = CreateInBoundsGEP(srcData, createConstMul(count, srcMeta.offset));
            }
            srcData = CreateBitCast(srcData, fieldPtrType);
            srcData = CreateInBoundsGEP(srcData, index);

            // -> auto value = *srcData;
            auto value = CreateAlignedLoad(srcData, fieldAlignment);

            // -> auto destData = reinterpret_cast<const T*>(dest + destMeta.offset);
            auto destData = getParam(dest);
            if (destMeta.offset != 0) {
                destData = CreateInBoundsGEP(destData, getInt64(destMeta.offset));
            }
            destData = CreateBitCast(destData, fieldPtrType);

            // -> *destData = value;
            CreateAlignedStore(value, destData, fieldAlignment);
        }
    }

    // -> auto destHeapOffset = destRecord.staticSize();
    llvm::Value* destHeapOffset = getInt32(destRecord.staticSize());

    if (destRecord.varSizeFieldCount() != 0) {
        auto srcFieldIdx = srcRecord.fixedSizeFieldCount();
        decltype(destRecord.varSizeFieldCount()) destFieldIdx = 0;

        // auto variableOffset = static_cast<uint64_t>(mainPage->variableOffset);
        auto variableOffset = CreateInBoundsGEP(mainPage, { getInt64(0), getInt32(3) });
        variableOffset = CreateZExt(CreateAlignedLoad(variableOffset, 4u), getInt64Ty());

        // -> auto variableData = reinterpret_cast<const ColumnMapHeapEntry*>(page + variableOffset) + idx;
        auto variableData = CreateInBoundsGEP(getParam(page), variableOffset);
        variableData = CreateBitCast(variableData, mHeapEntryStructTy->getPointerTo());
        variableData = CreateInBoundsGEP(variableData, index);

        // -> auto srcData = variableData;
        auto srcData = variableData;

        // -> auto destData = reinterpret_cast<uint32_t*>(dest + destRecord.variableOffset());
        auto destData = getParam(dest);
        if (destRecord.variableOffset() != 0) {
            destData = CreateInBoundsGEP(destData, getInt64(destRecord.variableOffset()));
        }
        destData = CreateBitCast(destData, getInt32PtrTy());

        // -> *destData = destHeapOffset;
        CreateAlignedStore(destHeapOffset, destData, 4u);

        do {
            if (*i != srcFieldIdx) {
                auto step = *i - srcFieldIdx;
                // -> srcData += count * (*i - srcFieldIdx);
                srcData = CreateInBoundsGEP(srcData, createConstMul(count, step));
                srcFieldIdx = *i;
            }

            // -> auto srcHeapOffset = srcData->offset;
            auto srcHeapOffset = CreateInBoundsGEP(srcData, { getInt64(0), getInt32(0) });
            srcHeapOffset = CreateAlignedLoad(srcHeapOffset, 8u);

            // -> auto offsetCorrection = srcHeapOffset - destHeapOffset;
            auto offsetCorrection = CreateSub(srcHeapOffset, destHeapOffset);

            llvm::Value* offset;
            do {
                ++i;

                // Step to offset of the following field (or to the last field of the previous element) to get the end
                // offset
                ++srcFieldIdx;
                if (srcFieldIdx == srcRecord.fieldCount()) {
                    // -> srcData = variableData - 1;
                    srcData = CreateGEP(variableData, getInt64(-1));
                } else {
                    // -> srcData += count;
                    srcData = CreateInBoundsGEP(srcData, count);
                }

                // -> auto offset = srcData->offset - offsetCorrection;
                offset = CreateInBoundsGEP(srcData, { getInt64(0), getInt32(0) });
                offset = CreateAlignedLoad(offset, 8u);
                offset = CreateSub(offset, offsetCorrection);

                // -> ++destData;
                ++destFieldIdx;
                destData = CreateInBoundsGEP(destData, getInt64(1));

                // -> *destData = offset;
                CreateAlignedStore(offset, destData, 4u);
            } while (destFieldIdx < destRecord.varSizeFieldCount() && *i == srcFieldIdx);

            // -> auto srcHeap = page + static_cast<uint64_t>(srcHeapOffset);
            auto srcHeap = CreateInBoundsGEP(getParam(page), CreateZExt(srcHeapOffset, getInt64Ty()));

            // -> auto destHeap = dest + static_cast<uint64_t>(destHeapOffset);
            auto destHeap = CreateInBoundsGEP(getParam(dest), CreateZExt(destHeapOffset, getInt64Ty()));

            // -> auto length = offset - destHeapOffset
            auto length = CreateSub(offset, destHeapOffset);

            // -> memcpy(destHeap, srcHeap, length);
            CreateMemCpy(destHeap, srcHeap, length, 1u);

            // -> destHeapOffset = offset;
            destHeapOffset = offset;
        } while (destFieldIdx < destRecord.varSizeFieldCount());
    }

    // -> return destHeapOffset;
    CreateRet(destHeapOffset);
}