inline static bool run(nd::array &a) { const ndt::type &tp = a.get_type(); if (a.is_immutable() && tp.get_type_id() == fixed_dim_type_id) { // It's immutable and "N * <something>" const ndt::type &et = tp.extended<fixed_dim_type>()->get_element_type(); const fixed_dim_type_arrmeta *md = reinterpret_cast<const fixed_dim_type_arrmeta *>(a.get_arrmeta()); if (et.get_type_id() == type_type_id && md->stride == sizeof(ndt::type)) { // It also has the right type and is contiguous, // so no modification necessary. return true; } } // We have to make a copy, check that it's a 1D array, and that // it has the same array kind as the requested type. if (tp.get_ndim() == 1) { // It's a 1D array const ndt::type &et = tp.get_type_at_dimension(NULL, 1).value_type(); if (et.get_type_id() == type_type_id) { // It also has the same array type as requested nd::array tmp = nd::empty(a.get_dim_size(), ndt::make_type()); tmp.vals() = a; tmp.flag_as_immutable(); a.swap(tmp); return true; } } // It's not compatible, so return false return false; }
void dynd::lift_reduction_arrfunc(arrfunc_type_data *out_ar, const nd::arrfunc& elwise_reduction_arr, const ndt::type& lifted_arr_type, const nd::arrfunc& dst_initialization_arr, bool keepdims, intptr_t reduction_ndim, const bool *reduction_dimflags, bool associative, bool commutative, bool right_associative, const nd::array& reduction_identity) { // Validate the input elwise_reduction arrfunc if (elwise_reduction_arr.is_null()) { throw runtime_error("lift_reduction_arrfunc: 'elwise_reduction' may not be empty"); } const arrfunc_type_data *elwise_reduction = elwise_reduction_arr.get(); if (elwise_reduction->get_param_count() != 1 && !(elwise_reduction->get_param_count() == 2 && elwise_reduction->get_param_type(0) == elwise_reduction->get_param_type(1) && elwise_reduction->get_param_type(0) == elwise_reduction->get_return_type())) { stringstream ss; ss << "lift_reduction_arrfunc: 'elwise_reduction' must contain a" " unary operation ckernel or a binary expr ckernel with all " "equal types, its prototype is " << elwise_reduction->func_proto; throw invalid_argument(ss.str()); } lifted_reduction_arrfunc_data *self = new lifted_reduction_arrfunc_data; *out_ar->get_data_as<lifted_reduction_arrfunc_data *>() = self; out_ar->free_func = &delete_lifted_reduction_arrfunc_data; self->child_elwise_reduction = elwise_reduction_arr; self->child_dst_initialization = dst_initialization_arr; if (!reduction_identity.is_null()) { if (reduction_identity.is_immutable() && reduction_identity.get_type() == elwise_reduction->get_return_type()) { self->reduction_identity = reduction_identity; } else { self->reduction_identity = nd::empty(elwise_reduction->get_return_type()); self->reduction_identity.vals() = reduction_identity; self->reduction_identity.flag_as_immutable(); } } // Figure out the result type ndt::type lifted_dst_type = elwise_reduction->get_return_type(); for (intptr_t i = reduction_ndim - 1; i >= 0; --i) { if (reduction_dimflags[i]) { if (keepdims) { lifted_dst_type = ndt::make_strided_dim(lifted_dst_type); } } else { ndt::type subtype = lifted_arr_type.get_type_at_dimension(NULL, i); switch (subtype.get_type_id()) { case strided_dim_type_id: case cfixed_dim_type_id: lifted_dst_type = ndt::make_strided_dim(lifted_dst_type); break; case var_dim_type_id: lifted_dst_type = ndt::make_var_dim(lifted_dst_type); break; default: { stringstream ss; ss << "lift_reduction_arrfunc: don't know how to process "; ss << "dimension of type " << subtype; throw type_error(ss.str()); } } } } self->data_types[0] = lifted_dst_type; self->data_types[1] = lifted_arr_type; self->reduction_ndim = reduction_ndim; self->associative = associative; self->commutative = commutative; self->right_associative = right_associative; self->reduction_dimflags.init(reduction_ndim); memcpy(self->reduction_dimflags.get(), reduction_dimflags, sizeof(bool) * reduction_ndim); out_ar->instantiate = &instantiate_lifted_reduction_arrfunc_data; out_ar->func_proto = ndt::make_funcproto(lifted_arr_type, lifted_dst_type); }