static void Apply(HMMType& hmm, void* /* extraInfo */) { mat observations; Row<size_t> sequence; // Load the parameters. const size_t startState = (size_t) CLI::GetParam<int>("start_state"); const size_t length = (size_t) CLI::GetParam<int>("length"); const string outputFile = CLI::GetParam<string>("output_file"); const string sequenceFile = CLI::GetParam<string>("state_file"); Log::Info << "Generating sequence of length " << length << "..." << endl; if (startState >= hmm.Transition().n_rows) Log::Fatal << "Invalid start state (" << startState << "); must be " << "between 0 and number of states (" << hmm.Transition().n_rows << ")!" << endl; hmm.Generate(length, observations, sequence, startState); // Now save the output. if (CLI::HasParam("output_file")) data::Save(outputFile, observations, true); // Do we want to save the hidden sequence? if (CLI::HasParam("state_file")) data::Save(sequenceFile, sequence, true); if (outputFile == "" && sequenceFile == "") Log::Warn << "Neither --output_file nor --state_file are specified; no " << "output will be saved." << endl; }
static void Apply(HMMType& hmm, void* /* extraInfo */) { mat observations; Row<size_t> sequence; // Load the parameters. const size_t startState = (size_t) CLI::GetParam<int>("start_state"); const size_t length = (size_t) CLI::GetParam<int>("length"); Log::Info << "Generating sequence of length " << length << "..." << endl; if (startState >= hmm.Transition().n_rows) Log::Fatal << "Invalid start state (" << startState << "); must be " << "between 0 and number of states (" << hmm.Transition().n_rows << ")!" << endl; hmm.Generate(length, observations, sequence, startState); // Now save the output. if (CLI::HasParam("output")) CLI::GetParam<mat>("output") = std::move(observations); // Do we want to save the hidden sequence? if (CLI::HasParam("state")) CLI::GetParam<Mat<size_t>>("state") = std::move(sequence); }
static void Apply(HMMType& hmm, vector<mat>* trainSeqPtr) { const bool batch = CLI::HasParam("batch"); const double tolerance = CLI::GetParam<double>("tolerance"); // Do we need to replace the tolerance? if (CLI::HasParam("tolerance")) hmm.Tolerance() = tolerance; const string labelsFile = CLI::GetParam<string>("labels_file"); // Verify that the dimensionality of our observations is the same as the // dimensionality of our HMM's emissions. vector<mat>& trainSeq = *trainSeqPtr; for (size_t i = 0; i < trainSeq.size(); ++i) if (trainSeq[i].n_rows != hmm.Emission()[0].Dimensionality()) Log::Fatal << "Dimensionality of training sequence " << i << " (" << trainSeq[i].n_rows << ") is not equal to the dimensionality of " << "the HMM (" << hmm.Emission()[0].Dimensionality() << ")!" << endl; vector<arma::Row<size_t>> labelSeq; // May be empty. if (labelsFile != "") { // Do we have multiple label files to load? char lineBuf[1024]; if (batch) { fstream f(labelsFile); if (!f.is_open()) Log::Fatal << "Could not open '" << labelsFile << "' for reading." << endl; // Now read each line in. f.getline(lineBuf, 1024, '\n'); while (!f.eof()) { Log::Info << "Adding training sequence labels from '" << lineBuf << "'." << endl; // Now read the matrix. Mat<size_t> label; data::Load(lineBuf, label, true); // Fatal on failure. // Ensure that matrix only has one row. if (label.n_cols == 1) label = trans(label); if (label.n_rows > 1) Log::Fatal << "Invalid labels; must be one-dimensional." << endl; // Check all of the labels. for (size_t i = 0; i < label.n_cols; ++i) { if (label[i] >= hmm.Transition().n_cols) { Log::Fatal << "HMM has " << hmm.Transition().n_cols << " hidden " << "states, but label on line " << i << " of '" << lineBuf << "' is " << label[i] << " (should be between 0 and " << (hmm.Transition().n_cols - 1) << ")!" << endl; } } labelSeq.push_back(label.row(0)); f.getline(lineBuf, 1024, '\n'); } f.close(); } else { Mat<size_t> label; data::Load(labelsFile, label, true); // Ensure that matrix only has one row. if (label.n_cols == 1) label = trans(label); if (label.n_rows > 1) Log::Fatal << "Invalid labels; must be one-dimensional." << endl; // Verify the same number of observations as the data. if (label.n_elem != trainSeq[labelSeq.size()].n_cols) Log::Fatal << "Label sequence " << labelSeq.size() << " does not have" << " the same number of points as observation sequence " << labelSeq.size() << "!" << endl; // Check all of the labels. for (size_t i = 0; i < label.n_cols; ++i) { if (label[i] >= hmm.Transition().n_cols) { Log::Fatal << "HMM has " << hmm.Transition().n_cols << " hidden " << "states, but label on line " << i << " of '" << labelsFile << "' is " << label[i] << " (should be between 0 and " << (hmm.Transition().n_cols - 1) << ")!" << endl; } } labelSeq.push_back(label.row(0)); } // Now perform the training with labels. hmm.Train(trainSeq, labelSeq); } else { // Perform unsupervised training. hmm.Train(trainSeq); } // Save the model. if (CLI::HasParam("output_model_file")) { const string modelFile = CLI::GetParam<string>("output_model_file"); SaveHMM(hmm, modelFile); } }