NDArrayViewPtr Variable::Value() const { if (!IsConstant() && !IsParameter()) LogicError("Only Variables of kind Parameter and Constant have a Value!"); if (m_dataFields->m_value == nullptr) { assert(m_dataFields->m_valueInitializer); assert(m_dataFields->m_valueInitializationDevice); switch (GetDataType()) { case DataType::Float: { m_dataFields->m_value = CreateValueFromParameterInitializer<float>(Shape(), *m_dataFields->m_valueInitializer, *m_dataFields->m_valueInitializationDevice); break; } case DataType::Double: { m_dataFields->m_value = CreateValueFromParameterInitializer<double>(Shape(), *m_dataFields->m_valueInitializer, *m_dataFields->m_valueInitializationDevice); break; } default: LogicError("Unsupported DataType %s", DataTypeName(GetDataType())); break; } m_dataFields->m_valueInitializer = nullptr; m_dataFields->m_valueInitializationDevice = nullptr; } assert(m_dataFields->m_value != nullptr); return m_dataFields->m_value; }
void NDArrayView::CopyFrom(const NDArrayView& source) { if (source.Shape() != Shape()) InvalidArgument("NDArrayView::CopyFrom: The 'source' view's shape must be same as the shape of this NDArrayView"); if (IsReadOnly()) RuntimeError("NDArrayView::CopyFrom: Cannot modify contents of a readonly NDArrayView"); switch (m_dataType) { case DataType::Float: { auto sourceMatrix = source.GetMatrix<float>(); auto destMatrix = GetWritableMatrix<float>(); destMatrix->AssignValuesOf(*sourceMatrix); break; } case DataType::Double: { auto sourceMatrix = source.GetMatrix<double>(); auto destMatrix = GetWritableMatrix<double>(); destMatrix->AssignValuesOf(*sourceMatrix); break; } default: LogicError("Unsupported DataType %s", DataTypeName(m_dataType)); break; } }
NDArrayViewPtr NDArrayView::DeepClone(bool readOnly/* = false*/) const { NDArrayViewPtr newView(new NDArrayView(this->GetDataType(), this->GetStorageFormat(), this->Shape(), this->Device()), [](_ReferenceCounter* ptr) { delete ptr; }); switch (m_dataType) { case DataType::Float: { auto newMatrix = newView->GetWritableMatrix<float>(); auto thisMatrix = GetMatrix<float>(); newMatrix->AssignValuesOf(*thisMatrix); break; } case DataType::Double: { auto newMatrix = newView->GetWritableMatrix<double>(); auto thisMatrix = GetMatrix<double>(); newMatrix->AssignValuesOf(*thisMatrix); break; } default: LogicError("Unsupported DataType %s", DataTypeName(m_dataType)); break; } newView->m_isReadOnly = readOnly; return NDArrayViewPtr(newView, [](_ReferenceCounter* ptr) { delete ptr; }); }
void PackedValue::Unpack() const { if (m_packedDataLayout && (m_packedDataLayout->GetNumTimeSteps() != 1) && (m_packedDataLayout->GetNumSequences() != 1) && Internal::IsAutomaticUnpackingOfPackedValuesDisabled()) LogicError("PackedValue::Unpack: Automatic unpacking of PackedValue objects is disabled"); if (m_isPacked) { ValuePtr valueObject; auto dataType = m_packedData->GetDataType(); switch (dataType) { case DataType::Float: valueObject = CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(m_sampleShape, *(m_packedData->GetMatrix<float>()), m_packedDataLayout, m_isReadOnly); break; case DataType::Double: valueObject = CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(m_sampleShape, *(m_packedData->GetMatrix<double>()), m_packedDataLayout, m_isReadOnly); break; default: LogicError("Unsupported DataType %s", DataTypeName(dataType)); } m_data = valueObject->Data(); m_mask = valueObject->Mask(); m_packedData = nullptr; m_packedDataLayout = nullptr; m_isPacked = false; if (m_unpackedShape != m_data->Shape()) LogicError("The computed unpacked shape of the PackedValue object does not match the actual Data NDArrayView's shape after unpacking"); } }
NDArrayView::~NDArrayView() { switch (m_dataType) { case DataType::Float: delete GetTensorView<float>(); break; case DataType::Double: delete GetTensorView<double>(); break; default: LogicError("Unsupported DataType %s", DataTypeName(m_dataType)); break; } }
static void* AllocateTensorView(CNTK::DataType dataType, CNTK::StorageFormat storageType, const NDShape& viewShape, const DeviceDescriptor& device) { switch (dataType) { case DataType::Float: return AllocateTensorView<float>(viewShape, storageType, device); case DataType::Double: return AllocateTensorView<double>(viewShape, storageType, device); default: LogicError("Unsupported DataType %s", DataTypeName(dataType)); break; } }
static void* AllocateTensorView(CNTK::DataType dataType, const NDShape& viewShape, const DeviceDescriptor& device, void* dataBuffer, size_t bufferSizeInBytes) { switch (dataType) { case DataType::Float: return AllocateTensorView<float>(viewShape, device, dataBuffer, bufferSizeInBytes); case DataType::Double: return AllocateTensorView<double>(viewShape, device, dataBuffer, bufferSizeInBytes); default: LogicError("Unsupported DataType %s", DataTypeName(dataType)); break; } }
NDArrayViewPtr NDArrayView::Alias(bool readOnly/* = false*/) const { void* tensorView = nullptr; switch (m_dataType) { case DataType::Float: tensorView = new TensorView<float>(*(GetTensorView<float>())); break; case DataType::Double: tensorView = new TensorView<double>(*(GetTensorView<double>())); break; default: LogicError("Unsupported DataType %s", DataTypeName(m_dataType)); break; } auto aliasView = new NDArrayView(GetDataType(), Device(), GetStorageFormat(), Shape(), IsReadOnly() || readOnly, tensorView);; return NDArrayViewPtr(aliasView, [](_ReferenceCounter* ptr) { delete ptr; }); }
const ElementType* NDArrayView::DataBuffer() const { if (AsDataType<ElementType>() != m_dataType) LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(m_dataType)); if (IsSparse()) InvalidArgument("DataBuffer/WritableDataBuffer methods can only be called for NDArrayiew objects with dense storage format"); // First make sure that the underlying matrix is on the right device auto matrix = GetMatrix<ElementType>(); matrix->TransferToDeviceIfNotThere(AsCNTKImplDeviceId(m_device), true); return matrix->Data(); }
const TensorView<ElementType>* NDArrayView::GetTensorView() const { if (AsDataType<ElementType>() != m_dataType) LogicError("NDArrayView::GetWritableTensorView: The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(m_dataType)); return (const TensorView<ElementType>*)(m_tensorView); }