예제 #1
0
파일: Variable.cpp 프로젝트: rlugojr/CNTK
    void Variable::SetValue(const NDArrayViewPtr& value)
    {
        if (!IsParameter())
            LogicError("Variable::SetValue can be only invoked on a Parameter variable!");
        else if (GetDataType() != value->GetDataType()) 
            LogicError("Variable::SetValue: 'source' and 'destination' have different data types!");
        else if (Shape() != value->Shape() && (AsTensorShape(Shape()) != AsTensorShape(value->Shape())))
            LogicError("Variable::SetValue: 'source' and 'destination' have different shapes!");

        bool alreadySet = false;
        if (m_dataFields->m_initValueFlag)
        {
            // In the case of lazy initialization, try to avoid the redundant call to the initializer. 
            std::call_once(*m_dataFields->m_initValueFlag, [=, &value, &alreadySet] {
                // If the variable hasn't been initialized yet, clone the content of the supplied value and delete the initializer.
                m_dataFields->m_value = value->DeepClone(*m_dataFields->m_valueInitializationDevice, false);
                m_dataFields->m_valueInitializer = nullptr;
                m_dataFields->m_valueInitializationDevice = nullptr;
                alreadySet = true;
            });
        }

        assert(m_dataFields->m_value != nullptr);
        if (!alreadySet)
        {
            // alreadySet is false, the lambda above wasn't called and the variable has been initialized before,
            // get a pointer to its value and simply copy the content of the supplied value.
            m_dataFields->m_value->CopyFrom(*value);
        }
    }
예제 #2
0
void TestNDArrayView(size_t numAxes, const DeviceDescriptor& device)
{
    srand(1);

    size_t maxDimSize = 15;
    NDShape viewShape(numAxes);
    for (size_t i = 0; i < numAxes; ++i)
        viewShape[i] = (rand() % maxDimSize) + 1;

    // Create a NDArrayView over a std::array
    std::array<ElementType, 1> arrayData = { 3 };
    auto arrayDataView = MakeSharedObject<NDArrayView>(NDShape({}), arrayData);
    if (arrayDataView->template DataBuffer<ElementType>() != arrayData.data())
        throw std::runtime_error("The DataBuffer of the NDArrayView does not match the original buffer it was created over");

    std::vector<ElementType> data(viewShape.TotalSize());
    ElementType scale = 19.0;
    ElementType offset = -4.0;
    for (size_t i = 0; i < viewShape.TotalSize(); ++i)
        data[i] = offset + ((((ElementType)rand()) / RAND_MAX) * scale);

    auto cpuDataView = MakeSharedObject<NDArrayView>(viewShape, data);
    if (cpuDataView->template DataBuffer<ElementType>() != data.data())
        throw std::runtime_error("The DataBuffer of the NDArrayView does not match the original buffer it was created over");

    NDArrayViewPtr dataView;
    if ((device.Type() == DeviceKind::CPU))
        dataView = cpuDataView;
    else
    {
        dataView = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), viewShape, device);
        dataView->CopyFrom(*cpuDataView);
    }

    if (dataView->Device() != device)
        throw std::runtime_error("Device of NDArrayView does not match 'device' it was created on");

    // Test clone
    auto clonedView = dataView->DeepClone(false);
    ElementType* first = nullptr;
    const ElementType* second = cpuDataView->template DataBuffer<ElementType>();
    NDArrayViewPtr temp1CpuDataView, temp2CpuDataView;
    if ((device.Type() == DeviceKind::CPU))
    {
        if (dataView->DataBuffer<ElementType>() != data.data())
            throw std::runtime_error("The DataBuffer of the NDArrayView does not match the original buffer it was created over");

        first = clonedView->WritableDataBuffer<ElementType>();
    }
    else
    {
        temp1CpuDataView = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), viewShape, DeviceDescriptor::CPUDevice());
        temp1CpuDataView->CopyFrom(*clonedView);

        first = temp1CpuDataView->WritableDataBuffer<ElementType>();
    }

    for (size_t i = 0; i < viewShape.TotalSize(); ++i)
    {
        if (first[i] != second[i])
            throw std::runtime_error("The contents of the clone do not match expected");
    }

    first[0] += 1;
    if ((device.Type() != DeviceKind::CPU))
        clonedView->CopyFrom(*temp1CpuDataView);

    if ((device.Type() == DeviceKind::CPU))
    {
        first = clonedView->WritableDataBuffer<ElementType>();
        second = dataView->DataBuffer<ElementType>();
    }
    else
    {
        temp1CpuDataView = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), viewShape, DeviceDescriptor::CPUDevice());
        temp1CpuDataView->CopyFrom(*clonedView);
        first = temp1CpuDataView->WritableDataBuffer<ElementType>();

        temp2CpuDataView = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), viewShape, DeviceDescriptor::CPUDevice());
        temp2CpuDataView->CopyFrom(*dataView);
        second = temp2CpuDataView->DataBuffer<ElementType>();
    }

    if (first[0] != (second[0] + 1))
        throw std::runtime_error("The clonedView's contents do not match expected");

    // Test alias
    auto aliasView = clonedView->Alias(true);
    const ElementType* aliasViewBuffer = aliasView->DataBuffer<ElementType>();
    const ElementType* clonedDataBuffer = clonedView->DataBuffer<ElementType>();
    if (aliasViewBuffer != clonedDataBuffer)
        throw std::runtime_error("The buffers underlying the alias view and the view it is an alias of are different!");

    clonedView->CopyFrom(*dataView);
    if (aliasViewBuffer != clonedDataBuffer)
        throw std::runtime_error("The buffers underlying the alias view and the view it is an alias of are different!");

    // Test readonliness
    auto errorMsg = "Was incorrectly able to get a writable buffer pointer from a readonly view";

    // Should not be able to get the WritableDataBuffer for a read-only view
    VerifyException([&aliasView]() {
        ElementType* aliasViewBuffer = aliasView->WritableDataBuffer<ElementType>();
        aliasViewBuffer;
    }, errorMsg);

    // Should not be able to copy into a read-only view
    VerifyException([&aliasView, &dataView]() {
        aliasView->CopyFrom(*dataView);
    }, errorMsg);
}