Exemplo n.º 1
0
    NDArrayViewPtr NDArrayView::Alias(bool readOnly/* = false*/) const
    {
        void* tensorView = nullptr;
        switch (m_dataType)
        {
        case DataType::Float:
            tensorView = new TensorView<float>(*(GetTensorView<float>()));
            break;
        case DataType::Double:
            tensorView = new TensorView<double>(*(GetTensorView<double>()));
            break;
        default:
            LogicError("Unsupported DataType %s", DataTypeName(m_dataType));
            break;
        }

        auto aliasView = new NDArrayView(GetDataType(), Device(), GetStorageFormat(), Shape(), IsReadOnly() || readOnly, tensorView);;
        return NDArrayViewPtr(aliasView, [](_ReferenceCounter* ptr) { delete ptr; });
    }
Exemplo n.º 2
0
void Sorter::Iterate(CodeGen &codegen, llvm::Value *sorter_ptr,
                     Sorter::IterateCallback &callback) const {
  struct TaatIterateCallback : VectorizedIterateCallback {
    const UpdateableStorage &storage;
    Sorter::IterateCallback &callback;

    TaatIterateCallback(const UpdateableStorage &s, Sorter::IterateCallback &c)
        : storage(s), callback(c) {}

    void ProcessEntries(CodeGen &codegen, llvm::Value *start_index,
                        llvm::Value *end_index, SorterAccess &access) const {
      lang::Loop loop(codegen,
                      codegen->CreateICmpULT(start_index, end_index),
                      {{"start", start_index}});
      {
        llvm::Value *curr_index = loop.GetLoopVar(0);

        // Parse the row
        std::vector<codegen::Value> vals;
        auto &row = access.GetRow(curr_index);
        for (uint32_t i = 0; i < storage.GetNumElements(); i++) {
          vals.emplace_back(row.LoadColumn(codegen, i));
        }

        // Call the actual callback
        callback.ProcessEntry(codegen, vals);

        curr_index = codegen->CreateAdd(curr_index, codegen.Const32(1));
        loop.LoopEnd(codegen->CreateICmpULT(curr_index, end_index),
                     {curr_index});
      }
    }
  };

  // Do a vectorized iteration using our callback adapter
  TaatIterateCallback taat_cb(GetStorageFormat(), callback);
  VectorizedIterate(codegen, sorter_ptr, Vector::kDefaultVectorSize, 0,
                    taat_cb);
}