void sortBatched(Array<T>& val, bool isAscending) { af::dim4 inDims = val.dims(); // Sort dimension af::dim4 tileDims(1); af::dim4 seqDims = inDims; tileDims[dim] = inDims[dim]; seqDims[dim] = 1; Array<uint> key = iota<uint>(seqDims, tileDims); Array<uint> resKey = createEmptyArray<uint>(dim4()); Array<T> resVal = createEmptyArray<T>(dim4()); val.setDataDims(inDims.elements()); key.setDataDims(inDims.elements()); sort_by_key<T, uint>(resVal, resKey, val, key, 0, isAscending); // Needs to be ascending (true) in order to maintain the indices properly sort_by_key<uint, T>(key, val, resKey, resVal, 0, true); val.eval(); val.setDataDims(inDims); // This is correct only for dim0 }
void sortByKeyBatched(Array<Tk> okey, Array<Tv> oval, const int dim, bool isAscending) { af::dim4 inDims = okey.dims(); af::dim4 tileDims(1); af::dim4 seqDims = inDims; tileDims[dim] = inDims[dim]; seqDims[dim] = 1; uint* key = memAlloc<uint>(inDims.elements()); // IOTA { af::dim4 dims = inDims; uint* out = key; af::dim4 strides(1); for(int i = 1; i < 4; i++) strides[i] = strides[i-1] * dims[i-1]; for(dim_t w = 0; w < dims[3]; w++) { dim_t offW = w * strides[3]; uint okeyW = (w % seqDims[3]) * seqDims[0] * seqDims[1] * seqDims[2]; for(dim_t z = 0; z < dims[2]; z++) { dim_t offWZ = offW + z * strides[2]; uint okeyZ = okeyW + (z % seqDims[2]) * seqDims[0] * seqDims[1]; for(dim_t y = 0; y < dims[1]; y++) { dim_t offWZY = offWZ + y * strides[1]; uint okeyY = okeyZ + (y % seqDims[1]) * seqDims[0]; for(dim_t x = 0; x < dims[0]; x++) { dim_t id = offWZY + x; out[id] = okeyY + (x % seqDims[0]); } } } } } // initialize original index locations Tk *okey_ptr = okey.get(); Tv *oval_ptr = oval.get(); typedef KeyIndexPair<Tk, Tv> CurrentTuple; size_t size = okey.elements(); size_t bytes = okey.elements() * sizeof(CurrentTuple); CurrentTuple *tupleKeyValIdx = (CurrentTuple *)memAlloc<char>(bytes); for(unsigned i = 0; i < size; i++) { tupleKeyValIdx[i] = std::make_tuple(okey_ptr[i], oval_ptr[i], key[i]); } memFree(key); // key is no longer required if(isAscending) { std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareV<Tk, Tv, true>()); } else { std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareV<Tk, Tv, false>()); } std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareK<Tk, Tv, true>()); for(unsigned x = 0; x < okey.elements(); x++) { okey_ptr[x] = std::get<0>(tupleKeyValIdx[x]); oval_ptr[x] = std::get<1>(tupleKeyValIdx[x]); } memFree((char *)tupleKeyValIdx); return; }