Пример #1
0
  static void Apply(HMMType& hmm, void* /* extraInfo */)
  {
    mat observations;
    Row<size_t> sequence;

    // Load the parameters.
    const size_t startState = (size_t) CLI::GetParam<int>("start_state");
    const size_t length = (size_t) CLI::GetParam<int>("length");

    Log::Info << "Generating sequence of length " << length << "..." << endl;
    if (startState >= hmm.Transition().n_rows)
      Log::Fatal << "Invalid start state (" << startState << "); must be "
          << "between 0 and number of states (" << hmm.Transition().n_rows
          << ")!" << endl;

    hmm.Generate(length, observations, sequence, startState);

    // Now save the output.
    if (CLI::HasParam("output"))
      CLI::GetParam<mat>("output") = std::move(observations);

    // Do we want to save the hidden sequence?
    if (CLI::HasParam("state"))
      CLI::GetParam<Mat<size_t>>("state") = std::move(sequence);
  }
Пример #2
0
  static void Apply(HMMType& hmm, void* /* extraInfo */)
  {
    mat observations;
    Row<size_t> sequence;

    // Load the parameters.
    const size_t startState = (size_t) CLI::GetParam<int>("start_state");
    const size_t length = (size_t) CLI::GetParam<int>("length");
    const string outputFile = CLI::GetParam<string>("output_file");
    const string sequenceFile = CLI::GetParam<string>("state_file");

    Log::Info << "Generating sequence of length " << length << "..." << endl;
    if (startState >= hmm.Transition().n_rows)
      Log::Fatal << "Invalid start state (" << startState << "); must be "
          << "between 0 and number of states (" << hmm.Transition().n_rows
          << ")!" << endl;

    hmm.Generate(length, observations, sequence, startState);

    // Now save the output.
    if (CLI::HasParam("output_file"))
      data::Save(outputFile, observations, true);

    // Do we want to save the hidden sequence?
    if (CLI::HasParam("state_file"))
      data::Save(sequenceFile, sequence, true);

    if (outputFile == "" && sequenceFile == "")
      Log::Warn << "Neither --output_file nor --state_file are specified; no "
          << "output will be saved." << endl;
  }
Пример #3
0
  static void Apply(HMMType& hmm, void* /* extraInfo */)
  {
    // Load observations.
    const string inputFile = CLI::GetParam<string>("input_file");

    mat dataSeq;
    data::Load(inputFile, dataSeq, true);

    // See if transposing the data could make it the right dimensionality.
    if ((dataSeq.n_cols == 1) && (hmm.Emission()[0].Dimensionality() == 1))
    {
      Log::Info << "Data sequence appears to be transposed; correcting."
          << endl;
      dataSeq = dataSeq.t();
    }

    // Verify correct dimensionality.
    if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
      Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
          << "does not match HMM Gaussian dimensionality ("
          << hmm.Emission()[0].Dimensionality() << ")!" << endl;

    arma::Col<size_t> sequence;
    hmm.Predict(dataSeq, sequence);

    // Save output.
    const string outputFile = CLI::GetParam<string>("output_file");
    data::Save(outputFile, sequence, true);
  }
Пример #4
0
  static void Apply(HMMType& hmm, vector<mat>* trainSeqPtr)
  {
    const bool batch = CLI::HasParam("batch");
    const double tolerance = CLI::GetParam<double>("tolerance");

    // Do we need to replace the tolerance?
    if (CLI::HasParam("tolerance"))
      hmm.Tolerance() = tolerance;

    const string labelsFile = CLI::GetParam<string>("labels_file");

    // Verify that the dimensionality of our observations is the same as the
    // dimensionality of our HMM's emissions.
    vector<mat>& trainSeq = *trainSeqPtr;
    for (size_t i = 0; i < trainSeq.size(); ++i)
      if (trainSeq[i].n_rows != hmm.Emission()[0].Dimensionality())
        Log::Fatal << "Dimensionality of training sequence " << i << " ("
            << trainSeq[i].n_rows << ") is not equal to the dimensionality of "
            << "the HMM (" << hmm.Emission()[0].Dimensionality() << ")!"
            << endl;

    vector<arma::Row<size_t>> labelSeq; // May be empty.
    if (labelsFile != "")
    {
      // Do we have multiple label files to load?
      char lineBuf[1024];
      if (batch)
      {
        fstream f(labelsFile);

        if (!f.is_open())
          Log::Fatal << "Could not open '" << labelsFile << "' for reading."
              << endl;

        // Now read each line in.
        f.getline(lineBuf, 1024, '\n');
        while (!f.eof())
        {
          Log::Info << "Adding training sequence labels from '" << lineBuf
              << "'." << endl;

          // Now read the matrix.
          Mat<size_t> label;
          data::Load(lineBuf, label, true); // Fatal on failure.

          // Ensure that matrix only has one row.
          if (label.n_cols == 1)
            label = trans(label);

          if (label.n_rows > 1)
            Log::Fatal << "Invalid labels; must be one-dimensional." << endl;

          // Check all of the labels.
          for (size_t i = 0; i < label.n_cols; ++i)
          {
            if (label[i] >= hmm.Transition().n_cols)
            {
              Log::Fatal << "HMM has " << hmm.Transition().n_cols << " hidden "
                  << "states, but label on line " << i << " of '" << lineBuf
                  << "' is " << label[i] << " (should be between 0 and "
                  << (hmm.Transition().n_cols - 1) << ")!" << endl;
            }
          }

          labelSeq.push_back(label.row(0));

          f.getline(lineBuf, 1024, '\n');
        }

        f.close();
      }
      else
      {
        Mat<size_t> label;
        data::Load(labelsFile, label, true);

        // Ensure that matrix only has one row.
        if (label.n_cols == 1)
          label = trans(label);

        if (label.n_rows > 1)
          Log::Fatal << "Invalid labels; must be one-dimensional." << endl;

        // Verify the same number of observations as the data.
        if (label.n_elem != trainSeq[labelSeq.size()].n_cols)
          Log::Fatal << "Label sequence " << labelSeq.size() << " does not have"
              << " the same number of points as observation sequence "
              << labelSeq.size() << "!" << endl;

        // Check all of the labels.
        for (size_t i = 0; i < label.n_cols; ++i)
        {
          if (label[i] >= hmm.Transition().n_cols)
          {
            Log::Fatal << "HMM has " << hmm.Transition().n_cols << " hidden "
                << "states, but label on line " << i << " of '" << labelsFile
                << "' is " << label[i] << " (should be between 0 and "
                << (hmm.Transition().n_cols - 1) << ")!" << endl;
          }
        }

        labelSeq.push_back(label.row(0));
      }

      // Now perform the training with labels.
      hmm.Train(trainSeq, labelSeq);
    }
    else
    {
      // Perform unsupervised training.
      hmm.Train(trainSeq);
    }

    // Save the model.
    if (CLI::HasParam("output_model_file"))
    {
      const string modelFile = CLI::GetParam<string>("output_model_file");
      SaveHMM(hmm, modelFile);
    }
  }