Пример #1
0
// TODO: This class should go away eventually.
// TODO: The composition of packer + randomizer + different deserializers in a generic manner is done in the CompositeDataReader.
// TODO: Currently preserving this for backward compatibility with current configs.
CNTKTextFormatReader::CNTKTextFormatReader(MemoryProviderPtr provider,
    const ConfigParameters& config) :
    m_provider(provider)
{
    TextConfigHelper configHelper(config);

    try
    {
        if (configHelper.GetElementType() == ElementType::tfloat)
        {
            m_deserializer = shared_ptr<IDataDeserializer>(new TextParser<float>(configHelper));
        }
        else
        {
            m_deserializer = shared_ptr<IDataDeserializer>(new TextParser<double>(configHelper));
        }

        if (configHelper.ShouldKeepDataInMemory()) 
        {
            m_deserializer = shared_ptr<IDataDeserializer>(new ChunkCache(m_deserializer));
        }

        size_t window = configHelper.GetRandomizationWindow();
        if (window > 0)
        {
            // Verbosity is a general config parameter, not specific to the text format reader.
            int verbosity = config(L"verbosity", 0);
            m_randomizer = make_shared<BlockRandomizer>(verbosity, window, m_deserializer);
        }
        else
        {
            m_randomizer = std::make_shared<NoRandomizer>(m_deserializer);
        }

        if (configHelper.IsInFrameMode()) 
        {
            m_packer = std::make_shared<FramePacker>(
                m_provider,
                m_randomizer,
                GetStreamDescriptions());
        }
        else
        {
        m_packer = std::make_shared<SequencePacker>(
            m_provider,
            m_randomizer,
            GetStreamDescriptions());
        }
    }
    catch (const std::runtime_error& e)
    {
        RuntimeError("CNTKTextFormatReader: While reading '%ls': %s", configHelper.GetFilePath().c_str(), e.what());
    }
}
Пример #2
0
void ReaderBase::StartEpoch(const EpochConfiguration& config, const std::map<std::wstring, int>& inputDescriptions)
{
    if (config.m_totalEpochSizeInSamples == 0)
    {
        RuntimeError("Epoch size cannot be 0.");
    }

    // Let's check that streams requested for this epoch match the ones for previous epochs.
    // If not, update them.
    auto streams = GetStreamDescriptions();
    if (inputDescriptions.size() != m_requiredInputs.size()
        || !std::equal(inputDescriptions.begin(), inputDescriptions.end(), m_requiredInputs.begin()))
    {
        m_requiredInputs = inputDescriptions;

        // Reallocating memory providers.
        m_memoryProviders.resize(streams.size());
        for (size_t i = 0; i < streams.size(); ++i)
        {
            // TODO: In case when the network requires less inputs,
            // we should not even have them.
            if (m_requiredInputs.find(streams[i].m_name) == m_requiredInputs.end())
            {
                m_memoryProviders[i] = std::make_shared<HeapMemoryProvider>();
                continue;
            }

            int deviceId = m_requiredInputs[streams[i].m_name];
            if (deviceId < 0)
                m_memoryProviders[i] = std::make_shared<HeapMemoryProvider>();
            else
                m_memoryProviders[i] = std::make_shared<CudaMemoryProvider>(deviceId);
        }
    }

    m_sequenceEnumerator->StartEpoch(config);
    m_packer->SetConfiguration(config, m_memoryProviders);
}