예제 #1
0
파일: sort.cpp 프로젝트: 9prady9/arrayfire
void sortBatched(Array<T>& val, bool isAscending) {
    af::dim4 inDims = val.dims();

    // Sort dimension
    af::dim4 tileDims(1);
    af::dim4 seqDims = inDims;
    tileDims[dim]    = inDims[dim];
    seqDims[dim]     = 1;

    Array<uint> key = iota<uint>(seqDims, tileDims);

    Array<uint> resKey = createEmptyArray<uint>(dim4());
    Array<T> resVal    = createEmptyArray<T>(dim4());

    val.setDataDims(inDims.elements());
    key.setDataDims(inDims.elements());

    sort_by_key<T, uint>(resVal, resKey, val, key, 0, isAscending);

    // Needs to be ascending (true) in order to maintain the indices properly
    sort_by_key<uint, T>(key, val, resKey, resVal, 0, true);
    val.eval();

    val.setDataDims(inDims);  // This is correct only for dim0
}
예제 #2
0
void sortByKeyBatched(Array<Tk> okey, Array<Tv> oval, const int dim, bool isAscending)
{
    af::dim4 inDims = okey.dims();

    af::dim4 tileDims(1);
    af::dim4 seqDims = inDims;
    tileDims[dim] = inDims[dim];
    seqDims[dim] = 1;

    uint* key = memAlloc<uint>(inDims.elements());
    // IOTA
    {
        af::dim4 dims    = inDims;
        uint* out        = key;
        af::dim4 strides(1);
        for(int i = 1; i < 4; i++)
            strides[i] = strides[i-1] * dims[i-1];

        for(dim_t w = 0; w < dims[3]; w++) {
            dim_t offW = w * strides[3];
            uint okeyW = (w % seqDims[3]) * seqDims[0] * seqDims[1] * seqDims[2];
            for(dim_t z = 0; z < dims[2]; z++) {
                dim_t offWZ = offW + z * strides[2];
                uint okeyZ = okeyW + (z % seqDims[2]) * seqDims[0] * seqDims[1];
                for(dim_t y = 0; y < dims[1]; y++) {
                    dim_t offWZY = offWZ + y * strides[1];
                    uint okeyY = okeyZ + (y % seqDims[1]) * seqDims[0];
                    for(dim_t x = 0; x < dims[0]; x++) {
                        dim_t id = offWZY + x;
                        out[id] = okeyY + (x % seqDims[0]);
                    }
                }
            }
        }
    }

    // initialize original index locations
    Tk *okey_ptr = okey.get();
    Tv *oval_ptr = oval.get();

    typedef KeyIndexPair<Tk, Tv> CurrentTuple;
    size_t size = okey.elements();
    size_t bytes = okey.elements() * sizeof(CurrentTuple);
    CurrentTuple *tupleKeyValIdx = (CurrentTuple *)memAlloc<char>(bytes);

    for(unsigned i = 0; i < size; i++) {
        tupleKeyValIdx[i] = std::make_tuple(okey_ptr[i], oval_ptr[i], key[i]);
    }

    memFree(key); // key is no longer required

    if(isAscending) {
      std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareV<Tk, Tv, true>());
    }
    else {
      std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareV<Tk, Tv, false>());
    }

    std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareK<Tk, Tv, true>());

    for(unsigned x = 0; x < okey.elements(); x++) {
        okey_ptr[x] = std::get<0>(tupleKeyValIdx[x]);
        oval_ptr[x] = std::get<1>(tupleKeyValIdx[x]);
    }

    memFree((char *)tupleKeyValIdx);
    return;
}