예제 #1
0
파일: Agent.cpp 프로젝트: caomw/BBRL
// ===========================================================================
//	Public static methods
// ===========================================================================
Agent* Agent::parse(int argc, char* argv[],
                    bool fromFile,
                    bool fromParameters) throw (parsing::ParsingException)
{     
      //   Get 'agentClassName'
     string agentClassName = parsing::getValue(argc, argv, "--agent");
     
     //   'agentFile' provided
     try
     {
          if (!fromFile) { throw parsing::ParsingException("--agent"); }


          string agentFile = parsing::getValue(argc, argv, "--agent_file");

          ifstream is(agentFile.c_str());
          if (is.fail()) // Unable to open the file
               throw parsing::ParsingException("--agent_file");
          
          return dynamic_cast<Agent*>(
                    Serializable::createInstance(agentClassName, is));
     }
     
     
     //   'agentFile' not provided
     catch (parsing::ParsingException& e)
     {
          if (!fromParameters) { throw parsing::ParsingException("--agent"); }

     
          //   Get 'agent'
          if (agentClassName == RandomAgent::toString())
               return new RandomAgent();
     
          if (agentClassName == OptimalAgent::toString())
               return new OptimalAgent();        
          
          if (agentClassName == EAgent::toString())
          {
               //   Get 'epsilon'         
               string tmp = parsing::getValue(argc, argv, "--epsilon");
               double epsilon = atof(tmp.c_str());
               
               
               //   Get 'base agent'
               int argcBase;
               char** argvBase;
               for (unsigned int i = 1; i < argc; ++i)
               {
                    //   '--base_agent' not found
                    if (string(argv[i]) != "--base_agent") { continue; }
                    
                    
                    //   '--base_agent' found
                    argvBase = new char*[2 + (argc - i - 1)];
                    argvBase[0] = argv[0];
                    char str[] = "--agent";
                    argvBase[1] = str;
                    
                    argcBase = 2;
                    for (unsigned int j = i + 1; j < argc; ++j)
                         argvBase[argcBase++] = argv[j];
                        
                    break;
               }
               Agent* baseAgent = Agent::parse(argcBase, argvBase, false);
               
               
               //   Return
               return new EAgent(epsilon, baseAgent);
          }
          
          if (agentClassName == EGreedyAgent::toString())
          {
               //   Get 'epsilon'         
               string tmp = parsing::getValue(argc, argv, "--epsilon");
               double epsilon = atof(tmp.c_str());
               
               
               //   Return
               return new EGreedyAgent(epsilon);
          }
          
          if (agentClassName == SoftMaxAgent::toString())
          {
               //   Get 'tau'
               string tmp = parsing::getValue(argc, argv, "--tau");
               double tau = atof(tmp.c_str());
               
               
               //   Return
               return new SoftMaxAgent(tau);
     
          }
          
          if (agentClassName == VDBEEGreedyAgent::toString())
          {
               //   Get 'sigma'
               string tmp = parsing::getValue(argc, argv, "--sigma");
               double sigma = atof(tmp.c_str());
               
               
               //   Get 'delta'
               tmp = parsing::getValue(argc, argv, "--delta");
               double delta = atof(tmp.c_str());
     
     
               //   Get 'iniEpsilon'
               tmp = parsing::getValue(argc, argv, "--ini_epsilon");
               double iniEpsilon = atof(tmp.c_str());
               
               
               //   Return
               return new VDBEEGreedyAgent(sigma, delta, iniEpsilon);
          }
          
          if (agentClassName == FormulaAgent::toString())
          {     
               //   Get 'f'
               string fStr = parsing::getValue(argc, argv, "--formula");
               utils::formula::Formula* f = new utils::formula::Formula(fStr);
               
               
               //   Get 'varNameList'
               string tmp = parsing::getValue(argc, argv, "--variables");
               unsigned int nVar = atoi(tmp.c_str());
               
               vector<string> varNameList
                         = parsing::getValues(argc, argv,
                                              "--variables", (nVar + 1));                                              
               varNameList.erase(varNameList.begin());
               
               
               //   Return
               return new FormulaAgent(f, varNameList);
          }
          
          if (agentClassName == BAMCPAgent::toString())
          {     
               //   Get 'K'
               string tmp = parsing::getValue(argc, argv, "--K");
               unsigned int K = atoi(tmp.c_str());
               
               try
               {
                    //   Get 'D'
                    tmp = parsing::getValue(argc, argv, "--D");
                    unsigned int D = atoi(tmp.c_str());               


                    //   Return
                    return new BAMCPAgent(K, D);
               }

               catch (parsing::ParsingException e)
               {
                    return new BAMCPAgent(K);
               }     
          }
          
          if (agentClassName == BFS3Agent::toString())
          {     
               //   Get 'K'
               string tmp = parsing::getValue(argc, argv, "--K");
               unsigned int K = atoi(tmp.c_str());


               //   Get 'C'
               tmp = parsing::getValue(argc, argv, "--C");
               unsigned int C = atoi(tmp.c_str());


               try
               {
                    //   Get 'D'
                    tmp = parsing::getValue(argc, argv, "--D");
                    unsigned int D = atoi(tmp.c_str());               


                    //   Return
                    return new BFS3Agent(K, C, D);
               }

               catch (parsing::ParsingException e)
               {
                    return new BFS3Agent(K, C);
               }
          }
          
          if (agentClassName == SBOSSAgent::toString())
          {     
               //   Get 'K'
               string tmp = parsing::getValue(argc, argv, "--epsilon");
               double epsilon = atof(tmp.c_str());


               //   Get 'delta'
               tmp = parsing::getValue(argc, argv, "--delta");
               double delta = atof(tmp.c_str());


               //   Return
               return new SBOSSAgent(epsilon, delta);
          }
          
          if (agentClassName == BEBAgent::toString())
          {     
               //   Get 'beta'
               string tmp = parsing::getValue(argc, argv, "--beta");
               double beta = atof(tmp.c_str());


               //   Return
               return new BEBAgent(beta);
          }
          
          if (agentClassName == OPPSDSAgent::toString())
          {
               //   Get 'nDraws'
               string tmp = parsing::getValue(argc, argv, "--n_draws");
               unsigned int nDraws = atoi(tmp.c_str());
               
               
               //   Get 'c'
               tmp = parsing::getValue(argc, argv, "--c");
               double c = atof(tmp.c_str());
               
               
               //   Get 'formulaVector'
               FormulaVector* formulaVector = FormulaVector::parse(argc, argv);
               assert(formulaVector);
               
               
               //   Get 'varNameList'
               tmp = parsing::getValue(argc, argv, "--variables");
               unsigned int nVar = atoi(tmp.c_str());
               
               vector<string> varNameList
                         = parsing::getValues(argc, argv,
                                              "--variables", (nVar + 1));                                              
               varNameList.erase(varNameList.begin());
               
               
               //   Get 'gamma'
               tmp = parsing::getValue(argc, argv, "--discount_factor");
               double gamma = atof(tmp.c_str());
               
               
               //   Get 'T'
               tmp = parsing::getValue(argc, argv, "--horizon_limit");
               unsigned int T = atoi(tmp.c_str());
               
               
               //   Build 'strategyList'
                    //   Get 'mdpDistrib'
               MDPDistribution* mdpDistrib = MDPDistribution::parse(argc, argv);
               assert(mdpDistrib);
               
                    //   Generate and store the strategies
               vector<Agent*> strategyList;
               for (unsigned int i = 0; i < formulaVector->size(); ++i)
                    strategyList.push_back(
                              new FormulaAgent((*formulaVector)[i],
                                               varNameList));
                    
                    //   Free 'formulaVector'
               delete formulaVector;
               
               
               //   Return
               return new OPPSDSAgent(nDraws, c, strategyList, gamma, T);
          }
          
          if (agentClassName == OPPSCSAgent::toString())
          {
               //   Get 'n'
               string tmp = parsing::getValue(argc, argv, "--n_eval");
               unsigned int nEval = atoi(tmp.c_str());


               //   Get 'K'
               tmp = parsing::getValue(argc, argv, "--K");
               unsigned int K = atoi(tmp.c_str());


               //   Get 'agentFactory'
               AgentFactory* agentFactory =
                         AgentFactory::parse(argc, argv, true, false);
               assert(agentFactory);
               
               
               //   Get 'gamma'
               tmp = parsing::getValue(argc, argv, "--discount_factor");
               double gamma = atof(tmp.c_str());
               
               
               //   Get 'T'
               tmp = parsing::getValue(argc, argv, "--horizon_limit");
               unsigned int T = atoi(tmp.c_str());
               
               
               //   Check if 'k', 'hMax' & delta are specified
               bool hasSmallK = parsing::hasFlag(argc, argv, "--k");
               bool hasHMax   = parsing::hasFlag(argc, argv, "--h_max");
               bool hasDelta  = parsing::hasFlag(argc, argv, "--delta");
               
               
               //   Case 1:   'k', 'hMax' & 'delta' are specified
               if (hasSmallK && hasHMax && hasDelta)
               {
                    tmp = parsing::getValue(argc, argv, "--k");
                    unsigned int k = atoi(tmp.c_str());

                    tmp = parsing::getValue(argc, argv, "--h_max");
                    unsigned int hMax = atoi(tmp.c_str());

                    tmp = parsing::getValue(argc, argv, "--delta");
                    double delta = atof(tmp.c_str());
                    
                    return new OPPSCSAgent(nEval, K, k, hMax, delta,
                                           agentFactory, gamma, T);
               }


               //  Case 2:    'k', 'hMax' & 'delta' are not specified 
               else if (!hasSmallK && !hasHMax && !hasDelta)
                    return new OPPSCSAgent(nEval, K, agentFactory, gamma, T);
               
               
               //   Case 3:   Among 'k', 'hMax' & 'delta', at least one is
               //             specified and one is not specified
               else
               {
                    tmp = "Cannot define 'k', 'h_max' or 'delta' without ";
                    tmp += "defining the others!\n";
                    throw parsing::ParsingException(tmp);
               }               
          }
          
          if (agentClassName == ANNAgent::toString())
          {
               //   Get 'hiddenLayers'
               string tmp = parsing::getValue(
                    argc, argv, "--hidden_layers");
               unsigned int nLayers = atoi(tmp.c_str());
               
               vector<string> values = parsing::getValues(
                    argc, argv, "--hidden_layers", nLayers + 1);
               
               vector<unsigned int> hiddenLayers;
               for (unsigned int i = 1; i < values.size(); ++i)
                    hiddenLayers.push_back(atoi(values[i].c_str()));


               //   Get 'learningRate'
               tmp = parsing::getValue(argc, argv, "--learning_rate");
               double learningRate = atof(tmp.c_str());

               
               //   Get 'decreasingLearningRate'
               bool decreasingLearningRate =
                    parsing::hasFlag(argc, argv, "--decreasing_learning_rate");

               
               //   Get 'maxEpoch'
               tmp = parsing::getValue(argc, argv, "--max_epoch");
               unsigned int maxEpoch = atoi(tmp.c_str());


               //   Get 'epochRange'
               tmp = parsing::getValue(argc, argv, "--epoch_range");
               unsigned int epochRange = atoi(tmp.c_str());
          
          
               //   Get 'base agent'
               int argcBase;
               char** argvBase;
               for (unsigned int i = 1; i < argc; ++i)
               {
                    //   '--base_agent' not found
                    if (string(argv[i]) != "--base_agent") { continue; }
                    
                    
                    //   '--base_agent' found
                    argvBase = new char*[2 + (argc - i - 1)];
                    argvBase[0] = argv[0];
                    char str[] = "--agent";
                    argvBase[1] = str;
                    
                    argcBase = 2;
                    for (unsigned int j = i + 1; j < argc; ++j)
                         argvBase[argcBase++] = argv[j];
                        
                    break;
               }
               Agent* baseAgent = Agent::parse(argcBase, argvBase, false);
               
               if (dynamic_cast<SLAgent*>(baseAgent) != 0)
               {
                    delete baseAgent;
                    throw parsing::ParsingException( "--base_agent");
               }
               
               
               //   Get 'nbOfMDPs'
               tmp = parsing::getValue(argc, argv, "--n_mdps");
               unsigned int nbOfMDPs = atoi(tmp.c_str());
               
               
               //   Get 'simGamma'
               tmp = parsing::getValue(argc, argv, "--discount_factor");
               double simGamma = atof(tmp.c_str());
               
               
               //   Get 'T'
               tmp = parsing::getValue(argc, argv, "--horizon_limit");
               unsigned int T = atoi(tmp.c_str());
               
               
               //   Get 'SLModelFileName'
               string SLModelFileName =
                         parsing::getValue(argc, argv, "--model_file");
               
          
               //   Return               
               return new ANNAgent(
                         hiddenLayers, learningRate, decreasingLearningRate,
                         maxEpoch, epochRange,
                         baseAgent, nbOfMDPs, simGamma, T, SLModelFileName);
          }
     }
     
     throw parsing::ParsingException("--agent");
}
int main(int argc, char* argv[]) {
  Matrix G;
  Matrix Y;
  Matrix Cov;

  LoadMatrix("input.mt.g", G);
  LoadMatrix("input.mt.y", Y);
  LoadMatrix("input.mt.cov", Cov);
  Cov.SetColumnLabel(0, "c1");
  Cov.SetColumnLabel(1, "c2");
  Y.SetColumnLabel(0, "y1");
  Y.SetColumnLabel(1, "y2");
  Y.SetColumnLabel(2, "y3");

  FormulaVector tests;
  {
    const char* tp1[] = {"y1"};
    const char* tc1[] = {"c1"};
    std::vector<std::string> p1(tp1, tp1 + 1);
    std::vector<std::string> c1(tc1, tc1 + 1);
    tests.add(p1, c1);
  }

  {
    const char* tp1[] = {"y2"};
    const char* tc1[] = {"c2"};
    std::vector<std::string> p1(tp1, tp1 + 1);
    std::vector<std::string> c1(tc1, tc1 + 1);
    tests.add(p1, c1);
  }

  {
    const char* tp1[] = {"y2"};
    const char* tc1[] = {"c1", "c2"};
    std::vector<std::string> p1(tp1, tp1 + 1);
    std::vector<std::string> c1(tc1, tc1 + 2);
    tests.add(p1, c1);
  }

  {
    const char* tp1[] = {"y1"};
    const char* tc1[] = {"1"};
    std::vector<std::string> p1(tp1, tp1 + 1);
    std::vector<std::string> c1(tc1, tc1 + 1);
    tests.add(p1, c1);
  }

  AccurateTimer t;
  {
    FastMultipleTraitLinearRegressionScoreTest mt(1024);

    bool ret = mt.FitNullModel(Cov, Y, tests);
    if (ret == false) {
      printf("Fit null model failed!\n");
      exit(1);
    }

    ret = mt.AddGenotype(G);
    if (ret == false) {
      printf("Add covariate failed!\n");
      exit(1);
    }
    ret = mt.TestCovariateBlock();
    if (ret == false) {
      printf("Test covariate block failed!\n");
      exit(1);
    }
    const Vector& u = mt.GetU(0);
    printf("u\t");
    Print(u);
    printf("\n");

    const Vector& v = mt.GetV(0);
    printf("v\t");
    Print(v);
    printf("\n");

    const Vector& pval = mt.GetPvalue(0);
    printf("pval\t");
    Print(pval);
    printf("\n");
  }
  return 0;
}
bool MultipleTraitLinearRegressionScoreTest::FitNullModel(
    Matrix& cov, Matrix& pheno, const FormulaVector& tests) {
  MultipleTraitLinearRegressionScoreTestInternal& w = *this->work;
  // set some values
  w.N = pheno.rows;
  w.T = pheno.cols;
  w.C = cov.cols;
  w.M = -1;

  w.Y.resize(tests.size());
  w.Z.resize(tests.size());
  w.ZZinv.resize(tests.size());
  w.hasCovariate.resize(tests.size());
  w.missingIndex.resize(tests.size());
  w.Uyz.resize(tests.size());
  w.Ugz.resize(tests.size());
  w.Uyg.resize(tests.size());
  w.sigma2.resize(tests.size());
  w.nTest = tests.size();
  ustat.Dimension(blockSize, tests.size());
  vstat.Dimension(blockSize, tests.size());
  pvalue.Dimension(blockSize, tests.size());

  // create dict (key: phenotype/cov name, val: index)
  std::map<std::string, int> phenoDict;
  std::map<std::string, int> covDict;
  makeColNameToDict(pheno, &phenoDict);
  makeColNameToDict(cov, &covDict);

  // create Y, Z
  std::vector<std::string> phenoName;
  std::vector<std::string> covName;
  std::vector<int> phenoCol;
  std::vector<int> covCol;
  std::vector<std::vector<std::string> > allCovName;
  // arrange Y, Z according to missing pattern for each trait
  for (int i = 0; i < w.nTest; ++i) {
    phenoName = tests.getPhenotype(i);
    phenoCol.clear();
    phenoCol.push_back(phenoDict[phenoName[0]]);
    covName = tests.getCovariate(i);
    allCovName.push_back(covName);
    covCol.clear();
    for (size_t j = 0; j != covName.size(); ++j) {
      if (covName[j] == "1") {
        continue;
      }
      assert(covDict.count(covName[j]));
      covCol.push_back(covDict[covName[j]]);
    }

    w.hasCovariate[i] = covCol.size() > 0;
    makeMatrix(pheno, phenoCol, &w.Y[i]);
    if (w.hasCovariate[i]) {
      makeMatrix(cov, covCol, &w.Z[i]);
    }

    // create index to indicate missingness
    w.missingIndex[i].resize(w.N);
    for (int j = 0; j < w.N; ++j) {
      if (hasMissingInRow(w.Y[i], j)) {
        w.missingIndex[i][j] = true;
        continue;
      } else {
        if (w.hasCovariate[i] && hasMissingInRow(w.Z[i], j)) {
          w.missingIndex[i][j] = true;
          continue;
        }
      }
      w.missingIndex[i][j] = false;
    }
    removeRow(w.missingIndex[i], &w.Y[i]);
    removeRow(w.missingIndex[i], &w.Z[i]);
    if (w.Y[i].rows() == 0) {
      fprintf(stderr, "Due to missingness, there is no sample to test!\n");
      return -1;
    }

    // center and scale Y, Z
    scale(&w.Y[i]);
    scale(&w.Z[i]);

    // calcualte Uzy, inv(Z'Z)
    if (w.hasCovariate[i]) {
      w.ZZinv[i].noalias() =
          (w.Z[i].transpose() * w.Z[i])
              .ldlt()
              .solve(EMat::Identity(w.Z[i].cols(), w.Z[i].cols()));
      w.Uyz[i].noalias() = w.Z[i].transpose() * w.Y[i];
      w.sigma2[i] = (w.Y[i].transpose() * w.Y[i] -
                     w.Uyz[i].transpose() * w.ZZinv[i] * w.Uyz[i])(0, 0) /
                    w.Y[i].rows();
    } else {
      w.sigma2[i] = w.Y[i].col(0).squaredNorm() / w.Y[i].rows();
    }

  }  // end for i

  // Make groups based on model covariats and missing patterns of (Y, Z)
  // Detail:
  // For test: 1, 2, 3, ..., nTest, a possible grouping is:
  // (1, 3), (2), (4, 5) ...
  // =>
  // test_1 => group 0, offset 0
  // test_2 => group 1, offset 0
  // test_3 => group 0, offset 1
  //
  // For each test, we will use its specific
  // [covar_name_1, covar_name_2, ...., missing_pattern], as the value to
  // distingish groups
  std::map<std::vector<std::string>, int> groupDict;
  groupSize = 0;
  for (int i = 0; i < w.nTest; ++i) {
    std::vector<std::string> key = allCovName[i];
    key.push_back(toString(w.missingIndex[i]));
    if (0 == groupDict.count(key)) {
      groupDict[key] = groupSize;
      group.resize(groupSize + 1);
      group[groupSize].push_back(i);
      groupSize++;
    } else {
      group[groupDict[key]].push_back(i);
    }
  }
  // fprintf(stderr, "total %d missingness group\n", groupSize);

  w.G.resize(groupSize);
  w.groupedY.resize(groupSize);
  w.groupedZ.resize(groupSize);
  w.groupedUyz.resize(groupSize);
  w.groupedZZinv.resize(groupSize);
  w.groupedL.resize(groupSize);
  w.ustat.resize(groupSize);
  w.vstat.resize(groupSize);
  w.groupedHasCovariate.resize(groupSize);
  for (int i = 0; i < groupSize; ++i) {
    const int nc = group[i].size();
    const int nr = w.Y[group[i][0]].rows();
    w.groupedY[i].resize(nr, nc);
    for (int j = 0; j < nc; ++j) {
      w.groupedY[i].col(j) = w.Y[group[i][j]];
    }
    // initialize G
    w.G[i].resize(nr, blockSize);
    w.ustat[i].resize(blockSize, nc);
    w.vstat[i].resize(blockSize, 1);
    w.groupedZ[i] = w.Z[group[i][0]];
    w.groupedHasCovariate[i] = w.hasCovariate[group[i][0]];
    if (w.groupedHasCovariate[i]) {
      w.groupedUyz[i] = w.groupedZ[i].transpose() * w.groupedY[i];
    }
    w.groupedZZinv[i] = w.ZZinv[group[i][0]];
    Eigen::LLT<Eigen::MatrixXf> lltOfA(w.groupedZZinv[i]);
    // L * L' = A
    w.groupedL[i] = lltOfA.matrixL();

    // fprintf(stderr, "i = %d, group has covar = %s\n", i,
    //         w.groupedHasCovariate[i] ? "true" : "false");
  }
  // clean up memory
  w.Y.clear();
  w.Z.clear();
  w.Uyz.clear();
  w.hasCovariate.clear();
  w.ZZinv.clear();

  return true;
}