static void pyobject_as_irange_array(intptr_t& out_size, shortvector<irange>& out_indices, PyObject *subscript) { if (!PyTuple_Check(subscript)) { // A single subscript out_size = 1; out_indices.init(1); out_indices[0] = pyobject_as_irange(subscript); } else { out_size = PyTuple_GET_SIZE(subscript); // Tuple of subscripts out_indices.init(out_size); for (Py_ssize_t i = 0; i < out_size; ++i) { out_indices[i] = pyobject_as_irange(PyTuple_GET_ITEM(subscript, i)); } } }
void dynd::broadcast_input_shapes(intptr_t ninputs, const nd::array* inputs, intptr_t& out_undim, dimvector& out_shape, shortvector<int>& out_axis_perm) { // Get the number of broadcast dimensions intptr_t undim = inputs[0].get_ndim(); for (intptr_t i = 0; i < ninputs; ++i) { intptr_t candidate_undim = inputs[i].get_ndim(); if (candidate_undim > undim) { undim = candidate_undim; } } out_undim = undim; out_shape.init(undim); out_axis_perm.init(undim); intptr_t *shape = out_shape.get(); // Fill in the broadcast shape for (intptr_t k = 0; k < undim; ++k) { shape[k] = 1; } dimvector tmpshape(undim); for (intptr_t i = 0; i < ninputs; ++i) { intptr_t input_undim = inputs[i].get_ndim(); inputs[i].get_shape(tmpshape.get()); intptr_t dimdelta = undim - input_undim; for (intptr_t k = dimdelta; k < undim; ++k) { intptr_t size = tmpshape[k - dimdelta]; intptr_t itershape_size = shape[k]; if (itershape_size == 1) { shape[k] = size; } else if (size < 0) { // A negative shape value means variable-sized if (itershape_size > 0) { shape[k] = -itershape_size; } else { shape[k] = -1; } } else if (itershape_size >= 0) { if (size != 1 && itershape_size != size) { //cout << "operand " << i << ", comparing size " << itershape_size << " vs " << size << "\n"; throw broadcast_error(ninputs, inputs); } } else { // itershape_size < 0 if (itershape_size == -1 && size > 0) { shape[k] = -size; } else if (size > 1 && itershape_size != -size) { throw broadcast_error(ninputs, inputs); } } } } // Fill in the axis permutation if (undim > 1) { int *axis_perm = out_axis_perm.get(); // TODO: keeporder behavior, currently always C order for (intptr_t i = 0; i < undim; ++i) { axis_perm[i] = int(undim - i - 1); } } else if (undim == 1) { out_axis_perm[0] = 0; } }