Stmt mutate(Stmt s) { CountPredicatedStoreLoad c; s.accept(&c); if (has_store_count) { if (c.store_count == 0) { printf("There should be some predicated stores but didn't find any\n"); exit(-1); } } else { if (c.store_count > 0) { printf("There were %d predicated stores. There weren't supposed to be any stores\n", c.store_count); exit(-1); } } if (has_load_count) { if (c.load_count == 0) { printf("There should be some predicated loads but didn't find any\n"); exit(-1); } } else { if (c.load_count > 0) { printf("There were %d predicated loads. There weren't supposed to be any loads\n", c.load_count); exit(-1); } } return s; }
string print_loop_nest(const vector<Function> &outputs) { // Do the first part of lowering: // Compute an environment map<string, Function> env; for (Function f : outputs) { map<string, Function> more_funcs = find_transitive_calls(f); env.insert(more_funcs.begin(), more_funcs.end()); } // Compute a realization order vector<string> order = realization_order(outputs, env); // For the purposes of printing the loop nest, we don't want to // worry about which features are and aren't enabled. Target target = get_host_target(); for (DeviceAPI api : all_device_apis) { target.set_feature(target_feature_for_device_api(DeviceAPI(api))); } bool any_memoized = false; // Schedule the functions. Stmt s = schedule_functions(outputs, order, env, target, any_memoized); // Now convert that to pseudocode std::ostringstream sstr; PrintLoopNest pln(sstr, env); s.accept(&pln); return sstr.str(); }
std::unique_ptr<Stmt> InstantiationVisitor::clone(Stmt& s) { std::unique_ptr<Stmt> result; s.accept(*this); swap(result, _stmt); return std::move(result); }
Closure Closure::make(Stmt s, const string &loop_variable, bool track_buffers, llvm::StructType *buffer_t) { Closure c; c.buffer_t = buffer_t; c.track_buffers = track_buffers; c.ignore.push(loop_variable, 0); s.accept(&c); return c; }
Stmt IRMutator::mutate(const Stmt &s) { if (s.defined()) { s.accept(this); } else { stmt = Stmt(); } expr = Expr(); return std::move(stmt); }
Stmt IRMutator::mutate(Stmt s) { if (s.defined()) { s.accept(this); } else { stmt = Stmt(); } expr = Expr(); return stmt; }
int count_host_alignment_asserts(Func f, std::map<string, int> m) { Target t = get_jit_target_from_environment(); t.set_feature(Target::NoBoundsQuery); f.compute_root(); Stmt s = Internal::lower({f.function()}, f.name(), t); CountHostAlignmentAsserts c(m); s.accept(&c); return c.count; }
int count_interleaves(Func f) { Target t = get_jit_target_from_environment(); t.set_feature(Target::NoBoundsQuery); t.set_feature(Target::NoAsserts); Stmt s = Internal::lower(f.function(), t); CountInterleaves i; s.accept(&i); return i.result; }
bool uses_branches(Func f) { Target t = get_jit_target_from_environment(); t.set_feature(Target::NoBoundsQuery); t.set_feature(Target::NoAsserts); Stmt s = Internal::lower(f.function(), t); ContainsBranches b; s.accept(&b); return b.result; }
void IRGraphVisitor::include(const Stmt &s) { if (visited.count(s.ptr)) { return; } else { visited.insert(s.ptr); s.accept(this); return; } }
Stmt mutate(const Stmt &s) override { CheckLoops c; s.accept(&c); if (c.count != 1) { std::cerr << "expected one loop, found " << c.count << std::endl; exit(-1); } return s; }
Stmt skip_stages(Stmt stmt, const vector<string> &order) { for (size_t i = order.size()-1; i > 0; i--) { MightBeSkippable check(order[i-1]); stmt.accept(&check); if (check.result) { StageSkipper skipper(order[i-1]); stmt = skipper.mutate(stmt); } } return stmt; }
Stmt mutate(Stmt s) { CountBarriers c; s.accept(&c); if (c.count != correct) { printf("There were %d barriers. There were supposed to be %d\n", c.count, correct); exit(-1); } return s; }
Stmt mutate(const Stmt &s) override { CountStores c; s.accept(&c); if (c.count != correct) { printf("There were %d stores. There were supposed to be %d\n", c.count, correct); exit(-1); } return s; }
void Func::realize(Buffer dst) { if (!compiled_module.wrapped_function) { assert(func.defined() && "Can't realize NULL function handle"); assert(value().defined() && "Can't realize undefined function"); Stmt stmt = lower(); // Infer arguments InferArguments infer_args; stmt.accept(&infer_args); Argument me(name(), true, Int(1)); infer_args.arg_types.push_back(me); arg_values = infer_args.arg_values; arg_values.push_back(dst.raw_buffer()); image_param_args = infer_args.image_param_args; Internal::log(2) << "Inferred argument list:\n"; for (size_t i = 0; i < infer_args.arg_types.size(); i++) { Internal::log(2) << infer_args.arg_types[i].name << ", " << infer_args.arg_types[i].type << ", " << infer_args.arg_types[i].is_buffer << "\n"; } StmtCompiler cg; cg.compile(stmt, name(), infer_args.arg_types); if (log::debug_level >= 3) { cg.compile_to_native(name() + ".s", true); cg.compile_to_bitcode(name() + ".bc"); std::ofstream stmt_debug((name() + ".stmt").c_str()); stmt_debug << stmt; } compiled_module = cg.compile_to_function_pointers(); if (error_handler) compiled_module.set_error_handler(error_handler); else compiled_module.set_error_handler(NULL); } else { // Update the address of the buffer we're realizing into arg_values[arg_values.size()-1] = dst.raw_buffer(); // update the addresses of the image param args for (size_t i = 0; i < image_param_args.size(); i++) { Buffer b = image_param_args[i].second.get_buffer(); assert(b.defined() && "An ImageParam is not bound to a buffer"); arg_values[image_param_args[i].first] = b.raw_buffer(); } } compiled_module.wrapped_function(&(arg_values[0])); }
string print_loop_nest(const vector<Function> &output_funcs) { // Do the first part of lowering: // Compute an environment map<string, Function> env; for (Function f : output_funcs) { populate_environment(f, env); } // Create a deep-copy of the entire graph of Funcs. vector<Function> outputs; std::tie(outputs, env) = deep_copy(output_funcs, env); // Output functions should all be computed and stored at root. for (Function f: outputs) { Func(f).compute_root().store_root(); } // Ensure that all ScheduleParams become well-defined constant Exprs. for (auto &f : env) { f.second.substitute_schedule_param_exprs(); } // Substitute in wrapper Funcs env = wrap_func_calls(env); // Compute a realization order vector<string> order = realization_order(outputs, env); // Try to simplify the RHS/LHS of a function definition by propagating its // specializations' conditions simplify_specializations(env); // For the purposes of printing the loop nest, we don't want to // worry about which features are and aren't enabled. Target target = get_host_target(); for (DeviceAPI api : all_device_apis) { target.set_feature(target_feature_for_device_api(DeviceAPI(api))); } bool any_memoized = false; // Schedule the functions. Stmt s = schedule_functions(outputs, order, env, target, any_memoized); // Now convert that to pseudocode std::ostringstream sstr; PrintLoopNest pln(sstr, env); s.accept(&pln); return sstr.str(); }
Stmt skip_stages(Stmt stmt, const vector<string> &order) { // Don't consider the last stage, because it's the output, so it's // never skippable. for (size_t i = order.size()-1; i > 0; i--) { debug(2) << "skip_stages checking " << order[i-1] << "\n"; MightBeSkippable check(order[i-1]); stmt.accept(&check); if (check.result) { debug(2) << "skip_stages can skip " << order[i-1] << "\n"; StageSkipper skipper(order[i-1]); stmt = skipper.mutate(stmt); } } return stmt; }
Stmt IRRewriter::rewrite(Stmt s) { if (s.defined()) { s.accept(this); Stmt spilledStmts = getSpilledStmts(); if (spilledStmts.defined()) { stmt = Block::make(spilledStmts, stmt); } s = stmt; } else { s = Stmt(); } expr = Expr(); stmt = Stmt(); func = Func(); return s; }
string print_loop_nest(const vector<Function> &outputs) { // Do the first part of lowering: // Compute an environment map<string, Function> env; for (Function f : outputs) { map<string, Function> more_funcs = find_transitive_calls(f); env.insert(more_funcs.begin(), more_funcs.end()); } // Compute a realization order vector<string> order = realization_order(outputs, env); // Schedule the functions. Stmt s = schedule_functions(outputs, order, env, false); // Now convert that to pseudocode std::ostringstream sstr; PrintLoopNest pln(sstr, env); s.accept(&pln); return sstr.str(); }
void CodeGen_OpenGLCompute_Dev::CodeGen_OpenGLCompute_C::add_kernel(Stmt s, Target target, const string &name, const vector<GPU_Argument> &args) { debug(2) << "Adding OpenGLCompute kernel " << name << "\n"; cache.clear(); if (target.os == Target::Android) { stream << "#version 310 es\n" << "#extension GL_ANDROID_extension_pack_es31a : require\n"; } else { stream << "#version 430\n"; } stream << "float float_from_bits(int x) { return intBitsToFloat(int(x)); }\n"; for (size_t i = 0; i < args.size(); i++) { if (args[i].is_buffer) { // // layout(binding = 10) buffer buffer10 { // vec3 data[]; // } inBuffer; // stream << "layout(binding=" << i << ")" << " buffer buffer" << i << " { " << print_type(args[i].type) << " data[]; } " << print_name(args[i].name) << ";\n"; } else { stream << "uniform " << print_type(args[i].type) << " " << print_name(args[i].name) << ";\n"; } } // Find all the shared allocations and declare them at global scope. FindSharedAllocations fsa; s.accept(&fsa); for (const Allocate *op : fsa.allocs) { internal_assert(op->extents.size() == 1 && is_const(op->extents[0])); stream << "shared " << print_type(op->type) << " " << print_name(op->name) << "[" << op->extents[0] << "];\n"; } // We'll figure out the workgroup size while traversing the stmt workgroup_size[0] = 0; workgroup_size[1] = 0; workgroup_size[2] = 0; stream << "void main()\n{\n"; indent += 2; print(s); indent -= 2; stream << "}\n"; // Declare the workgroup size. indent += 2; stream << "layout(local_size_x = " << workgroup_size[0]; if (workgroup_size[1] > 1) { stream << ", local_size_y = " << workgroup_size[1]; } if (workgroup_size[2] > 1) { stream << ", local_size_z = " << workgroup_size[2]; } stream << ") in;\n// end of kernel " << name << "\n"; }
int count_producers(Stmt in, const std::string &name) { CountProducers counter(name); in.accept(&counter); return counter.count; }
void CodeGen_SPIR_Dev::add_kernel(Stmt stmt, std::string name, const std::vector<Argument> &args) { // Now deduce the types of the arguments to our function vector<llvm::Type *> arg_types(args.size()+1); for (size_t i = 0; i < args.size(); i++) { if (args[i].is_buffer) { arg_types[i] = llvm_type_of(UInt(8))->getPointerTo(1); // __global = addrspace(1) } else { arg_types[i] = llvm_type_of(args[i].type); } } // Add local (shared) memory buffer parameter. arg_types[args.size()] = llvm_type_of(UInt(8))->getPointerTo(3); // __local = addrspace(3) // Make our function function_name = name; FunctionType *func_t = FunctionType::get(void_t, arg_types, false); function = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module); function->setCallingConv(llvm::CallingConv::SPIR_KERNEL); // Mark the buffer args as no alias for (size_t i = 0; i < args.size(); i++) { if (args[i].is_buffer) { function->setDoesNotAlias(i+1); } } // Mark the local memory as no alias (probably not necessary?) function->setDoesNotAlias(args.size()); // Make the initial basic block entry_block = BasicBlock::Create(*context, "entry", function); builder->SetInsertPoint(entry_block); vector<Value *> kernel_arg_address_space = init_kernel_metadata(*context, "kernel_arg_addr_space"); vector<Value *> kernel_arg_access_qual = init_kernel_metadata(*context, "kernel_arg_access_qual"); vector<Value *> kernel_arg_type = init_kernel_metadata(*context, "kernel_arg_type"); vector<Value *> kernel_arg_base_type = init_kernel_metadata(*context, "kernel_arg_base_type"); vector<Value *> kernel_arg_type_qual = init_kernel_metadata(*context, "kernel_arg_type_qual"); vector<Value *> kernel_arg_name = init_kernel_metadata(*context, "kernel_arg_name"); // Put the arguments in the symbol table { llvm::Function::arg_iterator arg = function->arg_begin(); for (std::vector<Argument>::const_iterator iter = args.begin(); iter != args.end(); ++iter, ++arg) { if (iter->is_buffer) { // HACK: codegen expects a load from foo to use base // address 'foo.host', so we store the device pointer // as foo.host in this scope. sym_push(iter->name + ".host", arg); kernel_arg_address_space.push_back(ConstantInt::get(i32, 1)); } else { sym_push(iter->name, arg); kernel_arg_address_space.push_back(ConstantInt::get(i32, 0)); } arg->setName(iter->name); kernel_arg_name.push_back(MDString::get(*context, iter->name)); kernel_arg_access_qual.push_back(MDString::get(*context, "none")); kernel_arg_type_qual.push_back(MDString::get(*context, "")); // TODO: 'Type' isn't correct, but we don't have C to get the type name from... // This really shouldn't matter anyways. Everything SPIR needs is in the function // type, this metadata seems redundant. kernel_arg_type.push_back(MDString::get(*context, "type")); kernel_arg_base_type.push_back(MDString::get(*context, "type")); } arg->setName("shared"); kernel_arg_address_space.push_back(ConstantInt::get(i32, 3)); // __local = addrspace(3) kernel_arg_name.push_back(MDString::get(*context, "shared")); kernel_arg_access_qual.push_back(MDString::get(*context, "none")); kernel_arg_type.push_back(MDString::get(*context, "char*")); kernel_arg_base_type.push_back(MDString::get(*context, "char*")); kernel_arg_type_qual.push_back(MDString::get(*context, "")); } // We won't end the entry block yet, because we'll want to add // some allocas to it later if there are local allocations. Start // a new block to put all the code. BasicBlock *body_block = BasicBlock::Create(*context, "body", function); builder->SetInsertPoint(body_block); debug(1) << "Generating llvm bitcode...\n"; // Ok, we have a module, function, context, and a builder // pointing at a brand new basic block. We're good to go. stmt.accept(this); // Now we need to end the function builder->CreateRetVoid(); // Make the entry block point to the body block builder->SetInsertPoint(entry_block); builder->CreateBr(body_block); // Add the nvvm annotation that it is a kernel function. Value *kernel_metadata[] = { function, MDNode::get(*context, kernel_arg_address_space), MDNode::get(*context, kernel_arg_access_qual), MDNode::get(*context, kernel_arg_type), MDNode::get(*context, kernel_arg_type_qual), MDNode::get(*context, kernel_arg_name) }; MDNode *mdNode = MDNode::get(*context, kernel_metadata); module->getOrInsertNamedMetadata("opencl.kernels")->addOperand(mdNode); // Now verify the function is ok verifyFunction(*function); // Finally, verify the module is ok verifyModule(*module); debug(2) << "Done generating llvm bitcode\n"; }
// Insert checks to make sure that parameters are within their // declared range. Stmt add_parameter_checks(Stmt s, const Target &t) { // First, find all the parameters FindParameters finder; s.accept(&finder); map<string, Expr> replace_with_constrained; vector<pair<string, Expr>> lets; struct ParamAssert { Expr condition; Expr value, limit_value; string param_name; }; vector<ParamAssert> asserts; // Make constrained versions of the params for (pair<const string, Parameter> &i : finder.params) { Parameter param = i.second; if (!param.is_buffer() && (param.get_min_value().defined() || param.get_max_value().defined())) { string constrained_name = i.first + ".constrained"; Expr constrained_var = Variable::make(param.type(), constrained_name); Expr constrained_value = Variable::make(param.type(), i.first, param); replace_with_constrained[i.first] = constrained_var; if (param.get_min_value().defined()) { ParamAssert p = { constrained_value >= param.get_min_value(), constrained_value, param.get_min_value(), param.name() }; asserts.push_back(p); constrained_value = max(constrained_value, param.get_min_value()); } if (param.get_max_value().defined()) { ParamAssert p = { constrained_value <= param.get_max_value(), constrained_value, param.get_max_value(), param.name() }; asserts.push_back(p); constrained_value = min(constrained_value, param.get_max_value()); } lets.push_back(make_pair(constrained_name, constrained_value)); } } // Replace the params with their constrained version in the rest of the pipeline s = substitute(replace_with_constrained, s); // Inject the let statements for (size_t i = 0; i < lets.size(); i++) { s = LetStmt::make(lets[i].first, lets[i].second, s); } if (t.has_feature(Target::NoAsserts)) { asserts.clear(); } // Inject the assert statements for (size_t i = 0; i < asserts.size(); i++) { ParamAssert p = asserts[i]; // Upgrade the types to 64-bit versions for the error call Type wider = p.value.type().with_bits(64); p.limit_value = cast(wider, p.limit_value); p.value = cast(wider, p.value); string error_call_name = "halide_error_param"; if (p.condition.as<LE>()) { error_call_name += "_too_large"; } else { internal_assert(p.condition.as<GE>()); error_call_name += "_too_small"; } if (wider.is_int()) { error_call_name += "_i64"; } else if (wider.is_uint()) { error_call_name += "_u64"; } else { internal_assert(wider.is_float()); error_call_name += "_f64"; } Expr error = Call::make(Int(32), error_call_name, {p.param_name, p.value, p.limit_value}, Call::Extern); s = Block::make(AssertStmt::make(p.condition, error), s); } return s; }
void print(Stmt ir) { ir.accept(this); }
void CodeGen_PTX_Dev::add_kernel(Stmt stmt, const std::string &name, const std::vector<DeviceArgument> &args) { internal_assert(module != nullptr); debug(2) << "In CodeGen_PTX_Dev::add_kernel\n"; // Now deduce the types of the arguments to our function vector<llvm::Type *> arg_types(args.size()); for (size_t i = 0; i < args.size(); i++) { if (args[i].is_buffer) { arg_types[i] = llvm_type_of(UInt(8))->getPointerTo(); } else { arg_types[i] = llvm_type_of(args[i].type); } } // Make our function FunctionType *func_t = FunctionType::get(void_t, arg_types, false); function = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); // Mark the buffer args as no alias for (size_t i = 0; i < args.size(); i++) { if (args[i].is_buffer) { function->setDoesNotAlias(i+1); } } // Make the initial basic block entry_block = BasicBlock::Create(*context, "entry", function); builder->SetInsertPoint(entry_block); // Put the arguments in the symbol table vector<string> arg_sym_names; { size_t i = 0; for (auto &fn_arg : function->args()) { string arg_sym_name = args[i].name; if (args[i].is_buffer) { // HACK: codegen expects a load from foo to use base // address 'foo.host', so we store the device pointer // as foo.host in this scope. arg_sym_name += ".host"; } sym_push(arg_sym_name, &fn_arg); fn_arg.setName(arg_sym_name); arg_sym_names.push_back(arg_sym_name); i++; } } // We won't end the entry block yet, because we'll want to add // some allocas to it later if there are local allocations. Start // a new block to put all the code. BasicBlock *body_block = BasicBlock::Create(*context, "body", function); builder->SetInsertPoint(body_block); debug(1) << "Generating llvm bitcode for kernel...\n"; // Ok, we have a module, function, context, and a builder // pointing at a brand new basic block. We're good to go. stmt.accept(this); // Now we need to end the function builder->CreateRetVoid(); // Make the entry block point to the body block builder->SetInsertPoint(entry_block); builder->CreateBr(body_block); // Add the nvvm annotation that it is a kernel function. llvm::Metadata *md_args[] = { llvm::ValueAsMetadata::get(function), MDString::get(*context, "kernel"), llvm::ValueAsMetadata::get(ConstantInt::get(i32_t, 1)) }; MDNode *md_node = MDNode::get(*context, md_args); module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(md_node); // Now verify the function is ok verifyFunction(*function); // Finally, verify the module is ok verifyModule(*module); debug(2) << "Done generating llvm bitcode for PTX\n"; // Clear the symbol table for (size_t i = 0; i < arg_sym_names.size(); i++) { sym_pop(arg_sym_names[i]); } }
Closure::Closure(Stmt s, const string &loop_variable) { if (!loop_variable.empty()) { ignore.push(loop_variable, 0); } s.accept(this); }
bool check(Stmt stmt) { isBlocked = false; stmt.accept(this); return isBlocked; }
bool check(Stmt stmt) { isFlattened = true; indexExprFound = false; stmt.accept(this); return isFlattened; }
Closure::Closure(Stmt s, const string &loop_variable, llvm::StructType *buffer_t) : buffer_t(buffer_t) { ignore.push(loop_variable, 0); s.accept(this); }
void IRPrinter::print(Stmt ir) { ir.accept(this); }