const void* LLVMCodeGenerator::visit(ASTFunctionDef* node) { auto function = (llvm::Function*)node->proto->accept(this); // set names size_t i = 0; for (llvm::Function::arg_iterator ai = function->arg_begin(); ai != function->arg_end(); ++ai) { ai->setName(node->proto->arg_names[i]); ++i; } // create the block llvm::BasicBlock* entry_block = llvm::BasicBlock::Create(_context, "entry", function); llvm::IRBuilderBase::InsertPoint ip = _builder.saveIP(); _builder.SetInsertPoint(entry_block); // load arguments // we copy them to allow them to be lvalues // TODO: make sure llvm can optimize the unnecessary copies out? i = 0; for (llvm::Function::arg_iterator ai = function->arg_begin(); ai != function->arg_end(); ++ai) { llvm::AllocaInst* alloca = _builder.CreateAlloca(ai->getType(), 0, node->arg_prefix + node->proto->arg_names[i]); _named_values[node->arg_prefix + node->proto->arg_names[i]] = alloca; _builder.CreateStore(ai, alloca); ++i; } // build the body llvm::AllocaInst* return_alloca = nullptr; if (node->proto->func->return_type()->type() != SLTypeTypeVoid) { return_alloca = _builder.CreateAlloca(_llvm_type(node->proto->func->return_type()), nullptr, "ret"); } llvm::BasicBlock* return_block = llvm::BasicBlock::Create(_context, "return", function); FunctionContext context(node->proto->func, function, return_alloca, return_block); auto prev_function_context = _current_function_context; _current_function_context = context; _build_basic_block(entry_block, node->body, return_block); _current_function_context = prev_function_context; _builder.SetInsertPoint(return_block); if (return_alloca) { _builder.CreateRet(_builder.CreateLoad(return_alloca)); } else { _builder.CreateRetVoid(); } // end the block _builder.restoreIP(ip); return nullptr; }
static void CreateInstrBreakpoint(llvm::BasicBlock *B, VA pc) { auto F = B->getParent(); auto M = F->getParent(); std::stringstream ss; ss << "breakpoint_" << std::hex << pc; auto instr_func_name = ss.str(); auto IFT = M->getFunction(instr_func_name); if (!IFT) { IFT = llvm::Function::Create( LiftedFunctionType(), llvm::GlobalValue::ExternalLinkage, instr_func_name, M); IFT->addFnAttr(llvm::Attribute::OptimizeNone); IFT->addFnAttr(llvm::Attribute::NoInline); auto &C = M->getContext(); llvm::IRBuilder<> ir(llvm::BasicBlock::Create(C, "", IFT)); ir.CreateRetVoid(); } auto state_ptr = &*F->arg_begin(); llvm::CallInst::Create(IFT, {state_ptr}, "", B); }
static void AddRegStateTracer(llvm::BasicBlock *B) { auto F = B->getParent(); auto M = F->getParent(); auto IFT = ArchGetOrCreateRegStateTracer(M); auto state_ptr = &*F->arg_begin(); llvm::CallInst::Create(IFT, {state_ptr}, "", B); }
void Kontext::createInteger() { std::vector<llvm::Type*> types(1, llvm::Type::getInt32Ty(this->getContext())); auto leType = llvm::StructType::create(this->getContext(), makeArrayRef(types), "Integer", false); std::vector<KObjectAttr> attributes(1, KObjectAttr("innerInt", nullptr)); auto o = KObject("Integer", leType, std::move(attributes)); this->addType("Integer", o); this->setObject("Integer"); std::vector<llvm::Type *> argTypes; argTypes.emplace_back(leType->getPointerTo()); argTypes.emplace_back(llvm::Type::getInt32Ty(this->getContext())); auto fType = llvm::FunctionType::get(llvm::Type::getVoidTy(this->getContext()), makeArrayRef(argTypes), false); auto func = llvm::Function::Create(fType, llvm::GlobalValue::ExternalLinkage, "_KN7IntegerC1E", this->module()); auto bBlock = llvm::BasicBlock::Create(this->getContext(), "entry", func, nullptr); this->pushBlock(bBlock); auto argIter = func->arg_begin(); auto obj = InstanciatedObject::Create("self", &(*argIter), *this, KCallArgList()); (*argIter).setName("self"); obj->store(*this); argIter++; (*argIter).setName("leValue"); auto leBlock = KBlock(); auto decl = std::make_shared<KVarDecl>("innerInt", llvm::Type::getInt32Ty(this->getContext()), &(*argIter)); decl->setInObj(); leBlock.emplaceStatement(std::move(decl)); leBlock.codeGen(*this); llvm::ReturnInst::Create(this->getContext(), _blocks.top()->returnValue, bBlock); this->popBlock(); this->createIntegerAdd(); this->createIntegerPrint(); this->popObject(); }
llvm::Function* LocalStack::getStackPrepareFunc() { static const auto c_funcName = "stack.prepare"; if (auto func = getModule()->getFunction(c_funcName)) return func; llvm::Type* argsTys[] = {Type::WordPtr, Type::Size->getPointerTo(), Type::Size, Type::Size, Type::Size, Type::BytePtr}; auto func = llvm::Function::Create(llvm::FunctionType::get(Type::WordPtr, argsTys, false), llvm::Function::PrivateLinkage, c_funcName, getModule()); func->setDoesNotThrow(); func->setDoesNotAccessMemory(1); func->setDoesNotAlias(2); func->setDoesNotCapture(2); auto checkBB = llvm::BasicBlock::Create(func->getContext(), "Check", func); auto updateBB = llvm::BasicBlock::Create(func->getContext(), "Update", func); auto outOfStackBB = llvm::BasicBlock::Create(func->getContext(), "OutOfStack", func); auto iter = func->arg_begin(); llvm::Argument* base = &(*iter++); base->setName("base"); llvm::Argument* sizePtr = &(*iter++); sizePtr->setName("size.ptr"); llvm::Argument* min = &(*iter++); min->setName("min"); llvm::Argument* max = &(*iter++); max->setName("max"); llvm::Argument* diff = &(*iter++); diff->setName("diff"); llvm::Argument* jmpBuf = &(*iter); jmpBuf->setName("jmpBuf"); InsertPointGuard guard{m_builder}; m_builder.SetInsertPoint(checkBB); auto sizeAlignment = getModule()->getDataLayout().getABITypeAlignment(Type::Size); auto size = m_builder.CreateAlignedLoad(sizePtr, sizeAlignment, "size"); auto minSize = m_builder.CreateAdd(size, min, "size.min", false, true); auto maxSize = m_builder.CreateAdd(size, max, "size.max", true, true); auto minOk = m_builder.CreateICmpSGE(minSize, m_builder.getInt64(0), "ok.min"); auto maxOk = m_builder.CreateICmpULE(maxSize, m_builder.getInt64(RuntimeManager::stackSizeLimit), "ok.max"); auto ok = m_builder.CreateAnd(minOk, maxOk, "ok"); m_builder.CreateCondBr(ok, updateBB, outOfStackBB, Type::expectTrue); m_builder.SetInsertPoint(updateBB); auto newSize = m_builder.CreateNSWAdd(size, diff, "size.next"); m_builder.CreateAlignedStore(newSize, sizePtr, sizeAlignment); auto sp = m_builder.CreateGEP(base, size, "sp"); m_builder.CreateRet(sp); m_builder.SetInsertPoint(outOfStackBB); auto longjmp = llvm::Intrinsic::getDeclaration(getModule(), llvm::Intrinsic::eh_sjlj_longjmp); m_builder.CreateCall(longjmp, {jmpBuf}); m_builder.CreateUnreachable(); return func; }
void FunctionState::addByValArea(unsigned ArgumentNumber, stateptr_ty Address, std::size_t Size) { auto const Fn = getFunction(); assert(ArgumentNumber < Fn->arg_size()); auto ArgIt = Fn->arg_begin(); std::advance(ArgIt, ArgumentNumber); ParamByVals.emplace_back(&*ArgIt, MemoryArea(Address, Size)); }
bool ParseFunctionCall(FunctionEvent *Event, BinaryOperator *Bop, vector<ValueDecl*>& References, ASTContext& Ctx) { // TODO: better distinguishing between callee and/or caller Event->set_context(FunctionEvent::Callee); // Since we might care about the return value, we must instrument exiting // the function rather than entering it. Event->set_direction(FunctionEvent::Exit); Expr *LHS = Bop->getLHS(); bool LHSisICE = LHS->isIntegerConstantExpr(Ctx); Expr *RHS = Bop->getRHS(); if (!(LHSisICE ^ RHS->isIntegerConstantExpr(Ctx))) { Report("One of {LHS,RHS} must be ICE", Bop->getLocStart(), Ctx) << Bop->getSourceRange(); return false; } Expr *RetVal = (LHSisICE ? LHS : RHS); Expr *FnCall = (LHSisICE ? RHS : LHS); if (!ParseArgument(Event->mutable_expectedreturnvalue(), RetVal, References, Ctx)) return false; auto FnCallExpr = dyn_cast<CallExpr>(FnCall); if (!FnCallExpr) { Report("Not a function call", FnCall->getLocStart(), Ctx) << FnCall->getSourceRange(); return false; } auto Fn = FnCallExpr->getDirectCallee(); if (!Fn) { Report("Not a direct function call", FnCallExpr->getLocStart(), Ctx) << FnCallExpr->getSourceRange(); return false; } if (!ParseFunctionRef(Event->mutable_function(), Fn, Ctx)) return false; for (auto I = FnCallExpr->arg_begin(); I != FnCallExpr->arg_end(); ++I) { if (!ParseArgument(Event->add_argument(), I->IgnoreImplicit(), References, Ctx)) return false; } return true; }
void visitInvokeInst(InvokeInst &I) { Function* target = I.getCalledFunction(); if (target == NULL) { anyUnknown = true; return; } if (isInternal(target)) { if (used != NULL) used->push(target); } else { interface->call(target->getName(), arg_begin(I), arg_end(I)); } this->visitInstruction(I); }
void Kontext::createIntegerPrint() { std::vector<llvm::Type *> argTypes; argTypes.emplace_back(this->_types["Integer"].type()->getPointerTo()); auto fType = llvm::FunctionType::get(llvm::Type::getVoidTy(this->getContext()), makeArrayRef(argTypes), false); auto func = llvm::Function::Create(fType, llvm::GlobalValue::ExternalLinkage, "_KN7Integer5printE", this->module()); auto bBlock = llvm::BasicBlock::Create(this->getContext(), "entry", func, nullptr); this->pushBlock(bBlock); auto argIter = func->arg_begin(); auto obj = InstanciatedObject::Create("self", &(*argIter), *this, KCallArgList()); (*argIter).setName("self"); obj->store(*this); const char *constValue = "%d\n"; llvm::Constant *format_const = llvm::ConstantDataArray::getString(this->getContext(), constValue); llvm::GlobalVariable *var = new llvm::GlobalVariable( *_module, llvm::ArrayType::get(llvm::IntegerType::get(this->getContext(), 8), strlen(constValue)+1), true, llvm::GlobalValue::PrivateLinkage, format_const, ".str"); llvm::Constant *zero = llvm::Constant::getNullValue(llvm::IntegerType::getInt32Ty(this->getContext())); std::vector<llvm::Constant*> indices; indices.push_back(zero); indices.push_back(zero); llvm::Constant *var_ref = llvm::ConstantExpr::getGetElementPtr( llvm::ArrayType::get(llvm::IntegerType::get(this->getContext(), 8), 4), var, indices); std::vector<llvm::Value*> args; args.push_back(var_ref); args.push_back(new llvm::LoadInst(obj->getStructElem(*this, "innerInt"), "", false, this->currentBlock())); auto call = llvm::CallInst::Create(_module->getFunction("printf"), llvm::makeArrayRef(args), "", bBlock); llvm::ReturnInst::Create(this->getContext(), bBlock); this->popBlock(); }
void Kontext::createIntegerAdd() { std::vector<llvm::Type *> argTypes; argTypes.emplace_back(this->_types["Integer"].type()->getPointerTo()); argTypes.emplace_back(this->_types["Integer"].type()->getPointerTo()); auto fType = llvm::FunctionType::get(this->_types["Integer"].type(), makeArrayRef(argTypes), false); auto func = llvm::Function::Create(fType, llvm::GlobalValue::ExternalLinkage, "_KN7IntegerplE", this->module()); auto bBlock = llvm::BasicBlock::Create(this->getContext(), "entry", func, nullptr); this->pushBlock(bBlock); auto argIter = func->arg_begin(); auto obj = InstanciatedObject::Create("self", &(*argIter), *this, KCallArgList()); (*argIter).setName("self"); obj->store(*this); argIter++; auto rhs = InstanciatedObject::Create("rhs", &(*argIter), *this, KCallArgList()); (*argIter).setName("rhs"); rhs->store(*this); auto binOp = llvm::BinaryOperator::Create( llvm::Instruction::Add, new llvm::LoadInst(obj->getStructElem(*this, "innerInt"), "", false, this->currentBlock()), new llvm::LoadInst(rhs->getStructElem(*this, "innerInt"), "", false, this->currentBlock()), "", this->_blocks.top()->block); auto type = this->type_of("Integer"); auto object = InstanciatedObject::Create("temp", type, *this, KCallArgList(1, KCallArg("Integer", binOp))); llvm::ReturnInst::Create(this->getContext(), new llvm::LoadInst(object->get(*this), "", false, this->currentBlock()), bBlock); this->popBlock(); }
llvm::Function* JitCodeSpecializer::_clone_root_function(SpecializationContext& context, const llvm::Function& function) const { // First create an appropriate function declaration that we can later clone the function body into const auto cloned_function = _create_function_declaration(context, function, _specialized_root_function_name); // We create a mapping from values in the source module to values in the target module. // This mapping is passed to LLVM's cloning function and ensures that all references to other functions, global // variables, and function arguments refer to values in the target module and NOT in the source module. // This way the module stays self-contained and valid. // Map functions called _visit<const llvm::Function>(function, [&](const auto& fn) { if (!context.llvm_value_map.count(&fn)) { context.llvm_value_map[&fn] = _create_function_declaration(context, fn, fn.getName()); } }); // Map global variables accessed _visit<const llvm::GlobalVariable>(function, [&](auto& global) { if (!context.llvm_value_map.count(&global)) { context.llvm_value_map[&global] = _clone_global_variable(context, global); } }); // Map function arguments auto arg = function.arg_begin(); auto cloned_arg = cloned_function->arg_begin(); for (; arg != function.arg_end() && cloned_arg != cloned_function->arg_end(); ++arg, ++cloned_arg) { cloned_arg->setName(arg->getName()); context.llvm_value_map[arg] = cloned_arg; } // Instruct LLVM to perform the actual cloning llvm::SmallVector<llvm::ReturnInst*, 8> returns; llvm::CloneFunctionInto(cloned_function, &function, context.llvm_value_map, true, returns); return cloned_function; }
iterator_range<Prototype::arg_iterator> Prototype::args() { return iterator_range<Prototype::arg_iterator>(arg_begin(), arg_end()); }
void JitCodeSpecializer::_inline_function_calls(SpecializationContext& context) const { // This method implements the main code fusion functionality. // It works as follows: // Throughout the fusion process, a working list of call sites (i.e., function calls) is maintained. // This queue is initialized with all call instructions from the initial root function. // Each call site is then handled in one of the following ways: // - Direct Calls: // Direct calls explicitly encode the name of the called function. If a function implementation can be located in // the JitRepository repository by that name the function is inlined (i.e., the call instruction is replaced with // the function body from the repository. Calls to functions not available in the repository (e.g., calls into the // C++ standard library) require a corresponding function declaration to be inserted into the module. // This is necessary to avoid any unresolved function calls, which would invalidate the module and cause the // just-in-time compiler to reject it. With an appropriate declaration, however, the just-in-time compiler is able // to locate the machine code of these functions in the Hyrise binary when compiling and linking the module. // - Indirect Calls // Indirect calls do not reference their called function by name. Instead, the address of the target function is // either loaded from memory or computed. The code specialization unit analyzes the instructions surrounding the // call and tries to infer the name of the called function. If this succeeds, the call is handled like a regular // direct call. If not, no specialization of the call can be performed. // In this case, the target address computed for the call at runtime will point to the correct machine code in the // Hyrise binary and the pipeline will still execute successfully. // // Whenever a call instruction is replaced by the corresponding function body, the inlining algorithm reports back // any new call sites encountered in the process. These are pushed to the end of the working queue. Once this queue // is empty, the operator fusion process is completed. std::queue<llvm::CallSite> call_sites; // Initialize the call queue with all call and invoke instructions from the root function. _visit<llvm::CallInst>(*context.root_function, [&](llvm::CallInst& inst) { call_sites.push(llvm::CallSite(&inst)); }); _visit<llvm::InvokeInst>(*context.root_function, [&](llvm::InvokeInst& inst) { call_sites.push(llvm::CallSite(&inst)); }); while (!call_sites.empty()) { auto& call_site = call_sites.front(); // Resolve indirect (virtual) function calls if (call_site.isIndirectCall()) { const auto called_value = call_site.getCalledValue(); // Get the runtime location of the called function (i.e., the compiled machine code of the function) const auto called_runtime_value = std::dynamic_pointer_cast<const JitKnownRuntimePointer>(GetRuntimePointerForValue(called_value, context)); if (called_runtime_value && called_runtime_value->is_valid()) { // LLVM implements virtual function calls via virtual function tables // (see https://en.wikipedia.org/wiki/Virtual_method_table). // The resolution of virtual calls starts with a pointer to the object that the virtual call should be performed // on. This object contains a pointer to its vtable (usually at offset 0). The vtable contains a list of // function pointers (one for each virtual function). // In the LLVM bitcode, a virtual call is resolved through a number of LLVM pointer operations: // - a load instruction to dereference the vtable pointer in the object // - a getelementptr instruction to select the correct virtual function in the vtable // - another load instruction to dereference the virtual function pointer from the table // When analyzing a virtual call, the code specializer works backwards starting from the pointer to the called // function (called_runtime_value). // Moving "up" one level (undoing one load operation) yields the pointer to the function pointer in the // vtable. The total_offset of this pointer corresponds to the index in the vtable that is used for the virtual // call. // Moving "up" another level yields the pointer to the object that the virtual function is called on. Using RTTI // the class name of the object can be determined. // These two pieces of information (the class name and the vtable index) are sufficient to unambiguously // identify the called virtual function and locate it in the bitcode repository. // Determine the vtable index and class name of the virtual call const auto vtable_index = called_runtime_value->up().total_offset() / context.module->getDataLayout().getPointerSize(); const auto instance = reinterpret_cast<JitRTTIHelper*>(called_runtime_value->up().up().base().address()); const auto class_name = typeid(*instance).name(); // If the called function can be located in the repository, the virtual call is replaced by a direct call to // that function. if (const auto repo_function = _repository.get_vtable_entry(class_name, vtable_index)) { call_site.setCalledFunction(repo_function); } } else { // The virtual call could not be resolved. There is nothing we can inline so we move on. call_sites.pop(); continue; } } auto function = call_site.getCalledFunction(); // ignore invalid functions if (!function) { call_sites.pop(); continue; } const auto function_name = function->getName().str(); const auto function_has_opossum_namespace = boost::starts_with(function_name, "_ZNK7opossum") || boost::starts_with(function_name, "_ZN7opossum"); // A note about "__clang_call_terminate": // __clang_call_terminate is generated / used internally by clang to call the std::terminate function when exception // handling fails. For some unknown reason this function cannot be resolved in the Hyrise binary when jit-compiling // bitcode that uses the function. The function is, however, present in the bitcode repository. // We thus always inline this function from the repository. // All function that are not in the opossum:: namespace are not considered for inlining. Instead, a function // declaration (without a function body) is created. if (!function_has_opossum_namespace && function_name != "__clang_call_terminate") { context.llvm_value_map[function] = _create_function_declaration(context, *function, function->getName()); call_sites.pop(); continue; } // Determine whether the first function argument is a pointer/reference to an object, but the runtime location // for this object cannot be determined. // This is the case for member functions that are called within a loop body. These functions may be called on // different objects in different loop iterations. // If two specialization passes are performed, these functions should be inlined after loop unrolling has been // performed (i.e., during the second pass). auto first_argument = call_site.arg_begin(); auto first_argument_cannot_be_resolved = first_argument->get()->getType()->isPointerTy() && !GetRuntimePointerForValue(first_argument->get(), context)->is_valid(); if (first_argument_cannot_be_resolved && function_name != "__clang_call_terminate") { call_sites.pop(); continue; } // We create a mapping from values in the source module to values in the target module. // This mapping is passed to LLVM's cloning function and ensures that all references to other functions, global // variables, and function arguments refer to values in the target module and NOT in the source module. // This way the module stays self-contained and valid. context.llvm_value_map.clear(); // Map called functions _visit<const llvm::Function>(*function, [&](const auto& fn) { if (fn.isDeclaration() && !context.llvm_value_map.count(&fn)) { context.llvm_value_map[&fn] = _create_function_declaration(context, fn, fn.getName()); } }); // Map global variables _visit<const llvm::GlobalVariable>(*function, [&](auto& global) { if (!context.llvm_value_map.count(&global)) { context.llvm_value_map[&global] = _clone_global_variable(context, global); } }); // Map function arguments auto function_arg = function->arg_begin(); auto call_arg = call_site.arg_begin(); for (; function_arg != function->arg_end() && call_arg != call_site.arg_end(); ++function_arg, ++call_arg) { context.llvm_value_map[function_arg] = call_arg->get(); } // Instruct LLVM to perform the function inlining and push all new call sites to the working queue llvm::InlineFunctionInfo info; if (InlineFunction(call_site, info, nullptr, false, nullptr, context)) { for (const auto& new_call_site : info.InlinedCallSites) { call_sites.push(new_call_site); } } // clear runtime_value_map to allow multiple inlining of same function auto runtime_this = context.runtime_value_map[context.root_function->arg_begin()]; context.runtime_value_map.clear(); context.runtime_value_map[context.root_function->arg_begin()] = runtime_this; call_sites.pop(); } }
Continuation* CodeGen::emit_spawn(Continuation* continuation) { assert(continuation->num_args() >= SPAWN_NUM_ARGS && "required arguments are missing"); auto kernel = continuation->arg(SPAWN_ARG_BODY)->as<Global>()->init()->as_continuation(); const size_t num_kernel_args = continuation->num_args() - SPAWN_NUM_ARGS; // build parallel-function signature Array<llvm::Type*> par_args(num_kernel_args); for (size_t i = 0; i < num_kernel_args; ++i) { auto type = continuation->arg(i + SPAWN_NUM_ARGS)->type(); par_args[i] = convert(type); } // fetch values and create a unified struct which contains all values (closure) auto closure_type = convert(world_.tuple_type(continuation->arg_fn_type()->ops().skip_front(SPAWN_NUM_ARGS))); llvm::Value* closure = nullptr; if (closure_type->isStructTy()) { closure = llvm::UndefValue::get(closure_type); for (size_t i = 0; i < num_kernel_args; ++i) closure = irbuilder_.CreateInsertValue(closure, lookup(continuation->arg(i + SPAWN_NUM_ARGS)), unsigned(i)); } else { closure = lookup(continuation->arg(0 + SPAWN_NUM_ARGS)); } // allocate closure object and write values into it auto ptr = irbuilder_.CreateAlloca(closure_type, nullptr); irbuilder_.CreateStore(closure, ptr, false); // create wrapper function and call the runtime // wrapper(void* closure) llvm::Type* wrapper_arg_types[] = { irbuilder_.getInt8PtrTy(0) }; auto wrapper_ft = llvm::FunctionType::get(irbuilder_.getVoidTy(), wrapper_arg_types, false); auto wrapper_name = kernel->unique_name() + "_spawn_thread"; auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft); auto call = runtime_->spawn_thread(ptr, wrapper); // set insert point to the wrapper function auto old_bb = irbuilder_.GetInsertBlock(); auto bb = llvm::BasicBlock::Create(*context_, wrapper_name, wrapper); irbuilder_.SetInsertPoint(bb); // extract all arguments from the closure auto wrapper_args = wrapper->arg_begin(); auto load_ptr = irbuilder_.CreateBitCast(&*wrapper_args, llvm::PointerType::get(closure_type, 0)); auto val = irbuilder_.CreateLoad(load_ptr); std::vector<llvm::Value*> target_args(num_kernel_args); if (val->getType()->isStructTy()) { for (size_t i = 0; i < num_kernel_args; ++i) target_args[i] = irbuilder_.CreateExtractValue(val, { unsigned(i) }); } else { target_args[0] = val; } // call kernel body auto par_type = llvm::FunctionType::get(irbuilder_.getVoidTy(), llvm_ref(par_args), false); auto kernel_par_func = (llvm::Function*)module_->getOrInsertFunction(kernel->unique_name(), par_type); irbuilder_.CreateCall(kernel_par_func, target_args); irbuilder_.CreateRetVoid(); // restore old insert point irbuilder_.SetInsertPoint(old_bb); // bind parameter of continuation to received handle auto cont = continuation->arg(SPAWN_ARG_RETURN)->as_continuation(); emit_result_phi(cont->param(1), call); return cont; }
Continuation* CodeGen::emit_parallel(Continuation* continuation) { // arguments assert(continuation->num_args() >= PAR_NUM_ARGS && "required arguments are missing"); auto num_threads = lookup(continuation->arg(PAR_ARG_NUMTHREADS)); auto lower = lookup(continuation->arg(PAR_ARG_LOWER)); auto upper = lookup(continuation->arg(PAR_ARG_UPPER)); auto kernel = continuation->arg(PAR_ARG_BODY)->as<Global>()->init()->as_continuation(); const size_t num_kernel_args = continuation->num_args() - PAR_NUM_ARGS; // build parallel-function signature Array<llvm::Type*> par_args(num_kernel_args + 1); par_args[0] = irbuilder_.getInt32Ty(); // loop index for (size_t i = 0; i < num_kernel_args; ++i) { auto type = continuation->arg(i + PAR_NUM_ARGS)->type(); par_args[i + 1] = convert(type); } // fetch values and create a unified struct which contains all values (closure) auto closure_type = convert(world_.tuple_type(continuation->arg_fn_type()->ops().skip_front(PAR_NUM_ARGS))); llvm::Value* closure = llvm::UndefValue::get(closure_type); if (num_kernel_args != 1) { for (size_t i = 0; i < num_kernel_args; ++i) closure = irbuilder_.CreateInsertValue(closure, lookup(continuation->arg(i + PAR_NUM_ARGS)), unsigned(i)); } else { closure = lookup(continuation->arg(PAR_NUM_ARGS)); } // allocate closure object and write values into it auto ptr = emit_alloca(closure_type, "parallel_closure"); irbuilder_.CreateStore(closure, ptr, false); // create wrapper function and call the runtime // wrapper(void* closure, int lower, int upper) llvm::Type* wrapper_arg_types[] = { irbuilder_.getInt8PtrTy(0), irbuilder_.getInt32Ty(), irbuilder_.getInt32Ty() }; auto wrapper_ft = llvm::FunctionType::get(irbuilder_.getVoidTy(), wrapper_arg_types, false); auto wrapper_name = kernel->unique_name() + "_parallel_for"; auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft); runtime_->parallel_for(num_threads, lower, upper, ptr, wrapper); // set insert point to the wrapper function auto old_bb = irbuilder_.GetInsertBlock(); auto bb = llvm::BasicBlock::Create(*context_, wrapper_name, wrapper); irbuilder_.SetInsertPoint(bb); // extract all arguments from the closure auto wrapper_args = wrapper->arg_begin(); auto load_ptr = irbuilder_.CreateBitCast(&*wrapper_args, llvm::PointerType::get(closure_type, 0)); auto val = irbuilder_.CreateLoad(load_ptr); std::vector<llvm::Value*> target_args(num_kernel_args + 1); if (num_kernel_args != 1) { for (size_t i = 0; i < num_kernel_args; ++i) target_args[i + 1] = irbuilder_.CreateExtractValue(val, { unsigned(i) }); } else { target_args[1] = val; } // create loop iterating over range: // for (int i=lower; i<upper; ++i) // body(i, <closure_elems>); auto wrapper_lower = &*(++wrapper_args); auto wrapper_upper = &*(++wrapper_args); create_loop(wrapper_lower, wrapper_upper, irbuilder_.getInt32(1), wrapper, [&](llvm::Value* counter) { // call kernel body target_args[0] = counter; // loop index auto par_type = llvm::FunctionType::get(irbuilder_.getVoidTy(), llvm_ref(par_args), false); auto kernel_par_func = (llvm::Function*)module_->getOrInsertFunction(kernel->unique_name(), par_type); irbuilder_.CreateCall(kernel_par_func, target_args); }); irbuilder_.CreateRetVoid(); // restore old insert point irbuilder_.SetInsertPoint(old_bb); return continuation->arg(PAR_ARG_RETURN)->as_continuation(); }