Ejemplo n.º 1
0
static util::StringVector getReferenceLabels(network::Bundle& bundle,
    const model::Model& model)
{
    if(bundle.contains("referenceLabels"))
    {
        util::StringVector labels;

        auto& referenceLabels = bundle["referenceLabels"].get<matrix::LabelVector>();

        for(auto& referenceLabel : referenceLabels)
        {
            labels.push_back(std::string());

            for(auto& grapheme : referenceLabel)
            {
                labels.back() += model.getOutputLabel(grapheme);
            }
        }

        return labels;
    }

    auto& referenceActivations =
        bundle["referenceActivations"].get<matrix::MatrixVector>().front();

    return convertActivationsToLabels(std::move(referenceActivations), model);
}
Ejemplo n.º 2
0
static util::StringVector convertActivationsToGraphemeLabels(matrix::Matrix&& activations,
    const model::Model& model)
{
    assert(activations.size().size() >= 3);

    if(activations.size().size() > 3)
    {
        activations = reshape(activations,
            {activations.size().front(),
             activations.size()[1],
             activations.size().product() / (activations.size().front() * activations.size()[1])});
    }

    size_t timesteps     = activations.size()[2];
    size_t miniBatchSize = activations.size()[1];
    size_t graphemes     = activations.size()[0];

    util::StringVector labels;

    for(size_t miniBatch = 0; miniBatch < miniBatchSize; ++miniBatch)
    {
        std::string currentLabel;

        size_t currentGrapheme = graphemes;

        for(size_t timestep = 0; timestep < timesteps; ++timestep)
        {
            size_t maxGrapheme = 0;
            double maxValue    = std::numeric_limits<double>::min();

            for(size_t grapheme = 0; grapheme < graphemes; ++grapheme)
            {
                if(activations(grapheme, miniBatch, timestep) >= maxValue)
                {
                    maxValue    = activations(grapheme, miniBatch, timestep);
                    maxGrapheme = grapheme;
                }
            }

            if(maxGrapheme != currentGrapheme)
            {
                currentGrapheme = maxGrapheme;
                auto newGrapheme = model.getOutputLabel(maxGrapheme);

                currentLabel.insert(currentLabel.end(), newGrapheme.begin(), newGrapheme.end());
            }

        }

        labels.push_back(currentLabel);
    }

    return labels;
}
Ejemplo n.º 3
0
static util::StringVector convertActivationsToLabels(matrix::Matrix&& activations,
    const model::Model& model)
{
    if(model.hasAttribute("UsesGraphemes") && model.getAttribute<bool>("UsesGraphemes"))
    {
        return convertActivationsToGraphemeLabels(std::move(activations), model);
    }

    if(activations.size().size() > 2)
    {
        activations = reshape(activations,
            {activations.size().front(),
             activations.size().product() / activations.size().front()});
    }

    size_t samples = activations.size()[1];
    size_t columns = activations.size()[0];

    util::StringVector labels;

    for(size_t sample = 0; sample < samples; ++sample)
    {
        size_t maxColumn = 0;
        double maxValue  = 0.0f;

        for(size_t column = 0; column < columns; ++column)
        {
            if(activations(column, sample) >= maxValue)
            {
                maxValue  = activations(column, sample);
                maxColumn = column;
            }
        }

        labels.push_back(model.getOutputLabel(maxColumn));
    }

    return labels;
}