예제 #1
0
 inline static bool run(nd::array &a)
 {
   const ndt::type &tp = a.get_type();
   if (a.is_immutable() && tp.get_type_id() == fixed_dim_type_id) {
     // It's immutable and "N * <something>"
     const ndt::type &et = tp.extended<fixed_dim_type>()->get_element_type();
     const fixed_dim_type_arrmeta *md =
         reinterpret_cast<const fixed_dim_type_arrmeta *>(a.get_arrmeta());
     if (et.get_type_id() == type_type_id &&
         md->stride == sizeof(ndt::type)) {
       // It also has the right type and is contiguous,
       // so no modification necessary.
       return true;
     }
   }
   // We have to make a copy, check that it's a 1D array, and that
   // it has the same array kind as the requested type.
   if (tp.get_ndim() == 1) {
     // It's a 1D array
     const ndt::type &et = tp.get_type_at_dimension(NULL, 1).value_type();
     if (et.get_type_id() == type_type_id) {
       // It also has the same array type as requested
       nd::array tmp = nd::empty(a.get_dim_size(), ndt::make_type());
       tmp.vals() = a;
       tmp.flag_as_immutable();
       a.swap(tmp);
       return true;
     }
   }
   // It's not compatible, so return false
   return false;
 }
void dynd::lift_reduction_arrfunc(arrfunc_type_data *out_ar,
                const nd::arrfunc& elwise_reduction_arr,
                const ndt::type& lifted_arr_type,
                const nd::arrfunc& dst_initialization_arr,
                bool keepdims,
                intptr_t reduction_ndim,
                const bool *reduction_dimflags,
                bool associative,
                bool commutative,
                bool right_associative,
                const nd::array& reduction_identity)
{
    // Validate the input elwise_reduction arrfunc
    if (elwise_reduction_arr.is_null()) {
        throw runtime_error("lift_reduction_arrfunc: 'elwise_reduction' may not be empty");
    }
    const arrfunc_type_data *elwise_reduction = elwise_reduction_arr.get();
    if (elwise_reduction->get_param_count() != 1 &&
            !(elwise_reduction->get_param_count() == 2 &&
              elwise_reduction->get_param_type(0) ==
                  elwise_reduction->get_param_type(1) &&
              elwise_reduction->get_param_type(0) ==
                  elwise_reduction->get_return_type())) {
        stringstream ss;
        ss << "lift_reduction_arrfunc: 'elwise_reduction' must contain a"
              " unary operation ckernel or a binary expr ckernel with all "
              "equal types, its prototype is " << elwise_reduction->func_proto;
        throw invalid_argument(ss.str());
    }

    lifted_reduction_arrfunc_data *self = new lifted_reduction_arrfunc_data;
    *out_ar->get_data_as<lifted_reduction_arrfunc_data *>() = self;
    out_ar->free_func = &delete_lifted_reduction_arrfunc_data;
    self->child_elwise_reduction = elwise_reduction_arr;
    self->child_dst_initialization = dst_initialization_arr;
    if (!reduction_identity.is_null()) {
        if (reduction_identity.is_immutable() &&
                reduction_identity.get_type() == elwise_reduction->get_return_type()) {
            self->reduction_identity = reduction_identity;
        } else {
            self->reduction_identity = nd::empty(elwise_reduction->get_return_type());
            self->reduction_identity.vals() = reduction_identity;
            self->reduction_identity.flag_as_immutable();
        }
    }

    // Figure out the result type
    ndt::type lifted_dst_type = elwise_reduction->get_return_type();
    for (intptr_t i = reduction_ndim - 1; i >= 0; --i) {
        if (reduction_dimflags[i]) {
            if (keepdims) {
                lifted_dst_type = ndt::make_strided_dim(lifted_dst_type);
            }
        } else {
            ndt::type subtype = lifted_arr_type.get_type_at_dimension(NULL, i);
            switch (subtype.get_type_id()) {
                case strided_dim_type_id:
                case cfixed_dim_type_id:
                    lifted_dst_type = ndt::make_strided_dim(lifted_dst_type);
                    break;
                case var_dim_type_id:
                    lifted_dst_type = ndt::make_var_dim(lifted_dst_type);
                    break;
                default: {
                    stringstream ss;
                    ss << "lift_reduction_arrfunc: don't know how to process ";
                    ss << "dimension of type " << subtype;
                    throw type_error(ss.str());
                }
            }
        }
    }
    self->data_types[0] = lifted_dst_type;
    self->data_types[1] = lifted_arr_type;
    self->reduction_ndim = reduction_ndim;
    self->associative = associative;
    self->commutative = commutative;
    self->right_associative = right_associative;
    self->reduction_dimflags.init(reduction_ndim);
    memcpy(self->reduction_dimflags.get(), reduction_dimflags, sizeof(bool) * reduction_ndim);

    out_ar->instantiate = &instantiate_lifted_reduction_arrfunc_data;
    out_ar->func_proto = ndt::make_funcproto(lifted_arr_type, lifted_dst_type);
}