コード例 #1
0
  bool LoadModel(std::string filename)
  {
    bool success = saveRestore.ReadFile(filename);
    if (success)
      anInt = saveRestore.LoadParameter(anInt, "anInt");

    return success;
  }
コード例 #2
0
int main()
{
 
  string year = "2014";
  string team = "ATL";
  string col = "1";
  // Load observations.
  const string inputFile =  "state_data/2014/2014ATL_state_col0.csv";
  const string modelFile = "simulation_data/" + year + "/" + year + team + "sim_col" + col + ".xml";

  mat dataSeq;
  data::Load(inputFile, dataSeq, true);

  // Load model, but first we have to determine its type.
  SaveRestoreUtility sr;
  sr.ReadFile(modelFile);
  string type;
  sr.LoadParameter(type, "hmm_type");

  arma::Col<size_t> sequence;
  if (type == "discrete")
  {
    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));

    LoadHMM(hmm, sr);

    // Verify only one row in observations.
    if (dataSeq.n_cols == 1)
      dataSeq = trans(dataSeq);

    if (dataSeq.n_rows > 1)
     std::cout<< "Only one-dimensional discrete observations allowed for "<< "discrete HMMs!" << endl;

    hmm.Predict(dataSeq, sequence);
  }
 
  // Save output.
  const string outputFile = "simulation_data/" + year + "/" + year + team + "sim_col" + col + "_data" + ".csv";
  data::Save(outputFile, sequence, true);
}
コード例 #3
0
int main(int argc, char** argv)
{
  // Parse command line options.
  CLI::ParseCommandLine(argc, argv);

  // Set random seed.
  if (CLI::GetParam<int>("seed") != 0)
    RandomSeed((size_t) CLI::GetParam<int>("seed"));
  else
    RandomSeed((size_t) time(NULL));

  // Load observations.
  const string modelFile = CLI::GetParam<string>("model_file");
  const int length = CLI::GetParam<int>("length");
  const int startState = CLI::GetParam<int>("start_state");

  if (length <= 0)
  {
    Log::Fatal << "Invalid sequence length (" << length << "); must be greater "
        << "than or equal to 0!" << endl;
  }

  // Load model, but first we have to determine its type.
  SaveRestoreUtility sr;
  sr.ReadFile(modelFile);
  string emissionType;
  sr.LoadParameter(emissionType, "emission_type");

  mat observations;
  Col<size_t> sequence;
  if (emissionType == "DiscreteDistribution")
  {
    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
    hmm.Load(sr);

    if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
    {
      Log::Fatal << "Invalid start state (" << startState << "); must be "
          << "between 0 and number of states (" << hmm.Transition().n_rows
          << ")!" << endl;
    }

    hmm.Generate(size_t(length), observations, sequence, size_t(startState));
  }
  else if (emissionType == "GaussianDistribution")
  {
    HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
    hmm.Load(sr);

    if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
    {
      Log::Fatal << "Invalid start state (" << startState << "); must be "
          << "between 0 and number of states (" << hmm.Transition().n_rows
          << ")!" << endl;
    }

    hmm.Generate(size_t(length), observations, sequence, size_t(startState));
  }
  else if (emissionType == "GMM")
  {
    HMM<GMM<> > hmm(1, GMM<>(1, 1));
    hmm.Load(sr);

    if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
    {
      Log::Fatal << "Invalid start state (" << startState << "); must be "
          << "between 0 and number of states (" << hmm.Transition().n_rows
          << ")!" << endl;
    }

    hmm.Generate(size_t(length), observations, sequence, size_t(startState));
  }
  else
  {
    Log::Fatal << "Unknown HMM type '" << emissionType << "' in file '" << modelFile
        << "'!" << endl;
  }

  // Save observations.
  const string outputFile = CLI::GetParam<string>("output_file");
  data::Save(outputFile, observations, true);

  // Do we want to save the hidden sequence?
  const string sequenceFile = CLI::GetParam<string>("state_file");
  if (sequenceFile != "")
    data::Save(sequenceFile, sequence, true);
  
  return 0;
}
コード例 #4
0
int main(int argc, char** argv)
{
  // Parse command line options.
  CLI::ParseCommandLine(argc, argv);

  // Load observations.
  const string inputFile = CLI::GetParam<string>("input_file");
  const string modelFile = CLI::GetParam<string>("model_file");

  mat dataSeq;
  data::Load(inputFile, dataSeq, true);

  // Load model, but first we have to determine its type.
  SaveRestoreUtility sr;
  sr.ReadFile(modelFile);
  string type;
  sr.LoadParameter(type, "hmm_type");

  double loglik = 0;
  if (type == "discrete")
  {
    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));

    LoadHMM(hmm, sr);

    // Verify only one row in observations.
    if (dataSeq.n_cols == 1)
      dataSeq = trans(dataSeq);

    if (dataSeq.n_rows > 1)
      Log::Fatal << "Only one-dimensional discrete observations allowed for "
          << "discrete HMMs!" << endl;

    loglik = hmm.LogLikelihood(dataSeq);
  }
  else if (type == "gaussian")
  {
    HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));

    LoadHMM(hmm, sr);

    // Verify correct dimensionality.
    if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
      Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
          << "does not match HMM Gaussian dimensionality ("
          << hmm.Emission()[0].Mean().n_elem << ")!" << endl;

    loglik = hmm.LogLikelihood(dataSeq);
  }
  else if (type == "gmm")
  {
    HMM<GMM<> > hmm(1, GMM<>(1, 1));

    LoadHMM(hmm, sr);

    // Verify correct dimensionality.
    if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
      Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
          << "does not match HMM Gaussian dimensionality ("
          << hmm.Emission()[0].Dimensionality() << ")!" << endl;

    loglik = hmm.LogLikelihood(dataSeq);
  }
  else
  {
    Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
        << "'!" << endl;
  }

  cout << loglik << endl;
}
コード例 #5
0
 bool SaveModel(std::string filename)
 {
   saveRestore.SaveParameter(anInt, "anInt");
   return saveRestore.WriteFile(filename);
 }