NDArrayView* CreateDataPtr<NDArrayView>(const NDArrayView& value) { // TODO: replace this copy with an alias to value. NDArrayView* viewPtr = new NDArrayView(value.GetDataType(), value.Shape(), DeviceDescriptor::CPUDevice()); viewPtr->CopyFrom(value); return viewPtr; }
static void WriteInt8Data(const NDArrayView& src, io::CodedOutputStream& output) { // Write raw bytes. auto size = src.Shape().TotalSize(); const int8_t* buffer = src.DataBuffer<int8_t>(); output.WriteRaw(buffer, size); }
void NDArrayView::CopyFrom(const NDArrayView& source) { if (source.Shape() != Shape()) InvalidArgument("NDArrayView::CopyFrom: The 'source' view's shape must be same as the shape of this NDArrayView"); if (IsReadOnly()) RuntimeError("NDArrayView::CopyFrom: Cannot modify contents of a readonly NDArrayView"); switch (m_dataType) { case DataType::Float: { auto sourceMatrix = source.GetMatrix<float>(); auto destMatrix = GetWritableMatrix<float>(); destMatrix->AssignValuesOf(*sourceMatrix); break; } case DataType::Double: { auto sourceMatrix = source.GetMatrix<double>(); auto destMatrix = GetWritableMatrix<double>(); destMatrix->AssignValuesOf(*sourceMatrix); break; } default: LogicError("Unsupported DataType %s", DataTypeName(m_dataType)); break; } }
static void WriteInt16Data(const NDArrayView& src, io::CodedOutputStream& output) { auto size = src.Shape().TotalSize(); const int16_t* buffer = src.DataBuffer<int16_t>(); for (auto i = 0; i < size; i++) { auto value = buffer[i]; output.WriteVarint32SignExtended(Encode<int16_t, int16_t>(value)); } }
static void CopyData(const NDArrayView& src, RepeatedField<DstT>* dst) { auto size = src.Shape().TotalSize(); dst->Resize((int)size, DstT()); const SrcT* buffer = src.DataBuffer<SrcT>(); if (std::is_same<SrcT, DstT>::value) memcpy(dst->mutable_data(), buffer, (int)size * sizeof(DstT)); else for (size_t i = 0; i < size; i++) dst->mutable_data()[i] = (DstT)buffer[i]; }
static bool ReadData(RenewableCodedStream& input, NDArrayView& dst) { auto size = dst.Shape().TotalSize(); DstT* buffer = dst.WritableDataBuffer<DstT>(); for (auto i = 0; i < size; i++) { SrcT value; if (!input.Read<SrcT>(&value)) return false; buffer[i] = (DstT)value; } return true; }
static void WriteData(const NDArrayView& src, io::CodedOutputStream& output) { auto size = src.Shape().TotalSize(); const T* buffer = src.DataBuffer<T>(); auto tSize = sizeof(T); for (auto i = 0; i < size; i++) { auto value = buffer[i]; if (tSize <= sizeof(uint32)) { output.WriteLittleEndian32(Encode<T, uint32>((float)value)); } else { output.WriteLittleEndian64(Encode<T, uint64>(value)); } } }
static bool ReadInt8Data(io::ZeroCopyInputStream& input, NDArrayView& dst) { const void* temp; int readSize; size_t totalSize = 0; bool success; do { success = input.Next(&temp, &readSize); totalSize += readSize; } while (success && readSize == 0); if (!success) return false; auto size = dst.Shape().TotalSize(); if (totalSize != size) return false; int8_t* buffer = dst.WritableDataBuffer<int8_t>(); memcpy(buffer, temp, size); return true; }