void Lower::lower_taskwait_with_dependences(const Nodecl::OpenMP::Taskwait& node)
    {
        Nodecl::List environment = node.get_environment().as<Nodecl::List>();

        // Prepare if(0) task reusing environment and set TaskIsTaskwait flag
        Nodecl::NodeclBase zero_expr;
        if(IS_C_LANGUAGE || IS_CXX_LANGUAGE)
        {
            zero_expr = const_value_to_nodecl(const_value_get_unsigned_int(0));
        }
        else  // IS_FORTRAN_LANGUAGE
        {
            zero_expr = Nodecl::BooleanLiteral::make(
                    TL::Type::get_bool_type(),
                    const_value_get_zero(/* bytes */ 4, /* sign */0));
        }
        environment.append(Nodecl::OpenMP::If::make(zero_expr));
        environment.append(Nodecl::OpenMP::TaskIsTaskwait::make());

        Nodecl::OpenMP::Task taskwait_task = Nodecl::OpenMP::Task::make(
                environment,
                Nodecl::List::make(Nodecl::EmptyStatement::make()),
                node.get_locus());

        node.replace(taskwait_task);
        lower_task(node.as<Nodecl::OpenMP::Task>());
    }
Exemple #2
0
        void VectorizerVectorReduction::vectorize_reduction(const TL::Symbol& scalar_symbol,
                TL::Symbol& vector_symbol,
                const Nodecl::NodeclBase& reduction_initializer,
                const std::string& reduction_name,
                const TL::Type& reduction_type,
                Nodecl::List& pre_nodecls,
                Nodecl::List& post_nodecls)
        {
            // Step1: ADD REDUCTION SYMBOLS
            vector_symbol.set_value(Nodecl::VectorPromotion::make(
                        reduction_initializer.shallow_copy(),
                        vector_symbol.get_type()));

            // Add new ObjectInit with the initialization
            Nodecl::ObjectInit reduction_object_init =
                Nodecl::ObjectInit::make(vector_symbol);

            pre_nodecls.append(reduction_object_init);


            // Step2: ADD VECTOR REDUCTION INSTRUCTIONS
            if(reduction_name.compare("+") == 0)
            {
                Nodecl::ExpressionStatement post_reduction_stmt =
                    Nodecl::ExpressionStatement::make(
                            Nodecl::VectorReductionAdd::make(
                                scalar_symbol.make_nodecl(true),
                                vector_symbol.make_nodecl(true),
                                scalar_symbol.get_type()));

                post_nodecls.append(post_reduction_stmt);
            }
            else if (reduction_name.compare("-") == 0)
            {
                Nodecl::ExpressionStatement post_reduction_stmt =
                    Nodecl::ExpressionStatement::make(
                            Nodecl::VectorReductionMinus::make(
                                scalar_symbol.make_nodecl(true),
                                vector_symbol.make_nodecl(true),
                                scalar_symbol.get_type()));

                post_nodecls.append(post_reduction_stmt);
            }
        }
Exemple #3
0
void AutoScopeVisitor::visit( const Nodecl::OpenMP::Task& n )
{
    // Retrieve the results of the Auto-Scoping process to the user
    _analysis_info->print_auto_scoping_results( n );

    // Modify the Nodecl with the new variables' scope
    Analysis::Utils::AutoScopedVariables autosc_vars = _analysis_info->get_auto_scoped_variables( n );
    Analysis::Utils::ext_sym_set private_ext_syms, firstprivate_ext_syms, race_ext_syms,
             shared_ext_syms, undef_ext_syms;
    Nodecl::NodeclBase user_private_vars, user_firstprivate_vars, user_shared_vars;

    // Get actual environment
    Nodecl::List environ = n.get_environment().as<Nodecl::List>();
    for( Nodecl::List::iterator it = environ.begin( ); it != environ.end( ); )
    {
        if( it->is<Nodecl::OpenMP::Auto>( ) )
        {
            it = environ.erase( it );
        }
        else
        {
            if( it->is<Nodecl::OpenMP::Private>( ) )
            {
                user_private_vars = it->as<Nodecl::OpenMP::Private>( );
            }
            if( it->is<Nodecl::OpenMP::Firstprivate>( ) )
            {
                user_firstprivate_vars = it->as<Nodecl::OpenMP::Firstprivate>( );
            }
            if( it->is<Nodecl::OpenMP::Shared>( ) )
            {
                user_shared_vars = it->as<Nodecl::OpenMP::Shared>( );
            }
            ++it;
        }
    }

    // Remove user-scoped variables from auto-scoped variables and reset environment
    private_ext_syms = autosc_vars.get_private_vars( );
    if( !private_ext_syms.empty( ) )
    {
        ObjectList<Nodecl::NodeclBase> autosc_private_vars;
        for( Analysis::Utils::ext_sym_set::iterator it = private_ext_syms.begin( ); it != private_ext_syms.end( ); ++it )
        {
            autosc_private_vars.insert( it->get_nodecl( ) );
        }
        ObjectList<Nodecl::NodeclBase> purged_autosc_private_vars;
        for( ObjectList<Nodecl::NodeclBase>::iterator it = autosc_private_vars.begin( );
                it != autosc_private_vars.end( ); ++it )
        {
            if( !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_firstprivate_vars.as<Nodecl::List>( ) )
                    && !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_private_vars.as<Nodecl::List>( ) )
                    && !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_shared_vars.as<Nodecl::List>( ) ) )
            {
                purged_autosc_private_vars.insert( it->shallow_copy() );
            }
        }
        if( !purged_autosc_private_vars.empty( ) )
        {
            Nodecl::OpenMP::Private private_node =
                Nodecl::OpenMP::Private::make( Nodecl::List::make( purged_autosc_private_vars ),
                                               n.get_locus( ) );
            environ.append( private_node );
        }
    }

    firstprivate_ext_syms = autosc_vars.get_firstprivate_vars( );
    if( !firstprivate_ext_syms.empty( ) )
    {
        ObjectList<Nodecl::NodeclBase> autosc_firstprivate_vars;
        for( Analysis::Utils::ext_sym_set::iterator it = firstprivate_ext_syms.begin( ); it != firstprivate_ext_syms.end( ); ++it )
        {
            autosc_firstprivate_vars.insert( it->get_nodecl( ) );
        }
        ObjectList<Nodecl::NodeclBase> purged_autosc_firstprivate_vars;
        for( ObjectList<Nodecl::NodeclBase>::iterator it = autosc_firstprivate_vars.begin( );
                it != autosc_firstprivate_vars.end( ); ++it )
        {
            if( !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_firstprivate_vars.as<Nodecl::List>( ) )
                    && !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_private_vars.as<Nodecl::List>( ) )
                    && !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_shared_vars.as<Nodecl::List>( ) ) )
            {
                purged_autosc_firstprivate_vars.insert( it->shallow_copy() );
            }
        }
        if( !purged_autosc_firstprivate_vars.empty( ) )
        {
            Nodecl::OpenMP::Firstprivate firstprivate_node =
                Nodecl::OpenMP::Firstprivate::make( Nodecl::List::make( purged_autosc_firstprivate_vars ),
                                                    n.get_locus( ) );
            environ.append( firstprivate_node );
        }
    }

    shared_ext_syms = autosc_vars.get_shared_vars( );
    if( !shared_ext_syms.empty( ) )
    {
        ObjectList<Nodecl::NodeclBase> autosc_shared_vars;
        for( Analysis::Utils::ext_sym_set::iterator it = shared_ext_syms.begin( ); it != shared_ext_syms.end( ); ++it )
        {
            autosc_shared_vars.insert( it->get_nodecl( ) );
        }
        ObjectList<Nodecl::NodeclBase> purged_autosc_shared_vars;
        for( ObjectList<Nodecl::NodeclBase>::iterator it = autosc_shared_vars.begin( );
                it != autosc_shared_vars.end( ); ++it )
        {
            if( !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_firstprivate_vars.as<Nodecl::List>( ) )
                    && !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_private_vars.as<Nodecl::List>( ) )
                    && !Nodecl::Utils::nodecl_is_in_nodecl_list( *it, user_shared_vars.as<Nodecl::List>( ) ) )
            {
                purged_autosc_shared_vars.insert( it->shallow_copy() );
            }
        }
        if( !purged_autosc_shared_vars.empty( ) )
        {
            Nodecl::OpenMP::Shared shared_node =
                Nodecl::OpenMP::Shared::make( Nodecl::List::make( purged_autosc_shared_vars ),
                                              n.get_locus( ) );
            environ.append( shared_node );
        }
    }
}
void DeviceOpenCL::create_outline(CreateOutlineInfo &info,
        Nodecl::NodeclBase &outline_placeholder,
        Nodecl::NodeclBase &output_statements,
        Nodecl::Utils::SimpleSymbolMap* &symbol_map)
{
    // Unpack DTO
    Lowering *lowering = info._lowering;
    const std::string& outline_name = ocl_outline_name(info._outline_name);
    const Nodecl::NodeclBase& task_statements = info._task_statements;
    const Nodecl::NodeclBase& original_statements = info._original_statements;
    const TL::Symbol& called_task = info._called_task;
    bool is_function_task = info._called_task.is_valid();
    TL::ObjectList<OutlineDataItem*> data_items = info._data_items;

    lowering->seen_opencl_task = true;

    symbol_map = new Nodecl::Utils::SimpleSymbolMap();

    output_statements = task_statements;

    ERROR_CONDITION(called_task.is_valid() && !called_task.is_function(),
            "The '%s' symbol is not a function", called_task.get_name().c_str());

    TL::Symbol current_function = original_statements.retrieve_context().get_related_symbol();
    if (current_function.is_nested_function())
    {
        if (IS_C_LANGUAGE || IS_CXX_LANGUAGE)
            fatal_printf_at(original_statements.get_locus(), "nested functions are not supported\n");


        if (IS_FORTRAN_LANGUAGE)
            fatal_printf_at(original_statements.get_locus(), "internal subprograms are not supported\n");
    }


    Source extra_declarations;
    Source final_statements, initial_statements;

    // *** Unpacked (and forward in Fortran) function ***
    TL::Symbol unpacked_function, forward_function;
    if (IS_FORTRAN_LANGUAGE)
    {
        forward_function = new_function_symbol_forward(
                current_function,
                outline_name + "_forward",
                info);
        unpacked_function = new_function_symbol_unpacked(
                current_function,
                outline_name + "_unpack",
                info,
                // out
                symbol_map,
                initial_statements,
                final_statements);
    }
    else
    {
        unpacked_function = new_function_symbol_unpacked(
                current_function,
                outline_name + "_unpacked",
                info,
                // out
                symbol_map,
                initial_statements,
                final_statements);
    }


    Nodecl::NodeclBase unpacked_function_code, unpacked_function_body;
    SymbolUtils::build_empty_body_for_function(unpacked_function,
            unpacked_function_code,
            unpacked_function_body);

    if (IS_FORTRAN_LANGUAGE)
    {
        // Now get all the needed internal functions and replicate them in the outline
        Nodecl::Utils::Fortran::InternalFunctions internal_functions;
        internal_functions.walk(info._original_statements);

        Nodecl::List l;
        for (TL::ObjectList<Nodecl::NodeclBase>::iterator
                it2 = internal_functions.function_codes.begin();
                it2 != internal_functions.function_codes.end();
                it2++)
        {
            l.append(
                    Nodecl::Utils::deep_copy(*it2, unpacked_function.get_related_scope(), *symbol_map)
                    );
        }

        // unpacked_function_code.as<Nodecl::FunctionCode>().set_internal_functions(l);
    }

    Nodecl::Utils::append_to_top_level_nodecl(unpacked_function_code);

    //Get the argument of the 'file' clause, if not present, use the files passed in the command line (if any)
    std::string file_clause_arg = info._target_info.get_file();
    std::string kernel_files = file_clause_arg;
    if (!file_clause_arg.empty())
    {
        bool found = false;
        for (int i = 0; i < ::compilation_process.num_translation_units; ++i)
        {
            compilation_file_process_t* file_process = ::compilation_process.translation_units[i];
            translation_unit_t* current_translation_unit = file_process->translation_unit;
            const char* extension = get_extension_filename(current_translation_unit->input_filename);
            struct extensions_table_t* current_extension = fileextensions_lookup(extension, strlen(extension));

            if (current_extension->source_language == SOURCE_SUBLANGUAGE_OPENCL)
            {
                const char* basename = give_basename(current_translation_unit->input_filename);
                if (file_clause_arg == std::string(basename))
                {
                    ERROR_CONDITION(found, "%s: error: more than one OpenCL file in the command line matches clause file(%s)\n",
                            original_statements.get_locus_str().c_str(), basename);

                    found = true;
                    kernel_files = std::string(current_translation_unit->input_filename);
                }
            }
        }

        if (!found && !_disable_opencl_file_check)
        {
            error_printf_at(original_statements.get_locus(),
                    "no OpenCL file in the command line matches clause file(%s)\n",
                    file_clause_arg.c_str());
        }
    }
    else
    {
        int ocl_files = 0;
        for (int i = 0; i < ::compilation_process.num_translation_units; ++i)
        {
            compilation_file_process_t* file_process = ::compilation_process.translation_units[i];
            translation_unit_t* current_translation_unit = file_process->translation_unit;
            const char* extension = get_extension_filename(current_translation_unit->input_filename);
            struct extensions_table_t* current_extension = fileextensions_lookup(extension, strlen(extension));

            if (current_extension->source_language == SOURCE_SUBLANGUAGE_OPENCL)
            {
                if (ocl_files > 0)
                    kernel_files += ",";

                kernel_files += std::string(current_translation_unit->input_filename);
                ocl_files++;
            }
        }

        if (ocl_files == 0)
        {
            fatal_printf_at(original_statements.get_locus(),
                    "no OpenCL file specified for kernel '%s'\n",
                    called_task.get_name().c_str());
        }
    }

    // Get the name of the kernel
    std::string kernel_name = info._target_info.get_name();
    if (kernel_name.empty())
    {
        // If the clause name is not present, use the name of the called task
        kernel_name = called_task.get_name();
    }

    Source ndrange_code;
    if (called_task.is_valid()
            && info._target_info.get_ndrange().size() > 0)
    {
        Nodecl::Utils::SimpleSymbolMap param_to_args_map =
            info._target_info.get_param_arg_map();

        generate_ndrange_code(called_task,
                unpacked_function,
                info._target_info,
                kernel_files,
                kernel_name,
                info._data_items,
                &param_to_args_map,
                symbol_map,
                ndrange_code);
    }


    Source unpacked_source;
    if (!IS_FORTRAN_LANGUAGE)
    {
        unpacked_source
            << "{";
    }

    unpacked_source
        << extra_declarations
        << initial_statements
        << ndrange_code
        //<< statement_placeholder(outline_placeholder)
        << final_statements
        ;

    if (!IS_FORTRAN_LANGUAGE)
    {
        unpacked_source
            << "}";
    }
    
    // Fortran may require more symbols
    if (IS_FORTRAN_LANGUAGE)
    {
        // Insert extra symbols
        TL::Scope unpacked_function_scope = unpacked_function_body.retrieve_context();

        Nodecl::Utils::Fortran::ExtraDeclsVisitor fun_visitor(symbol_map,
                unpacked_function_scope,
                current_function);
        if (is_function_task)
        {
            fun_visitor.insert_extra_symbol(info._called_task);
        }
        fun_visitor.insert_extra_symbols(task_statements);

        Nodecl::Utils::Fortran::append_used_modules(
                original_statements.retrieve_context(),
                unpacked_function_scope);

        if (is_function_task)
        {
            Nodecl::Utils::Fortran::append_used_modules(
                    info._called_task.get_related_scope(),
                    unpacked_function_scope);
        }

        // Add also used types
        add_used_types(data_items, unpacked_function.get_related_scope());

        // Now get all the needed internal functions and replicate them in the outline
        Nodecl::Utils::Fortran::InternalFunctions internal_functions;
        internal_functions.walk(info._original_statements);

        duplicate_internal_subprograms(internal_functions.function_codes,
                unpacked_function.get_related_scope(),
                symbol_map,
                output_statements);

        extra_declarations
            << "IMPLICIT NONE\n";
    }
    else if (IS_CXX_LANGUAGE)
    {
        if (!unpacked_function.is_member())
        {
            Nodecl::NodeclBase nodecl_decl = Nodecl::CxxDecl::make(
                    /* optative context */ nodecl_null(),
                    unpacked_function,
                    original_statements.get_locus());
            Nodecl::Utils::prepend_to_enclosing_top_level_location(original_statements, nodecl_decl);
        }
    }

    Nodecl::NodeclBase new_unpacked_body = unpacked_source.parse_statement(unpacked_function_body);
    unpacked_function_body.replace(new_unpacked_body);


    // **** Outline function *****
    ObjectList<std::string> structure_name;
    structure_name.append("args");
    ObjectList<TL::Type> structure_type;
    structure_type.append(
            TL::Type(get_user_defined_type(info._arguments_struct.get_internal_symbol())).get_lvalue_reference_to()
            );

    TL::Symbol outline_function = SymbolUtils::new_function_symbol(
            current_function,
            outline_name,
            TL::Type::get_void_type(),
            structure_name,
            structure_type);

    Nodecl::NodeclBase outline_function_code, outline_function_body;
    SymbolUtils::build_empty_body_for_function(outline_function,
            outline_function_code,
            outline_function_body);
    Nodecl::Utils::append_to_top_level_nodecl(outline_function_code);

    // Prepare arguments for the call to the unpack (or forward in Fortran)
    TL::Scope outline_function_scope(outline_function_body.retrieve_context());
    TL::Symbol structure_symbol = outline_function_scope.get_symbol_from_name("args");
    ERROR_CONDITION(!structure_symbol.is_valid(), "Argument of outline function not found", 0);

    Source unpacked_arguments, cleanup_code;

    for (TL::ObjectList<OutlineDataItem*>::iterator it = data_items.begin();
            it != data_items.end();
            it++)
    {
        if (!is_function_task
                && (*it)->get_is_cxx_this())
            continue;

        switch ((*it)->get_sharing())
        {
            case OutlineDataItem::SHARING_PRIVATE:
                {
                    // Do nothing
                    break;
                }
            case OutlineDataItem::SHARING_SHARED:
            case OutlineDataItem::SHARING_CAPTURE:
            case OutlineDataItem::SHARING_CAPTURE_ADDRESS:
                {
                    TL::Type param_type = (*it)->get_in_outline_type();

                    Source argument;
                    if (IS_C_LANGUAGE || IS_CXX_LANGUAGE)
                    {
                        // Normal shared items are passed by reference from a pointer,
                        // derreference here
                        if ((*it)->get_sharing() == OutlineDataItem::SHARING_SHARED
                                && !(IS_CXX_LANGUAGE && (*it)->get_symbol().get_name() == "this"))
                        {
                            if (!param_type.no_ref().depends_on_nonconstant_values())
                            {
                                argument << "*(args." << (*it)->get_field_name() << ")";
                            }
                            else
                            {
                                TL::Type ptr_type = (*it)->get_in_outline_type().references_to().get_pointer_to();
                                TL::Type cast_type = rewrite_type_of_vla_in_outline(ptr_type, data_items, structure_symbol);

                                argument << "*((" << as_type(cast_type) << ")args." << (*it)->get_field_name() << ")";
                            }
                        }
                        // Any other parameter is bound to the storage of the struct
                        else
                        {
                            if (!param_type.no_ref().depends_on_nonconstant_values())
                            {
                                argument << "args." << (*it)->get_field_name();
                            }
                            else
                            {
                                TL::Type cast_type = rewrite_type_of_vla_in_outline(param_type, data_items, structure_symbol);
                                argument << "(" << as_type(cast_type) << ")args." << (*it)->get_field_name();
                            }
                        }

                        if (IS_CXX_LANGUAGE
                                && (*it)->get_allocation_policy() == OutlineDataItem::ALLOCATION_POLICY_TASK_MUST_DESTROY)
                        {
                            internal_error("Not yet implemented: call the destructor", 0);
                        }
                    }
                    else if (IS_FORTRAN_LANGUAGE)
                    {
                        argument << "args % " << (*it)->get_field_name();

                        if ((*it)->get_allocation_policy()
                                & OutlineDataItem::ALLOCATION_POLICY_TASK_MUST_DEALLOCATE_ALLOCATABLE)
                        {
                            cleanup_code
                                << "IF (ALLOCATED(args % " << (*it)->get_field_name() << ")) THEN\n"
                                <<      "DEALLOCATE(args % " << (*it)->get_field_name() << ")\n"
                                << "ENDIF\n"
                                ;
                        }
                    }
                    else
                    {
                        internal_error("running error", 0);
                    }

                    unpacked_arguments.append_with_separator(argument, ", ");
                    break;
                }
            case OutlineDataItem::SHARING_REDUCTION:
                {
                    // // Pass the original reduced variable as if it were a shared
                    Source argument;
                    if (IS_C_LANGUAGE || IS_CXX_LANGUAGE)
                    {
                        argument << "*(args." << (*it)->get_field_name() << ")";
                    }
                    else if (IS_FORTRAN_LANGUAGE)
                    {
                        argument << "args % " << (*it)->get_field_name();
                    }
                    unpacked_arguments.append_with_separator(argument, ", ");
                    break;
                }
            default:
                {
                    internal_error("Unexpected data sharing kind", 0);
                }
        }
    }

    Source outline_src,
           instrument_before,
           instrument_after;

    if (IS_C_LANGUAGE || IS_CXX_LANGUAGE)
    {
        Source unpacked_function_call;
        if (IS_CXX_LANGUAGE
                && !is_function_task
                && current_function.is_member()
                && !current_function.is_static())
        {
            unpacked_function_call << "args.this_->";
        }

        unpacked_function_call << unpacked_function.get_qualified_name_for_expression(
                /* in_dependent_context */
                (current_function.get_type().is_template_specialized_type()
                 && current_function.get_type().is_dependent())
                ) << "(" << unpacked_arguments << ");";

        outline_src
            << "{"
            <<      instrument_before
            <<      unpacked_function_call
            <<      instrument_after
            <<      cleanup_code
            << "}"
            ;

        if (IS_CXX_LANGUAGE)
        {
            if (!outline_function.is_member())
            {
                Nodecl::NodeclBase nodecl_decl = Nodecl::CxxDecl::make(
                        /* optative context */ nodecl_null(),
                        outline_function,
                        original_statements.get_locus());
                Nodecl::Utils::prepend_to_enclosing_top_level_location(original_statements, nodecl_decl);
            }
        }
    }
    else if (IS_FORTRAN_LANGUAGE)
    {
        Source outline_function_addr;

        outline_src
            << instrument_before << "\n"
            << "CALL " << outline_name << "_forward(" << outline_function_addr << unpacked_arguments << ")\n"
            << instrument_after << "\n"
            << cleanup_code
            ;

        outline_function_addr << "LOC(" << unpacked_function.get_name() << ")";
        if (!unpacked_arguments.empty())
        {
            outline_function_addr << ", ";
        }

        // Copy USEd information to the outline and forward functions
        TL::Symbol *functions[] = { &outline_function, &forward_function, NULL };

        for (int i = 0; functions[i] != NULL; i++)
        {
            TL::Symbol &function(*functions[i]);

            Nodecl::Utils::Fortran::append_used_modules(original_statements.retrieve_context(),
                    function.get_related_scope());

            add_used_types(data_items, function.get_related_scope());
        }

        // Generate ancillary code in C
        add_forward_function_code_to_extra_c_code(outline_name, data_items, outline_function_body);
    }
    else
    {
        internal_error("Code unreachable", 0);
    }

    if (instrumentation_enabled())
    {
        get_instrumentation_code(
                info._called_task,
                outline_function,
                outline_function_body,
                info._task_label,
                original_statements.get_locus(),
                instrument_before,
                instrument_after);
    }

    Nodecl::NodeclBase new_outline_body = outline_src.parse_statement(outline_function_body);
    outline_function_body.replace(new_outline_body);

    // Nodecl::Utils::prepend_to_enclosing_top_level_location(original_statements, outline_function_code);
    //
     //Dummy function call placeholder
     Source unpacked_ndr_code;
     unpacked_ndr_code << statement_placeholder(outline_placeholder);
     Nodecl::NodeclBase new_unpacked_ndr_code = unpacked_ndr_code.parse_statement(unpacked_function_body);
     outline_placeholder=new_unpacked_ndr_code;
}
Exemple #5
0
    void loop_hlt_handler_post(TL::PragmaCustomStatement construct)
    {
        TL::PragmaCustomLine pragma_line = construct.get_pragma_line();
        TL::PragmaCustomClause collapse = construct.get_pragma_line().get_clause("collapse");
        if (!collapse.is_defined())
            return;

        TL::ObjectList<Nodecl::NodeclBase> expr_list = collapse.get_arguments_as_expressions(construct);
        if (expr_list.size() != 1)
        {
            error_printf_at(construct.get_locus(), "'collapse' clause needs exactly one argument\n");
            return;
        }

        Nodecl::NodeclBase expr = expr_list[0];
        if (!expr.is_constant()
                || !is_any_int_type(expr.get_type().get_internal_type()))
        {
            error_printf_at(construct.get_locus(),
                    "'collapse' clause requires an integer constant expression\n");
            return;
        }

        int collapse_factor = const_value_cast_to_signed_int(expr.get_constant());

        if (collapse_factor <= 0)
        {
            error_printf_at(
                    construct.get_locus(),
                    "Non-positive factor (%d) is not allowed in the 'collapse' clause\n",
                    collapse_factor);
        }
        else if (collapse_factor == 1)
        {
            // Removing the collapse clause from the pragma
            pragma_line.remove_clause("collapse");
        }
        else if (collapse_factor > 1)
        {
            Nodecl::NodeclBase loop = get_statement_from_pragma(construct);

            HLT::LoopCollapse loop_collapse;
            loop_collapse.set_loop(loop);
            loop_collapse.set_pragma_context(construct.retrieve_context());
            loop_collapse.set_collapse_factor(collapse_factor);

            loop_collapse.collapse();

            Nodecl::NodeclBase transformed_code = loop_collapse.get_whole_transformation();
            TL::ObjectList<TL::Symbol> capture_symbols = loop_collapse.get_omp_capture_symbols();

            // We may need to add some symbols that are used to implement the collapse clause to the pragma
            std::string names;
            for (TL::ObjectList<TL::Symbol>::iterator it = capture_symbols.begin();
                    it != capture_symbols.end();
                    it++)
            {
                if (it != capture_symbols.begin())
                    names += ",";
                names += it->get_name();
            }
            Nodecl::List clauses = pragma_line.get_clauses().as<Nodecl::List>();
            clauses.append(Nodecl::PragmaCustomClause::make(Nodecl::List::make(Nodecl::PragmaClauseArg::make(names)), "firstprivate"));

            // Removing the collapse clause from the pragma
            pragma_line.remove_clause("collapse");

            // Create a new pragma over the new for stmt
            ERROR_CONDITION(!transformed_code.is<Nodecl::Context>(), "Unexpected node\n", 0);
            Nodecl::NodeclBase compound_statement =
                transformed_code.as<Nodecl::Context>().get_in_context().as<Nodecl::List>().front();

            ERROR_CONDITION(!compound_statement.is<Nodecl::CompoundStatement>(), "Unexpected node\n", 0);
            Nodecl::Context context_for_stmt =
                compound_statement.as<Nodecl::CompoundStatement>().get_statements()
                .as<Nodecl::List>().find_first<Nodecl::Context>();

            Nodecl::Utils::remove_from_enclosing_list(context_for_stmt);

            Nodecl::List stmt_list =
                compound_statement.as<Nodecl::CompoundStatement>().get_statements().as<Nodecl::List>();
            ERROR_CONDITION(stmt_list.is_null(), "Unreachable code\n", 0);

            Nodecl::PragmaCustomStatement new_pragma =
                Nodecl::PragmaCustomStatement::make(pragma_line,
                        Nodecl::List::make(context_for_stmt),
                        construct.get_text(),
                        construct.get_locus());

            stmt_list.append(new_pragma);

            construct.replace(transformed_code);
        }
    }