Exemplo n.º 1
0
const void* LLVMCodeGenerator::visit(ASTFunctionDef* node) {
	auto function = (llvm::Function*)node->proto->accept(this);
	
	// set names
	size_t i = 0;
	for (llvm::Function::arg_iterator ai = function->arg_begin(); ai != function->arg_end(); ++ai) {
		ai->setName(node->proto->arg_names[i]);
		++i;
	}

	// create the block
	llvm::BasicBlock* entry_block = llvm::BasicBlock::Create(_context, "entry", function);
	llvm::IRBuilderBase::InsertPoint ip = _builder.saveIP();
	_builder.SetInsertPoint(entry_block);

	// load arguments
	// we copy them to allow them to be lvalues
	// TODO: make sure llvm can optimize the unnecessary copies out?
	i = 0;
	for (llvm::Function::arg_iterator ai = function->arg_begin(); ai != function->arg_end(); ++ai) {
		llvm::AllocaInst* alloca = _builder.CreateAlloca(ai->getType(), 0, node->arg_prefix + node->proto->arg_names[i]);
		_named_values[node->arg_prefix + node->proto->arg_names[i]] = alloca;
		_builder.CreateStore(ai, alloca);
		++i;
	}

	// build the body

	llvm::AllocaInst* return_alloca = nullptr;
	if (node->proto->func->return_type()->type() != SLTypeTypeVoid) {
		return_alloca = _builder.CreateAlloca(_llvm_type(node->proto->func->return_type()), nullptr, "ret");
	}

	llvm::BasicBlock* return_block = llvm::BasicBlock::Create(_context, "return", function);

	FunctionContext context(node->proto->func, function, return_alloca, return_block);
	auto prev_function_context = _current_function_context;
	_current_function_context = context;

	_build_basic_block(entry_block, node->body, return_block);

	_current_function_context = prev_function_context;

	_builder.SetInsertPoint(return_block);
	if (return_alloca) {
		_builder.CreateRet(_builder.CreateLoad(return_alloca));
	} else {
		_builder.CreateRetVoid();
	}

	// end the block
	_builder.restoreIP(ip);
	
	return nullptr;
}
Exemplo n.º 2
0
static void CreateInstrBreakpoint(llvm::BasicBlock *B, VA pc) {
  auto F = B->getParent();
  auto M = F->getParent();

  std::stringstream ss;
  ss << "breakpoint_" << std::hex << pc;
  auto instr_func_name = ss.str();

  auto IFT = M->getFunction(instr_func_name);
  if (!IFT) {

    IFT = llvm::Function::Create(
        LiftedFunctionType(), llvm::GlobalValue::ExternalLinkage,
        instr_func_name, M);

    IFT->addFnAttr(llvm::Attribute::OptimizeNone);
    IFT->addFnAttr(llvm::Attribute::NoInline);

    auto &C = M->getContext();
    llvm::IRBuilder<> ir(llvm::BasicBlock::Create(C, "", IFT));
    ir.CreateRetVoid();
  }

  auto state_ptr = &*F->arg_begin();
  llvm::CallInst::Create(IFT, {state_ptr}, "", B);
}
Exemplo n.º 3
0
static void AddRegStateTracer(llvm::BasicBlock *B) {
  auto F = B->getParent();
  auto M = F->getParent();
  auto IFT = ArchGetOrCreateRegStateTracer(M);

  auto state_ptr = &*F->arg_begin();
  llvm::CallInst::Create(IFT, {state_ptr}, "", B);
}
Exemplo n.º 4
0
void Kontext::createInteger() {
	std::vector<llvm::Type*> types(1, llvm::Type::getInt32Ty(this->getContext()));
	auto leType = llvm::StructType::create(this->getContext(), makeArrayRef(types), "Integer", false);

    std::vector<KObjectAttr> attributes(1, KObjectAttr("innerInt", nullptr));
    auto o = KObject("Integer", leType, std::move(attributes));
    this->addType("Integer", o);

    this->setObject("Integer");

    std::vector<llvm::Type *> argTypes;

    argTypes.emplace_back(leType->getPointerTo());
    argTypes.emplace_back(llvm::Type::getInt32Ty(this->getContext()));

    auto fType = llvm::FunctionType::get(llvm::Type::getVoidTy(this->getContext()), makeArrayRef(argTypes), false);
    auto func = llvm::Function::Create(fType,
            llvm::GlobalValue::ExternalLinkage,
            "_KN7IntegerC1E",
            this->module());

    auto bBlock = llvm::BasicBlock::Create(this->getContext(),
                    "entry",
                    func,
                    nullptr);

    this->pushBlock(bBlock);

	auto argIter = func->arg_begin();    

    auto obj = InstanciatedObject::Create("self", &(*argIter), *this, KCallArgList());
    (*argIter).setName("self");
    obj->store(*this);
	argIter++;

	(*argIter).setName("leValue");


    auto leBlock = KBlock();
    auto decl = std::make_shared<KVarDecl>("innerInt", llvm::Type::getInt32Ty(this->getContext()), &(*argIter));
    decl->setInObj();
    leBlock.emplaceStatement(std::move(decl));

    leBlock.codeGen(*this);

    llvm::ReturnInst::Create(this->getContext(),
                _blocks.top()->returnValue,
                bBlock);

    this->popBlock();

    this->createIntegerAdd();
    this->createIntegerPrint();

	this->popObject();
}
Exemplo n.º 5
0
llvm::Function* LocalStack::getStackPrepareFunc()
{
	static const auto c_funcName = "stack.prepare";
	if (auto func = getModule()->getFunction(c_funcName))
		return func;

	llvm::Type* argsTys[] = {Type::WordPtr, Type::Size->getPointerTo(), Type::Size, Type::Size, Type::Size, Type::BytePtr};
	auto func = llvm::Function::Create(llvm::FunctionType::get(Type::WordPtr, argsTys, false), llvm::Function::PrivateLinkage, c_funcName, getModule());
	func->setDoesNotThrow();
	func->setDoesNotAccessMemory(1);
	func->setDoesNotAlias(2);
	func->setDoesNotCapture(2);

	auto checkBB = llvm::BasicBlock::Create(func->getContext(), "Check", func);
	auto updateBB = llvm::BasicBlock::Create(func->getContext(), "Update", func);
	auto outOfStackBB = llvm::BasicBlock::Create(func->getContext(), "OutOfStack", func);

	auto iter = func->arg_begin();
	llvm::Argument* base = &(*iter++);
	base->setName("base");
	llvm::Argument* sizePtr = &(*iter++);
	sizePtr->setName("size.ptr");
	llvm::Argument* min = &(*iter++);
	min->setName("min");
	llvm::Argument* max = &(*iter++);
	max->setName("max");
	llvm::Argument* diff = &(*iter++);
	diff->setName("diff");
	llvm::Argument* jmpBuf = &(*iter);
	jmpBuf->setName("jmpBuf");

	InsertPointGuard guard{m_builder};
	m_builder.SetInsertPoint(checkBB);
	auto sizeAlignment = getModule()->getDataLayout().getABITypeAlignment(Type::Size);
	auto size = m_builder.CreateAlignedLoad(sizePtr, sizeAlignment, "size");
	auto minSize = m_builder.CreateAdd(size, min, "size.min", false, true);
	auto maxSize = m_builder.CreateAdd(size, max, "size.max", true, true);
	auto minOk = m_builder.CreateICmpSGE(minSize, m_builder.getInt64(0), "ok.min");
	auto maxOk = m_builder.CreateICmpULE(maxSize, m_builder.getInt64(RuntimeManager::stackSizeLimit), "ok.max");
	auto ok = m_builder.CreateAnd(minOk, maxOk, "ok");
	m_builder.CreateCondBr(ok, updateBB, outOfStackBB, Type::expectTrue);

	m_builder.SetInsertPoint(updateBB);
	auto newSize = m_builder.CreateNSWAdd(size, diff, "size.next");
	m_builder.CreateAlignedStore(newSize, sizePtr, sizeAlignment);
	auto sp = m_builder.CreateGEP(base, size, "sp");
	m_builder.CreateRet(sp);

	m_builder.SetInsertPoint(outOfStackBB);
	auto longjmp = llvm::Intrinsic::getDeclaration(getModule(), llvm::Intrinsic::eh_sjlj_longjmp);
	m_builder.CreateCall(longjmp, {jmpBuf});
	m_builder.CreateUnreachable();

	return func;
}
Exemplo n.º 6
0
void FunctionState::addByValArea(unsigned ArgumentNumber,
                                 stateptr_ty Address,
                                 std::size_t Size)
{
  auto const Fn = getFunction();
  assert(ArgumentNumber < Fn->arg_size());
  
  auto ArgIt = Fn->arg_begin();
  std::advance(ArgIt, ArgumentNumber);
  
  ParamByVals.emplace_back(&*ArgIt, MemoryArea(Address, Size));
}
Exemplo n.º 7
0
bool ParseFunctionCall(FunctionEvent *Event, BinaryOperator *Bop,
                       vector<ValueDecl*>& References,
                       ASTContext& Ctx) {

  // TODO: better distinguishing between callee and/or caller
  Event->set_context(FunctionEvent::Callee);

  // Since we might care about the return value, we must instrument exiting
  // the function rather than entering it.
  Event->set_direction(FunctionEvent::Exit);

  Expr *LHS = Bop->getLHS();
  bool LHSisICE = LHS->isIntegerConstantExpr(Ctx);

  Expr *RHS = Bop->getRHS();

  if (!(LHSisICE ^ RHS->isIntegerConstantExpr(Ctx))) {
    Report("One of {LHS,RHS} must be ICE", Bop->getLocStart(), Ctx)
      << Bop->getSourceRange();
    return false;
  }

  Expr *RetVal = (LHSisICE ? LHS : RHS);
  Expr *FnCall = (LHSisICE ? RHS : LHS);
  if (!ParseArgument(Event->mutable_expectedreturnvalue(), RetVal, References,
                     Ctx))
    return false;

  auto FnCallExpr = dyn_cast<CallExpr>(FnCall);
  if (!FnCallExpr) {
    Report("Not a function call", FnCall->getLocStart(), Ctx)
      << FnCall->getSourceRange();
    return false;
  }

  auto Fn = FnCallExpr->getDirectCallee();
  if (!Fn) {
    Report("Not a direct function call", FnCallExpr->getLocStart(), Ctx)
      << FnCallExpr->getSourceRange();
    return false;
  }

  if (!ParseFunctionRef(Event->mutable_function(), Fn, Ctx)) return false;

  for (auto I = FnCallExpr->arg_begin(); I != FnCallExpr->arg_end(); ++I) {
    if (!ParseArgument(Event->add_argument(), I->IgnoreImplicit(), References,
                       Ctx))
      return false;
  }

  return true;
}
Exemplo n.º 8
0
 void
 visitInvokeInst(InvokeInst &I)
 {
   Function* target = I.getCalledFunction();
   if (target == NULL) {
     anyUnknown = true;
     return;
   }
   if (isInternal(target)) {
     if (used != NULL) used->push(target);
   } else {
     interface->call(target->getName(), arg_begin(I), arg_end(I));
   }
   this->visitInstruction(I);
 }
Exemplo n.º 9
0
void Kontext::createIntegerPrint() {
    std::vector<llvm::Type *> argTypes;

    argTypes.emplace_back(this->_types["Integer"].type()->getPointerTo());

    auto fType = llvm::FunctionType::get(llvm::Type::getVoidTy(this->getContext()), makeArrayRef(argTypes), false);
    auto func = llvm::Function::Create(fType,
            llvm::GlobalValue::ExternalLinkage,
            "_KN7Integer5printE",
            this->module());

    auto bBlock = llvm::BasicBlock::Create(this->getContext(),
                    "entry",
                    func,
                    nullptr);

    this->pushBlock(bBlock);

    auto argIter = func->arg_begin();

    auto obj = InstanciatedObject::Create("self", &(*argIter), *this, KCallArgList());
    (*argIter).setName("self");
    obj->store(*this);

    const char *constValue = "%d\n";
    llvm::Constant *format_const = llvm::ConstantDataArray::getString(this->getContext(), constValue);
    llvm::GlobalVariable *var =
        new llvm::GlobalVariable(
            *_module, llvm::ArrayType::get(llvm::IntegerType::get(this->getContext(), 8), strlen(constValue)+1),
            true, llvm::GlobalValue::PrivateLinkage, format_const, ".str");
    llvm::Constant *zero =
        llvm::Constant::getNullValue(llvm::IntegerType::getInt32Ty(this->getContext()));

    std::vector<llvm::Constant*> indices;
    indices.push_back(zero);
    indices.push_back(zero);
    llvm::Constant *var_ref = llvm::ConstantExpr::getGetElementPtr(
    llvm::ArrayType::get(llvm::IntegerType::get(this->getContext(), 8), 4),
        var, indices);

    std::vector<llvm::Value*> args;
    args.push_back(var_ref);
    args.push_back(new llvm::LoadInst(obj->getStructElem(*this, "innerInt"), "", false, this->currentBlock()));

    auto call = llvm::CallInst::Create(_module->getFunction("printf"), llvm::makeArrayRef(args), "", bBlock);
    llvm::ReturnInst::Create(this->getContext(), bBlock);
    this->popBlock();
}
Exemplo n.º 10
0
void Kontext::createIntegerAdd() {
	std::vector<llvm::Type *> argTypes;

    argTypes.emplace_back(this->_types["Integer"].type()->getPointerTo());
    argTypes.emplace_back(this->_types["Integer"].type()->getPointerTo());

    auto fType = llvm::FunctionType::get(this->_types["Integer"].type(), makeArrayRef(argTypes), false);
    auto func = llvm::Function::Create(fType,
            llvm::GlobalValue::ExternalLinkage,
            "_KN7IntegerplE",
            this->module());

    auto bBlock = llvm::BasicBlock::Create(this->getContext(),
                    "entry",
                    func,
                    nullptr);

    this->pushBlock(bBlock);

	auto argIter = func->arg_begin();    

    auto obj = InstanciatedObject::Create("self", &(*argIter), *this, KCallArgList());
    (*argIter).setName("self");
    obj->store(*this);
	argIter++;

    auto rhs = InstanciatedObject::Create("rhs", &(*argIter), *this, KCallArgList());
    (*argIter).setName("rhs");
    rhs->store(*this);

    auto binOp = llvm::BinaryOperator::Create(	llvm::Instruction::Add,
    											new llvm::LoadInst(obj->getStructElem(*this, "innerInt"), "", false, this->currentBlock()),
    											new llvm::LoadInst(rhs->getStructElem(*this, "innerInt"), "", false, this->currentBlock()),
    											"",	this->_blocks.top()->block);

    auto type = this->type_of("Integer");
    auto object = InstanciatedObject::Create("temp", type, *this, KCallArgList(1, KCallArg("Integer", binOp)));

    llvm::ReturnInst::Create(this->getContext(),
                new llvm::LoadInst(object->get(*this), "", false, this->currentBlock()),
                bBlock);

    this->popBlock();
}
Exemplo n.º 11
0
llvm::Function* JitCodeSpecializer::_clone_root_function(SpecializationContext& context,
                                                         const llvm::Function& function) const {
  // First create an appropriate function declaration that we can later clone the function body into
  const auto cloned_function = _create_function_declaration(context, function, _specialized_root_function_name);

  // We create a mapping from values in the source module to values in the target module.
  // This mapping is passed to LLVM's cloning function and ensures that all references to other functions, global
  // variables, and function arguments refer to values in the target module and NOT in the source module.
  // This way the module stays self-contained and valid.

  // Map functions called
  _visit<const llvm::Function>(function, [&](const auto& fn) {
    if (!context.llvm_value_map.count(&fn)) {
      context.llvm_value_map[&fn] = _create_function_declaration(context, fn, fn.getName());
    }
  });

  // Map global variables accessed
  _visit<const llvm::GlobalVariable>(function, [&](auto& global) {
    if (!context.llvm_value_map.count(&global)) {
      context.llvm_value_map[&global] = _clone_global_variable(context, global);
    }
  });

  // Map function arguments
  auto arg = function.arg_begin();
  auto cloned_arg = cloned_function->arg_begin();
  for (; arg != function.arg_end() && cloned_arg != cloned_function->arg_end(); ++arg, ++cloned_arg) {
    cloned_arg->setName(arg->getName());
    context.llvm_value_map[arg] = cloned_arg;
  }

  // Instruct LLVM to perform the actual cloning
  llvm::SmallVector<llvm::ReturnInst*, 8> returns;
  llvm::CloneFunctionInto(cloned_function, &function, context.llvm_value_map, true, returns);

  return cloned_function;
}
Exemplo n.º 12
0
iterator_range<Prototype::arg_iterator> Prototype::args() {
  return iterator_range<Prototype::arg_iterator>(arg_begin(), arg_end());
}
Exemplo n.º 13
0
void JitCodeSpecializer::_inline_function_calls(SpecializationContext& context) const {
  // This method implements the main code fusion functionality.
  // It works as follows:
  // Throughout the fusion process, a working list of call sites (i.e., function calls) is maintained.
  // This queue is initialized with all call instructions from the initial root function.
  // Each call site is then handled in one of the following ways:
  // - Direct Calls:
  //   Direct calls explicitly encode the name of the called function. If a function implementation can be located in
  //   the JitRepository repository by that name the function is inlined (i.e., the call instruction is replaced with
  //   the function body from the repository. Calls to functions not available in the repository (e.g., calls into the
  //   C++ standard library) require a corresponding function declaration to be inserted into the module.
  //   This is necessary to avoid any unresolved function calls, which would invalidate the module and cause the
  //   just-in-time compiler to reject it. With an appropriate declaration, however, the just-in-time compiler is able
  //   to locate the machine code of these functions in the Hyrise binary when compiling and linking the module.
  // - Indirect Calls
  //   Indirect calls do not reference their called function by name. Instead, the address of the target function is
  //   either loaded from memory or computed. The code specialization unit analyzes the instructions surrounding the
  //   call and tries to infer the name of the called function. If this succeeds, the call is handled like a regular
  //   direct call. If not, no specialization of the call can be performed.
  //   In this case, the target address computed for the call at runtime will point to the correct machine code in the
  //   Hyrise binary and the pipeline will still execute successfully.
  //
  // Whenever a call instruction is replaced by the corresponding function body, the inlining algorithm reports back
  // any new call sites encountered in the process. These are pushed to the end of the working queue. Once this queue
  // is empty, the operator fusion process is completed.

  std::queue<llvm::CallSite> call_sites;

  // Initialize the call queue with all call and invoke instructions from the root function.
  _visit<llvm::CallInst>(*context.root_function, [&](llvm::CallInst& inst) { call_sites.push(llvm::CallSite(&inst)); });
  _visit<llvm::InvokeInst>(*context.root_function,
                           [&](llvm::InvokeInst& inst) { call_sites.push(llvm::CallSite(&inst)); });

  while (!call_sites.empty()) {
    auto& call_site = call_sites.front();

    // Resolve indirect (virtual) function calls
    if (call_site.isIndirectCall()) {
      const auto called_value = call_site.getCalledValue();
      // Get the runtime location of the called function (i.e., the compiled machine code of the function)
      const auto called_runtime_value =
          std::dynamic_pointer_cast<const JitKnownRuntimePointer>(GetRuntimePointerForValue(called_value, context));
      if (called_runtime_value && called_runtime_value->is_valid()) {
        // LLVM implements virtual function calls via virtual function tables
        // (see https://en.wikipedia.org/wiki/Virtual_method_table).
        // The resolution of virtual calls starts with a pointer to the object that the virtual call should be performed
        // on. This object contains a pointer to its vtable (usually at offset 0). The vtable contains a list of
        // function pointers (one for each virtual function).
        // In the LLVM bitcode, a virtual call is resolved through a number of LLVM pointer operations:
        // - a load instruction to dereference the vtable pointer in the object
        // - a getelementptr instruction to select the correct virtual function in the vtable
        // - another load instruction to dereference the virtual function pointer from the table
        // When analyzing a virtual call, the code specializer works backwards starting from the pointer to the called
        // function (called_runtime_value).
        // Moving "up" one level (undoing one load operation) yields the pointer to the function pointer in the
        // vtable. The total_offset of this pointer corresponds to the index in the vtable that is used for the virtual
        // call.
        // Moving "up" another level yields the pointer to the object that the virtual function is called on. Using RTTI
        // the class name of the object can be determined.
        // These two pieces of information (the class name and the vtable index) are sufficient to unambiguously
        // identify the called virtual function and locate it in the bitcode repository.

        // Determine the vtable index and class name of the virtual call
        const auto vtable_index =
            called_runtime_value->up().total_offset() / context.module->getDataLayout().getPointerSize();
        const auto instance = reinterpret_cast<JitRTTIHelper*>(called_runtime_value->up().up().base().address());
        const auto class_name = typeid(*instance).name();

        // If the called function can be located in the repository, the virtual call is replaced by a direct call to
        // that function.
        if (const auto repo_function = _repository.get_vtable_entry(class_name, vtable_index)) {
          call_site.setCalledFunction(repo_function);
        }
      } else {
        // The virtual call could not be resolved. There is nothing we can inline so we move on.
        call_sites.pop();
        continue;
      }
    }

    auto function = call_site.getCalledFunction();
    // ignore invalid functions
    if (!function) {
      call_sites.pop();
      continue;
    }

    const auto function_name = function->getName().str();

    const auto function_has_opossum_namespace =
        boost::starts_with(function_name, "_ZNK7opossum") || boost::starts_with(function_name, "_ZN7opossum");

    // A note about "__clang_call_terminate":
    // __clang_call_terminate is generated / used internally by clang to call the std::terminate function when exception
    // handling fails. For some unknown reason this function cannot be resolved in the Hyrise binary when jit-compiling
    // bitcode that uses the function. The function is, however, present in the bitcode repository.
    // We thus always inline this function from the repository.

    // All function that are not in the opossum:: namespace are not considered for inlining. Instead, a function
    // declaration (without a function body) is created.
    if (!function_has_opossum_namespace && function_name != "__clang_call_terminate") {
      context.llvm_value_map[function] = _create_function_declaration(context, *function, function->getName());
      call_sites.pop();
      continue;
    }

    // Determine whether the first function argument is a pointer/reference to an object, but the runtime location
    // for this object cannot be determined.
    // This is the case for member functions that are called within a loop body. These functions may be called on
    // different objects in different loop iterations.
    // If two specialization passes are performed, these functions should be inlined after loop unrolling has been
    // performed (i.e., during the second pass).
    auto first_argument = call_site.arg_begin();
    auto first_argument_cannot_be_resolved = first_argument->get()->getType()->isPointerTy() &&
                                             !GetRuntimePointerForValue(first_argument->get(), context)->is_valid();

    if (first_argument_cannot_be_resolved && function_name != "__clang_call_terminate") {
      call_sites.pop();
      continue;
    }

    // We create a mapping from values in the source module to values in the target module.
    // This mapping is passed to LLVM's cloning function and ensures that all references to other functions, global
    // variables, and function arguments refer to values in the target module and NOT in the source module.
    // This way the module stays self-contained and valid.
    context.llvm_value_map.clear();

    // Map called functions
    _visit<const llvm::Function>(*function, [&](const auto& fn) {
      if (fn.isDeclaration() && !context.llvm_value_map.count(&fn)) {
        context.llvm_value_map[&fn] = _create_function_declaration(context, fn, fn.getName());
      }
    });

    // Map global variables
    _visit<const llvm::GlobalVariable>(*function, [&](auto& global) {
      if (!context.llvm_value_map.count(&global)) {
        context.llvm_value_map[&global] = _clone_global_variable(context, global);
      }
    });

    // Map function arguments
    auto function_arg = function->arg_begin();
    auto call_arg = call_site.arg_begin();
    for (; function_arg != function->arg_end() && call_arg != call_site.arg_end(); ++function_arg, ++call_arg) {
      context.llvm_value_map[function_arg] = call_arg->get();
    }

    // Instruct LLVM to perform the function inlining and push all new call sites to the working queue
    llvm::InlineFunctionInfo info;
    if (InlineFunction(call_site, info, nullptr, false, nullptr, context)) {
      for (const auto& new_call_site : info.InlinedCallSites) {
        call_sites.push(new_call_site);
      }
    }

    // clear runtime_value_map to allow multiple inlining of same function
    auto runtime_this = context.runtime_value_map[context.root_function->arg_begin()];
    context.runtime_value_map.clear();
    context.runtime_value_map[context.root_function->arg_begin()] = runtime_this;

    call_sites.pop();
  }
}
Exemplo n.º 14
0
Continuation* CodeGen::emit_spawn(Continuation* continuation) {
    assert(continuation->num_args() >= SPAWN_NUM_ARGS && "required arguments are missing");
    auto kernel = continuation->arg(SPAWN_ARG_BODY)->as<Global>()->init()->as_continuation();
    const size_t num_kernel_args = continuation->num_args() - SPAWN_NUM_ARGS;

    // build parallel-function signature
    Array<llvm::Type*> par_args(num_kernel_args);
    for (size_t i = 0; i < num_kernel_args; ++i) {
        auto type = continuation->arg(i + SPAWN_NUM_ARGS)->type();
        par_args[i] = convert(type);
    }

    // fetch values and create a unified struct which contains all values (closure)
    auto closure_type = convert(world_.tuple_type(continuation->arg_fn_type()->ops().skip_front(SPAWN_NUM_ARGS)));
    llvm::Value* closure = nullptr;
    if (closure_type->isStructTy()) {
        closure = llvm::UndefValue::get(closure_type);
        for (size_t i = 0; i < num_kernel_args; ++i)
            closure = irbuilder_.CreateInsertValue(closure, lookup(continuation->arg(i + SPAWN_NUM_ARGS)), unsigned(i));
    } else {
        closure = lookup(continuation->arg(0 + SPAWN_NUM_ARGS));
    }

    // allocate closure object and write values into it
    auto ptr = irbuilder_.CreateAlloca(closure_type, nullptr);
    irbuilder_.CreateStore(closure, ptr, false);

    // create wrapper function and call the runtime
    // wrapper(void* closure)
    llvm::Type* wrapper_arg_types[] = { irbuilder_.getInt8PtrTy(0) };
    auto wrapper_ft = llvm::FunctionType::get(irbuilder_.getVoidTy(), wrapper_arg_types, false);
    auto wrapper_name = kernel->unique_name() + "_spawn_thread";
    auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft);
    auto call = runtime_->spawn_thread(ptr, wrapper);

    // set insert point to the wrapper function
    auto old_bb = irbuilder_.GetInsertBlock();
    auto bb = llvm::BasicBlock::Create(*context_, wrapper_name, wrapper);
    irbuilder_.SetInsertPoint(bb);

    // extract all arguments from the closure
    auto wrapper_args = wrapper->arg_begin();
    auto load_ptr = irbuilder_.CreateBitCast(&*wrapper_args, llvm::PointerType::get(closure_type, 0));
    auto val = irbuilder_.CreateLoad(load_ptr);
    std::vector<llvm::Value*> target_args(num_kernel_args);
    if (val->getType()->isStructTy()) {
        for (size_t i = 0; i < num_kernel_args; ++i)
            target_args[i] = irbuilder_.CreateExtractValue(val, { unsigned(i) });
    } else {
        target_args[0] = val;
    }

    // call kernel body
    auto par_type = llvm::FunctionType::get(irbuilder_.getVoidTy(), llvm_ref(par_args), false);
    auto kernel_par_func = (llvm::Function*)module_->getOrInsertFunction(kernel->unique_name(), par_type);
    irbuilder_.CreateCall(kernel_par_func, target_args);
    irbuilder_.CreateRetVoid();

    // restore old insert point
    irbuilder_.SetInsertPoint(old_bb);

    // bind parameter of continuation to received handle
    auto cont = continuation->arg(SPAWN_ARG_RETURN)->as_continuation();
    emit_result_phi(cont->param(1), call);
    return cont;
}
Exemplo n.º 15
0
Continuation* CodeGen::emit_parallel(Continuation* continuation) {
    // arguments
    assert(continuation->num_args() >= PAR_NUM_ARGS && "required arguments are missing");
    auto num_threads = lookup(continuation->arg(PAR_ARG_NUMTHREADS));
    auto lower = lookup(continuation->arg(PAR_ARG_LOWER));
    auto upper = lookup(continuation->arg(PAR_ARG_UPPER));
    auto kernel = continuation->arg(PAR_ARG_BODY)->as<Global>()->init()->as_continuation();

    const size_t num_kernel_args = continuation->num_args() - PAR_NUM_ARGS;

    // build parallel-function signature
    Array<llvm::Type*> par_args(num_kernel_args + 1);
    par_args[0] = irbuilder_.getInt32Ty(); // loop index
    for (size_t i = 0; i < num_kernel_args; ++i) {
        auto type = continuation->arg(i + PAR_NUM_ARGS)->type();
        par_args[i + 1] = convert(type);
    }

    // fetch values and create a unified struct which contains all values (closure)
    auto closure_type = convert(world_.tuple_type(continuation->arg_fn_type()->ops().skip_front(PAR_NUM_ARGS)));
    llvm::Value* closure = llvm::UndefValue::get(closure_type);
    if (num_kernel_args != 1) {
        for (size_t i = 0; i < num_kernel_args; ++i)
            closure = irbuilder_.CreateInsertValue(closure, lookup(continuation->arg(i + PAR_NUM_ARGS)), unsigned(i));
    } else {
        closure = lookup(continuation->arg(PAR_NUM_ARGS));
    }

    // allocate closure object and write values into it
    auto ptr = emit_alloca(closure_type, "parallel_closure");
    irbuilder_.CreateStore(closure, ptr, false);

    // create wrapper function and call the runtime
    // wrapper(void* closure, int lower, int upper)
    llvm::Type* wrapper_arg_types[] = { irbuilder_.getInt8PtrTy(0), irbuilder_.getInt32Ty(), irbuilder_.getInt32Ty() };
    auto wrapper_ft = llvm::FunctionType::get(irbuilder_.getVoidTy(), wrapper_arg_types, false);
    auto wrapper_name = kernel->unique_name() + "_parallel_for";
    auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft);
    runtime_->parallel_for(num_threads, lower, upper, ptr, wrapper);

    // set insert point to the wrapper function
    auto old_bb = irbuilder_.GetInsertBlock();
    auto bb = llvm::BasicBlock::Create(*context_, wrapper_name, wrapper);
    irbuilder_.SetInsertPoint(bb);

    // extract all arguments from the closure
    auto wrapper_args = wrapper->arg_begin();
    auto load_ptr = irbuilder_.CreateBitCast(&*wrapper_args, llvm::PointerType::get(closure_type, 0));
    auto val = irbuilder_.CreateLoad(load_ptr);
    std::vector<llvm::Value*> target_args(num_kernel_args + 1);
    if (num_kernel_args != 1) {
        for (size_t i = 0; i < num_kernel_args; ++i)
            target_args[i + 1] = irbuilder_.CreateExtractValue(val, { unsigned(i) });
    } else {
        target_args[1] = val;
    }

    // create loop iterating over range:
    // for (int i=lower; i<upper; ++i)
    //   body(i, <closure_elems>);
    auto wrapper_lower = &*(++wrapper_args);
    auto wrapper_upper = &*(++wrapper_args);
    create_loop(wrapper_lower, wrapper_upper, irbuilder_.getInt32(1), wrapper, [&](llvm::Value* counter) {
        // call kernel body
        target_args[0] = counter; // loop index
        auto par_type = llvm::FunctionType::get(irbuilder_.getVoidTy(), llvm_ref(par_args), false);
        auto kernel_par_func = (llvm::Function*)module_->getOrInsertFunction(kernel->unique_name(), par_type);
        irbuilder_.CreateCall(kernel_par_func, target_args);
    });
    irbuilder_.CreateRetVoid();

    // restore old insert point
    irbuilder_.SetInsertPoint(old_bb);

    return continuation->arg(PAR_ARG_RETURN)->as_continuation();
}