static util::StringVector getReferenceLabels(network::Bundle& bundle, const model::Model& model) { if(bundle.contains("referenceLabels")) { util::StringVector labels; auto& referenceLabels = bundle["referenceLabels"].get<matrix::LabelVector>(); for(auto& referenceLabel : referenceLabels) { labels.push_back(std::string()); for(auto& grapheme : referenceLabel) { labels.back() += model.getOutputLabel(grapheme); } } return labels; } auto& referenceActivations = bundle["referenceActivations"].get<matrix::MatrixVector>().front(); return convertActivationsToLabels(std::move(referenceActivations), model); }
static util::StringVector convertActivationsToGraphemeLabels(matrix::Matrix&& activations, const model::Model& model) { assert(activations.size().size() >= 3); if(activations.size().size() > 3) { activations = reshape(activations, {activations.size().front(), activations.size()[1], activations.size().product() / (activations.size().front() * activations.size()[1])}); } size_t timesteps = activations.size()[2]; size_t miniBatchSize = activations.size()[1]; size_t graphemes = activations.size()[0]; util::StringVector labels; for(size_t miniBatch = 0; miniBatch < miniBatchSize; ++miniBatch) { std::string currentLabel; size_t currentGrapheme = graphemes; for(size_t timestep = 0; timestep < timesteps; ++timestep) { size_t maxGrapheme = 0; double maxValue = std::numeric_limits<double>::min(); for(size_t grapheme = 0; grapheme < graphemes; ++grapheme) { if(activations(grapheme, miniBatch, timestep) >= maxValue) { maxValue = activations(grapheme, miniBatch, timestep); maxGrapheme = grapheme; } } if(maxGrapheme != currentGrapheme) { currentGrapheme = maxGrapheme; auto newGrapheme = model.getOutputLabel(maxGrapheme); currentLabel.insert(currentLabel.end(), newGrapheme.begin(), newGrapheme.end()); } } labels.push_back(currentLabel); } return labels; }
static util::StringVector convertActivationsToLabels(matrix::Matrix&& activations, const model::Model& model) { if(model.hasAttribute("UsesGraphemes") && model.getAttribute<bool>("UsesGraphemes")) { return convertActivationsToGraphemeLabels(std::move(activations), model); } if(activations.size().size() > 2) { activations = reshape(activations, {activations.size().front(), activations.size().product() / activations.size().front()}); } size_t samples = activations.size()[1]; size_t columns = activations.size()[0]; util::StringVector labels; for(size_t sample = 0; sample < samples; ++sample) { size_t maxColumn = 0; double maxValue = 0.0f; for(size_t column = 0; column < columns; ++column) { if(activations(column, sample) >= maxValue) { maxValue = activations(column, sample); maxColumn = column; } } labels.push_back(model.getOutputLabel(maxColumn)); } return labels; }