Exemplo n.º 1
0
ExpressionPtr KernelReplacer::createKernelCallLambda(const ExpressionAddress localKernel,
		const ExpressionPtr work_dim, const ExpressionPtr local_work_size, const ExpressionPtr global_work_size) {
	NodeManager& mgr = prog->getNodeManager();
	IRBuilder builder(mgr);

	ExpressionPtr k = utils::getRootVariable(localKernel).as<ExpressionPtr>();
/*
std::cout << "searching " << *k << " from " << *localKernel << std::endl;
if(kernelFunctions.empty()) std::cout << "\tnothing\n";

for_each(kernelFunctions, [](std::pair<core::ExpressionPtr, core::LambdaExprPtr> kernel) {
	std::cout << "in " << *kernel.first << std::endl;
});
*/
	// try to find coresponding kernel function
	assert(kernelFunctions.find(k) != kernelFunctions.end() && "Cannot find OpenCL Kernel");
	const ExpressionPtr local = anythingToVec3(work_dim, local_work_size);
	const ExpressionPtr global = anythingToVec3(work_dim, global_work_size);

	LambdaExprPtr lambda = kernelFunctions[k].as<LambdaExprPtr>();

//dumpPretty(lambda);
//for_each(kernelTypes[k], [&](TypePtr ty) {
//	std::cout << "->\t" << *ty << std::endl;
//});

	/*    assert(kernelArgs.find(k) != kernelArgs.end() && "No arguments for call to kernel function found");
	const VariablePtr& args = kernelArgs[k];
	const TupleTypePtr& argTypes = dynamic_pointer_cast<const TupleType>(args->getType());*/
	const VariableList& interface = lambda->getParameterList()->getElements();

	vector<ExpressionPtr> innerArgs;
	const core::lang::BasicGenerator& gen = builder.getLangBasic();

	// construct call to kernel function
//		if(localMemDecls.find(k) == localMemDecls.end() || localMemDecls[k].size() == 0) {
//std::cout << "lmd " << localMemDecls[k] << std::endl;
	TypeList kernelType = kernelTypes[k];

	/* body of a newly created function which replaces clNDRangeKernel. It contains
	 *  the kernel call
	 *  return 0;
	 */
	StatementList body;

	// Kernel variable to be used inside the newly created function
	VariablePtr innerKernel = builder.variable(builder.tupleType(kernelType));
	for(size_t i = 0; i < interface.size() -2 /*argTypes->getElementTypes().size()*/; ++i) {
//??			TypePtr argTy = utils::vectorArrayTypeToScalarArrayType(interface.at(i)->getType(), builder);
		TypePtr argTy = interface.at(i)->getType();
		TypePtr memberTy = kernelType.at(i);
		ExpressionPtr tupleMemberAccess = builder.callExpr(memberTy, gen.getTupleMemberAccess(), utils::removeDoubleRef(innerKernel),
				builder.literal(gen.getUInt8(), toString(i)), builder.getTypeLiteral(memberTy));
		ExpressionPtr argument = handleArgument(argTy, memberTy, tupleMemberAccess, body);

		innerArgs.push_back(argument);
	}

	const TypePtr vecTy = builder.vectorType(gen.getUInt8(), builder.concreteIntTypeParam(static_cast<size_t>(3)));

	// local and global size to be used inside the newly created function
	VariablePtr innerGlobal = builder.variable(vecTy);
	VariablePtr innerLocal = builder.variable(vecTy);

	// add global and local size to arguments
	innerArgs.push_back(innerGlobal);
	innerArgs.push_back(innerLocal);

	ExpressionPtr kernelCall = builder.callExpr(gen.getInt4(), lambda, innerArgs);
	body.push_back(kernelCall);							   // calling the kernel function
	body.push_back(builder.returnStmt(builder.intLit(0))); // return CL_SUCCESS

	// create function type for inner function: kernel tuple, global size, local size
	TypeList innerFctInterface;
	innerFctInterface.push_back(innerKernel->getType());
	innerFctInterface.push_back(vecTy);
	innerFctInterface.push_back(vecTy);

	FunctionTypePtr innerFctTy = builder.functionType(innerFctInterface, gen.getInt4());

	// collect inner function parameters
	VariableList innerFctParams;
	innerFctParams.push_back(innerKernel);
	innerFctParams.push_back(innerGlobal);
	innerFctParams.push_back(innerLocal);

	// create lambda for inner function
	LambdaExprPtr innerLambda = builder.lambdaExpr(innerFctTy, builder.parameters(innerFctParams), builder.compoundStmt(body));

	return builder.callExpr(gen.getInt4(), innerLambda, builder.deref(localKernel), global, local);
}