Exemplo n.º 1
0
void App() {
  long t1;
  (void) time(&t1);
  seedMT(t1);
  float em_converged = 1e-4;
  int em_max_iter = 20;
  int em_estimate_alpha = 1; //1 indicate estimate alpha and 0 use given value
  int var_max_iter = 30;
  double var_converged = 1e-6;
  double initial_alpha = 0.1;
  int n_topic = 30;
  LDA lda;
  lda.Init(em_converged, em_max_iter, em_estimate_alpha, var_max_iter,
                         var_converged, initial_alpha, n_topic);
  Corpus cor;
  //Str data = "../../data/ap.dat";
  Str data = "lda_data";
  cor.LoadData(data);
  Corpus train;
  Corpus test;
  double p = 0.8;
  SplitData(cor, p, &train, &test);
  Str type = "seeded";
  LdaModel m;
  lda.RunEM(type, train, test, &m);

  LOG(INFO) << m.alpha;
  VVReal gamma;
  VVVReal phi;
  lda.Infer(test, m, &gamma, &phi);
  WriteStrToFile(Join(gamma, " ", "\n"), "gamma");
  WriteStrToFile(Join(phi, " ", "\n", "\n\n"), "phi");
}
Exemplo n.º 2
0
int main(int argc, char* argv[]) {
	corpus* corpus;

	long t1;
	(void) time(&t1);
	seedMT(t1);
		// seedMT(4357U);

	if (argc > 1)
	{
		if (strcmp(argv[1], "est")==0)
		{
			INITIAL_ALPHA = atof(argv[2]);
			NTOPICS = atoi(argv[3]);
			read_settings(argv[4]);
			corpus = read_data(argv[5]);
			make_directory(argv[7]);
			run_em(argv[6], argv[7], corpus);
		}
		if (strcmp(argv[1], "inf")==0)
		{
			read_settings(argv[2]);
			corpus = read_data(argv[4]);
			infer(argv[3], argv[5], corpus);
		}
	}
	else
	{
		printf("usage : lda est [initial alpha] [k] [settings] [data] [random/seeded/*] [directory]\n");
		printf("        lda inf [settings] [model] [data] [name]\n");
	}
	return(0);
}
int main(void) {
    int j, k;
    
    // you can seed with any uint32, but the best are odds in 0..(2^32 - 1)
    
    seedMT(4357);
    uint32 MAX=pow(2, 32)-1;

// print the first 2,002 random numbers seven to a line as an example
//    for(j=0; j<2002; j++)
//        printf(" %10lu%s", (unsigned long) randomMT(), (j%7)==6 ? "\n" : "");
    
    double test_val;
    for(k=0;k<100;k++)
        for(j=0; j<2000002; j++) {
            test_val = ((double)randomMT()/(double)MAX);
            if (test_val>=1.0){
                printf("Problem");
                return(0);
            }
            //printf(" %f%s", test_val , (j%7)==6 ? "\n" : "");
        }
    printf("Success");
    return(1);
}
Exemplo n.º 4
0
void set_seed(unsigned long the_seed)
{
  /* use the_seed to seed the generator */

#if DO_DEBUG
  fprintf(debug_file,"set_seed: s= %ld\n",the_seed);
  fflush(debug_file);
#endif	  

#if USE_MTWIST
  seedMT(the_seed);
#else
  srand(the_seed);
#endif
}; /* end get_seed */
Exemplo n.º 5
0
int main(int argc, char **argv)
{
  struct ddnsrequest p = { 0 };
  stralloc username = {0};
  uint32 ttl;

  VERSIONINFO;

  /* chroot() to $ROOT and switch to $UID:$GID */
  droprootordie("ddnsd: ");

  // XXX should this be thigtened ?
  umask(024);

  /* seed some entropy into the MT */
  seedMT((long long) getpid () *
	 (long long) time(0) *
	 (long long) getppid() * 
	 (long long) random() * 
	 (long long) clock());

  datadir = argv[1];
  if (!datadir) usage();

  /* read one ddns packet from stdin which should be 
     connected to the tcpstream */
  ddnsd_recive(&p, &ttl, &username);
      
  switch(p.type)
    {
    case DDNS_T_SETENTRY:
      ddnsd_setentry(&p, &ttl, &username);
      break;
    case DDNS_T_RENEWENTRY:
      ddnsd_renewentry(&p, &ttl, &username);
      break;
    case DDNS_T_KILLENTRY:       
      ddnsd_killentry(&p, &ttl, &username);
      break;
    default:
      ddnsd_send_err(p.uid, DDNS_T_EPROTERROR, "unsupported type/command");
    }
  
  return 0;
}
Exemplo n.º 6
0
void open_crypto(lua_State * L)
{
    urandom = fopen("/dev/urandom","r");
    if(!urandom)
    {
        fprintf(stderr, "Crypto module warning: couldn't open /dev/urandom for reading -- will be using randomMT() instead!\n");
    }
    else
    {
        seedMT((unsigned int)getnanoseconds());
    }
    
    static luaL_Reg functions[] = {
        {"md5sum", md5sum},
        {"tigersum", tigersum},
    #ifndef WITHOUT_OPENSSL
        {"hmac", hmac::create_object},
    #endif
        {NULL, NULL}
    };
    
    luaL_register(L, "crypto", functions);
    
    static luaL_Reg ecc_functions[] = {
        {"generate_key_pair", ecc::generate_key_pair},
        {"key", ecc::create_key},
        {"answer_challenge", ecc::answer_challenge},
        {NULL, NULL}
    };
    
    lua_newtable(L);
    luaL_register(L, NULL, ecc_functions);
    lua_setfield(L, -2, "sauerecc");
    lua_pop(L, 1);
    
    ecc::key::register_class(L);
    ecc::challenge::register_class(L);
    
    #ifndef WITHOUT_OPENSSL
    hmac::register_class(L);
    #endif
    
    lua::on_shutdown(L, shutdown_crypto);
}
Exemplo n.º 7
0
void App() {
  long t1;
  (void) time(&t1);
  seedMT(t1);
  float em_converged = 1e-4;
  int em_max_iter = 40;
  int em_estimate_alpha = 1; //1 indicate estimate alpha and 0 use given value
  int var_max_iter = 50;
  double var_converged = 1e-6;
  double initial_alpha = 0.1;
  VarRTM var;
  var.Init(em_converged, em_max_iter, em_estimate_alpha, var_max_iter,
                         var_converged, initial_alpha, FLAGS_topic_num);
  var.Load(FLAGS_net_path, FLAGS_cor_path);
  
  SpMat test;
  ReadData(FLAGS_net_path, &test);
  RTM rtm(FLAGS_topic_num, FLAGS_alpha);
  var.RunEM(test, &rtm);
 
}
Exemplo n.º 8
0
//
// Start of main program.
// IF things are set to read from a robot's sensors and not a data log, then this would be the best place
// to actually put in controls for the robot's behaviors and actions. The main SLAM process is called as a
// seperate thread off of this function.
//
int main (int argc, char *argv[])
{
  //  carmen_warn("Random seed: %d\n", carmen_randomize(&argc, &argv));

  if (argc != 2) 
    carmen_die("Usage: model_learner <logfile>\n");

  RECORDING =  (char*) "";
  PLAYBACK = argv[1];

  carmen_warn("********** World Initialization ***********\n");

  seedMT(SEED);
  // Spawn off a seperate thread to do SLAM
  //
  // Should use semaphores or similar to prevent reading of the map
  // during updates to the map.
  //
  Slam(NULL);

  return 0;
}
Exemplo n.º 9
0
bool RakServer::Start( unsigned short AllowedPlayers, unsigned int depreciated, int threadSleepTimer, unsigned short port, const char *forceHostAddress )
{
	bool init;

	RakPeer::Disconnect( 30 );

	init = RakPeer::Initialize( AllowedPlayers, port, threadSleepTimer, forceHostAddress );
	RakPeer::SetMaximumIncomingConnections( AllowedPlayers );

	// Random number seed
	long time = RakNet::GetTime();
	seedMT( time );
	seed = randomMT();

	if ( seed % 2 == 0 )   // Even
		seed--; // make odd

	nextSeed = randomMT();

	if ( nextSeed % 2 == 0 )   // Even
		nextSeed--; // make odd

	return init;
}
Exemplo n.º 10
0
bool RakServer::Start( unsigned short AllowedPlayers, unsigned long connectionValidationInteger, int threadSleepTimer, unsigned short port )
{
	bool init;
	
	RakPeer::Disconnect( 30L );
	
	init = RakPeer::Initialize( AllowedPlayers, port, threadSleepTimer );
	RakPeer::SetMaximumIncomingConnections( AllowedPlayers );
	
	// Random number seed
	long time = RakNet::GetTime();
	seedMT( time );
	seed = randomMT();
	
	if ( seed % 2 == 0 )   // Even
		seed--; // make odd
		
	nextSeed = randomMT();
	
	if ( nextSeed % 2 == 0 )   // Even
		nextSeed--; // make odd
		
	return init;
}
Exemplo n.º 11
0
    void mexFunction(int nlhs, mxArray *plhs[], int nrhs,
                     const mxArray *prhs[])
    {
      double *srwp, *srdp, *probs, *Z, *WS, *DS, *ZIN;
      double ALPHA,BETA;
      mwIndex *irwp, *jcwp, *irdp, *jcdp;
      int *z,*d,*w, *order, *wp, *dp, *ztot;
      int W,T,D,NN,SEED,OUTPUT, nzmax, nzmaxwp, nzmaxdp, ntokens;
      int i,j,c,n,nt,wi,di, startcond;

      /* Check for proper number of arguments. */
      if (nrhs < 8) {
          mexPrintf("1");
        mexErrMsgTxt("At least 8 input arguments required");
      } else if (nlhs < 3) {
          mexPrintf("2");
        mexErrMsgTxt("3 output arguments required");
      }

      startcond = 0;
      if (nrhs == 9) {
          mexPrintf("3");
          startcond = 1;
      }

      /* process the input arguments */
      if (mxIsDouble( prhs[ 0 ] ) != 1) mexErrMsgTxt("WS input vector must be a double precision matrix");
      if (mxIsDouble( prhs[ 1 ] ) != 1) mexErrMsgTxt("DS input vector must be a double precision matrix");

      // pointer to word indices
      WS = mxGetPr( prhs[ 0 ] );

      // pointer to document indices
      DS = mxGetPr( prhs[ 1 ] );

      // get the number of tokens
      ntokens = mxGetM( prhs[ 0 ] ) * mxGetN( prhs[ 0 ] );


      if (ntokens == 0){
          mexPrintf("4");
          mexErrMsgTxt("WS vector is empty");
      }
      if (ntokens != ( mxGetM( prhs[ 1 ] ) * mxGetN( prhs[ 1 ] ))) mexErrMsgTxt("WS and DS vectors should have same number of entries");

      T    = (int) mxGetScalar(prhs[2]);
      if (T<=0){
          mexPrintf("5");
          mexErrMsgTxt("Number of topics must be greater than zero");
      }

      NN    = (int) mxGetScalar(prhs[3]);
      if (NN<0) mexErrMsgTxt("Number of iterations must be positive");

      ALPHA = (double) mxGetScalar(prhs[4]);
      if (ALPHA<=0) mexErrMsgTxt("ALPHA must be greater than zero");

      BETA = (double) mxGetScalar(prhs[5]);
      if (BETA<=0) mexErrMsgTxt("BETA must be greater than zero");

      SEED = (int) mxGetScalar(prhs[6]);

      OUTPUT = (int) mxGetScalar(prhs[7]);
        mexPrintf("ok");
      if (startcond == 1) {
          mexPrintf("6");
          ZIN = mxGetPr( prhs[ 8 ] );
          if (ntokens != ( mxGetM( prhs[ 8 ] ) * mxGetN( prhs[ 8 ] ))) mexErrMsgTxt("WS and ZIN vectors should have same number of entries");
      }

      // seeding
      seedMT( 1 + SEED * 2 ); // seeding only works on uneven numbers



    /* allocate memory */
      z  = (int *) mxCalloc( ntokens , sizeof( int ));

      if (startcond == 1) {
         for (i=0; i<ntokens; i++) z[ i ] = (int) ZIN[ i ] - 1;
      }

      d  = (int *) mxCalloc( ntokens , sizeof( int ));
      w  = (int *) mxCalloc( ntokens , sizeof( int ));
      order  = (int *) mxCalloc( ntokens , sizeof( int ));
      ztot  = (int *) mxCalloc( T , sizeof( int ));
      probs  = (double *) mxCalloc( T , sizeof( double ));

      // copy over the word and document indices into internal format
      for (i=0; i<ntokens; i++) {
         w[ i ] = (int) WS[ i ] - 1;
         d[ i ] = (int) DS[ i ] - 1;
      }

      n = ntokens;

      W = 0;
      D = 0;
      for (i=0; i<n; i++) {
         if (w[ i ] > W) W = w[ i ];
         if (d[ i ] > D) D = d[ i ];
      }
      W = W + 1;
      D = D + 1;

      wp  = (int *) mxCalloc( T*W , sizeof( int ));
      dp  = (int *) mxCalloc( T*D , sizeof( int ));

      //mexPrintf( "N=%d  T=%d W=%d D=%d\n" , ntokens , T , W , D );

      if (OUTPUT==2) {
          mexPrintf("7");
          mexPrintf( "Running LDA Gibbs Sampler Version 1.0\n" );
          if (startcond==1) mexPrintf( "Starting from previous state ZIN\n" );
          mexPrintf( "Arguments:\n" );
          mexPrintf( "\tNumber of words      W = %d\n"    , W );
          mexPrintf( "\tNumber of docs       D = %d\n"    , D );
          mexPrintf( "\tNumber of topics     T = %d\n"    , T );
          mexPrintf( "\tNumber of iterations N = %d\n"    , NN );
          mexPrintf( "\tHyperparameter   ALPHA = %4.4f\n" , ALPHA );
          mexPrintf( "\tHyperparameter    BETA = %4.4f\n" , BETA );
          mexPrintf( "\tSeed number            = %d\n"    , SEED );
          mexPrintf( "\tNumber of tokens       = %d\n"    , ntokens );
          mexPrintf( "Internal Memory Allocation\n" );
          mexPrintf( "\tw,d,z,order indices combined = %d bytes\n" , 4 * sizeof( int) * ntokens );
          mexPrintf( "\twp (full) matrix = %d bytes\n" , sizeof( int ) * W * T  );
          mexPrintf( "\tdp (full) matrix = %d bytes\n" , sizeof( int ) * D * T  );
          //mexPrintf( "Checking: sizeof(int)=%d sizeof(long)=%d sizeof(double)=%d\n" , sizeof(int) , sizeof(long) , sizeof(double));
      }

      /* run the model */
      GibbsSamplerLDA( ALPHA, BETA, W, T, D, NN, OUTPUT, n, z, d, w, wp, dp, ztot, order, probs, startcond );

      /* convert the full wp matrix into a sparse matrix */
      nzmaxwp = 0;
      for (i=0; i<W; i++) {
         for (j=0; j<T; j++)
             nzmaxwp += (int) ( *( wp + j + i*T )) > 0;
      }
      /*if (OUTPUT==2) {
          mexPrintf( "Constructing sparse output matrix wp\n" );
          mexPrintf( "Number of nonzero entries for WP = %d\n" , nzmaxwp );
      }*/

      // MAKE THE WP SPARSE MATRIX
      plhs[0] = mxCreateSparse( W,T,nzmaxwp,mxREAL);
      srwp  = mxGetPr(plhs[0]);
      irwp = mxGetIr(plhs[0]);
      jcwp = mxGetJc(plhs[0]);
      n = 0;
      for (j=0; j<T; j++) {
          *( jcwp + j ) = n;
          for (i=0; i<W; i++) {
             c = (int) *( wp + i*T + j );
             if (c >0) {
                 *( srwp + n ) = c;
                 *( irwp + n ) = i;
                 n++;
             }
          }
      }
      *( jcwp + T ) = n;

      // MAKE THE DP SPARSE MATRIX
      nzmaxdp = 0;
      for (i=0; i<D; i++) {
          for (j=0; j<T; j++)
              nzmaxdp += (int) ( *( dp + j + i*T )) > 0;
      }
      /*if (OUTPUT==2) {
          mexPrintf( "Constructing sparse output matrix dp\n" );
          mexPrintf( "Number of nonzero entries for DP = %d\n" , nzmaxdp );
      } */
      plhs[1] = mxCreateSparse( D,T,nzmaxdp,mxREAL);
      srdp  = mxGetPr(plhs[1]);
      irdp = mxGetIr(plhs[1]);
      jcdp = mxGetJc(plhs[1]);
      n = 0;
      for (j=0; j<T; j++) {
          *( jcdp + j ) = n;
          for (i=0; i<D; i++) {
              c = (int) *( dp + i*T + j );
              if (c >0) {
                  *( srdp + n ) = c;
                  *( irdp + n ) = i;
                  n++;
              }
          }
      }
      *( jcdp + T ) = n;

      plhs[ 2 ] = mxCreateDoubleMatrix( 1,ntokens , mxREAL );
      Z = mxGetPr( plhs[ 2 ] );
      for (i=0; i<ntokens; i++) Z[ i ] = (double) z[ i ] + 1;
    }
void mexFunction(int nlhs, mxArray *plhs[], int nrhs,
                 const mxArray *prhs[])
{
  double *PQ0, *PQ1, *PQ2, *GAMMA, *X, *XD, *PQ, *probs;
  double totweight, sumlogprob, pwz, pwd, pwc, proute0, proute1, proute2, prob, rn, max;
  int *dims, *x, *xtot;
  int NQ, D, BURNIN, S, NS, NN, LAG, SEED, OUTPUT, i, iter, d, route, r, oldroute, newroute;
    
    
  // Check for proper number of arguments.
  if (nrhs < 9) {
    mexErrMsgTxt("At least 9 input arguments required");
  } else if (nlhs < 3) {
    mexErrMsgTxt("3 output arguments required");
  }
  
  // process the input arguments
  if (mxIsDouble( prhs[ 0 ] ) != 1) mexErrMsgTxt("PQ0 input matrix must be a double precision matrix");
  if (mxIsDouble( prhs[ 1 ] ) != 1) mexErrMsgTxt("PQ1 input matrix must be a double precision matrix");
  if (mxIsDouble( prhs[ 2 ] ) != 1) mexErrMsgTxt("PQ2 input matrix must be a double precision matrix");
  if (mxIsDouble( prhs[ 3 ] ) != 1) mexErrMsgTxt("GAMMA input vector must be a double precision matrix");
 
  // pointer to PQ0 matrix
  PQ0 = mxGetPr( prhs[ 0 ] );
  NQ = mxGetM( prhs[ 0 ] );
  D  = mxGetN( prhs[ 0 ] );
  
  // pointer to PQ1 matrix
  PQ1 = mxGetPr( prhs[ 1 ] );
  if ( mxGetM( prhs[ 1 ] ) != NQ ) mexErrMsgTxt("PQ1 matrix should have same dimensions as PQ0 matrix");
  if ( mxGetN( prhs[ 1 ] ) != D )  mexErrMsgTxt("PQ1 matrix should have same dimensions as PQ0 matrix");
  
  // pointer to PQ2 matrix
  PQ2 = mxGetPr( prhs[ 2 ] );
  if ( mxGetM( prhs[ 2 ] ) != 1 )  mexErrMsgTxt("PQ2 matrix should be a row vector");
  if ( mxGetN( prhs[ 2 ] ) != NQ )  mexErrMsgTxt("PQ2 matrix should same number of query words as PQ0 matrix");
    
  // pointer to gamma hyperparameter vector
  GAMMA = mxGetPr( prhs[ 3 ] );
  if ( mxGetM( prhs[ 3 ] ) != 1 ) mexErrMsgTxt("GAMMA input vector must be a row vector");
  if ( mxGetN( prhs[ 3 ] ) != 3 ) mexErrMsgTxt("GAMMA input vector must have three entries");
  
  BURNIN    = (int) mxGetScalar(prhs[4]);
  if (BURNIN<1) mexErrMsgTxt("Number of burnin iterations must be positive");

  NS    = (int) mxGetScalar(prhs[5]);
  if (NS<1) mexErrMsgTxt("Number of samples must be greater than zero");
  
  LAG    = (int) mxGetScalar(prhs[6]);
  if (LAG<1) mexErrMsgTxt("Lag must be greater than zero");
  
  // seeding
  SEED = (int) mxGetScalar(prhs[7]);
  seedMT( 1 + SEED * 2 ); // seeding only works on uneven numbers
  
  OUTPUT = (int) mxGetScalar(prhs[8]);
  
  // create output matrices
  plhs[ 0 ] = mxCreateDoubleMatrix( D , 3, mxREAL );
  X = mxGetPr( plhs[ 0 ] );

  dims = (int *) mxCalloc( 3 , sizeof( int ));
  dims[ 0 ] = D;
  dims[ 1 ] = 3;
  dims[ 2 ] = NQ;
  
  plhs[ 1 ] = mxCreateNumericArray( 3,dims,mxDOUBLE_CLASS,mxREAL);
  XD = mxGetPr( plhs[ 1 ] );
  
  plhs[ 2 ] = mxCreateDoubleMatrix( NQ , D, mxREAL );
  PQ = mxGetPr( plhs[ 2 ] );
  
  if (OUTPUT==2) 
  {
      mexPrintf( "Running Special Word Retrieval Gibbs Sampler Version 1.0\n" );
      mexPrintf( "Arguments:\n" );
      mexPrintf( "\tNumber of docs          D = %d\n" , D );
      mexPrintf( "\tNumber of query words  NQ = %d\n" , NQ );        
      mexPrintf( "\tHyperparameter     GAMMA0 = %4.4f\n" , GAMMA[0] );
      mexPrintf( "\tHyperparameter     GAMMA1 = %4.4f\n" , GAMMA[1] );
      mexPrintf( "\tHyperparameter     GAMMA2 = %4.4f\n" , GAMMA[2] );
      mexPrintf( "\tSeed number          SEED = %d\n" , SEED );
      mexPrintf( "\tBurnin             BURNIN = %d\n" , BURNIN );
      mexPrintf( "\tNumber of samples      NS = %d\n" , NS );
      mexPrintf( "\tLag between samples   LAG = %d\n" , LAG );
  }
  
  // allocate memory for sampling arrays
  x  = (int *) mxCalloc( NQ * D , sizeof( int ));
  xtot = (int *) mxCalloc( 3 * D , sizeof( int ));
  probs = (double *) mxCalloc( 3 , sizeof( double ));
  
  // initialize the sampler
  for (d=0; d<D; d++)
  {
      for (i=0; i<NQ; i++)
      {
          // pick a random route between 0 and 2
          route = (int) ( (double) randomMT() * (double) 3 / (double) (4294967296.0 + 1.0) );
          
          // assign this query word to this route
          x[ i + d * NQ ] = route;
          
          // update total
          xtot[ route + d * 3 ] += 1;
          
          //if (d<5) mexPrintf( "d=%d i=%d route=%d\n" , d , i , route );
      }
  }
  
  // RUN THE GIBBS SAMPLER
  NN = BURNIN + NS * LAG - LAG + 1;
  S = 0;
  for (iter=0; iter<NN; iter++)
  {
      if (OUTPUT >=1)
      {
          if ((iter % 10)==0) mexPrintf( "\tIteration %d of %d\n" , iter , NN );
          if ((iter % 10)==0) mexEvalString("drawnow;");
      }
      
      // do an iteration
      for (d=0; d<D; d++) // loop over all docs
      {
          for (i=0; i<NQ; i++) // loop over all words in query
          {
              oldroute = x[ i + d * NQ ];
              xtot[ oldroute + d * 3 ]--; // subtract this from the counts
                      
              pwz = PQ0[ i + d * NQ ];
              pwd = PQ1[ i + d * NQ ];
              pwc = PQ2[ i ];
        
              probs[ 0 ] = ((double) xtot[ 0 + d * 3 ] + (double) GAMMA[ 0 ] ) * pwz;  
              probs[ 1 ] = ((double) xtot[ 1 + d * 3 ] + (double) GAMMA[ 1 ] ) * pwd;
              probs[ 2 ] = ((double) xtot[ 2 + d * 3 ] + (double) GAMMA[ 2 ] ) * pwc;
              
              totweight = probs[ 0 ] + probs[ 1 ] + probs[ 2 ];
              
              // sample a route from this distribution
              rn = (double) totweight * (double) randomMT() / (double) 4294967296.0;
              max = probs[0];
              newroute = 0;
              while (rn>max) {
                  newroute++;
                  max += probs[newroute];
              }

              x[ i + d * NQ ] = newroute;
              xtot[ newroute + d * 3 ]++; // add this to the counts
          }
      }
      
      
      if ((iter >= BURNIN) && (((iter - BURNIN) % LAG) == 0))
      {
          S++;
          
          if (OUTPUT >=1)
          {
              //mexPrintf( "\tDrawing sample %d of %d\n" , S , NS );
              //mexEvalString("drawnow;");
          }
          
          // update the return variables with the new counts
          for (d=0; d<D; d++)
          {
              
              X[ d + 0 * D ] += (double) xtot[ 0 + d * 3 ] + GAMMA[ 0 ];
              X[ d + 1 * D ] += (double) xtot[ 1 + d * 3 ] + GAMMA[ 1 ];
              X[ d + 2 * D ] += (double) xtot[ 2 + d * 3 ] + GAMMA[ 2 ];
              
              for (i=0; i<NQ; i++)
              {
                  route = x[ i + d * NQ ]; // current route assignment
                  
                  for (r=0; r<3; r++)
                  {
                      if (r==route)
                      {   
                          XD[ d + r*D + i*D*3 ] += (double) 1 + GAMMA[ r ];  // D x 3 x NQ
                      } else
                      {
                          XD[ d + r*D + i*D*3 ] += (double) GAMMA[ r ];  // D x 3 x NQ
                      }
                  } 
                           
              }
          }
      }
  }
  
  // NOW CALCULATE PROBABILITY DISTRIBUTIONS
  for (d=0; d<D; d++)
  {      
     totweight = 0;
     for (route=0; route<3; route++) totweight += X[ d + route * D ];
     for (route=0; route<3; route++) X[ d + route * D ] /= totweight; 
     
     for (i=0; i<NQ; i++)
     {
        totweight = 0;
        for (route=0; route<3; route++) totweight += XD[ d + route*D + i*D*3 ];
        for (route=0; route<3; route++) XD[ d + route*D + i*D*3 ] /= totweight;
     }
  }
  
  // NOW CALCULATE RETRIEVAL PROBABILITY
  for (d=0; d<D; d++)
  {  
     // the following is correct with a model that has lambda outside the plate -- one route probability for each query 
     proute0 = X[ d + 0 * D ];
     proute1 = X[ d + 1 * D ];
     proute2 = X[ d + 2 * D ];
      
     for (i=0; i<NQ; i++)
     {
        // the following is correct with a model that has lambda inside the plate -- a separate probability for each word       
        //proute0 = XD[ d + 0 * D + i * D * 3 ];
        //proute1 = XD[ d + 1 * D + i * D * 3 ];
        //proute2 = XD[ d + 2 * D + i * D * 3 ];
          
        pwz = PQ0[ i + d * NQ ];
        pwd = PQ1[ i + d * NQ ];
        pwc = PQ2[ i ];
        
        prob = proute0 * pwz + proute1 * pwd + proute2 * pwc;
        
        PQ[ i + d * NQ ] = prob;
        //if (d<5) mexPrintf( "d=%d i=%d prob=%4.7f logprob=%4.7f p0=%4.5f p1=%4.5f p2=%4.5f  pwz=%4.5f  pwd=%4.5f  pwc=%4.5f\n" , d , i , prob , sumlogprob , proute0 , proute1 , proute2 , pwz , pwd , pwc );
     }
     
     //if (d<5) mexPrintf( "d=%d  logprob=%4.4f\n" , d , sumlogprob );
          
  }
   
}
Exemplo n.º 13
0
//
// Start of main program.
// IF things are set to read from a robot's sensors and not a data log, then this would be the best place
// to actually put in controls for the robot's behaviors and actions. The main SLAM process is called as a
// seperate thread off of this function.
//
int main (int argc, char *argv[])
{
  //char command[256], tempString[20];
  int x;
  //int y;
  //double maxDist, tempDist, tempAngle;
  int WANDER, EXPLORE, DIRECT_COMMAND;
  pthread_t slam_thread;
    
  RECORDING = "";
  PLAYBACK = "";
  for (x = 1; x < argc; x++) {
    if (!strncmp(argv[x], "-R", 2))
      RECORDING = "current.log";
    if (!strncmp(argv[x], "-r", 2)) {
      x++;
      RECORDING = argv[x];
    }
    else if (!strncmp(argv[x], "-p", 2)) {
      x++;
      PLAYBACK = argv[x];
    }
    else if (!strncmp(argv[x], "-P", 2))
      PLAYBACK = "current.log";
  }

  fprintf(stderr, "********** Localization Example *************\n");
  if (PLAYBACK == "")
    if (InitializeRobot(argc, argv) == -1)
      return -1;

  fprintf(stderr, "********** World Initialization ***********\n");

  seedMT(SEED);
  // Spawn off a seperate thread to do SLAM
  //
  // Should use semaphores or similar to prevent reading of the map
  // during updates to the map.
  //
  continueSlam = 1;
  pthread_create(&slam_thread, (pthread_attr_t *) NULL, Slam, &x);

  fprintf(stderr, "*********** Main Loop (Movement) **********\n");


  // This is the spot where code should be inserted to control the robot. You can go ahead and assume
  // that the robot is localizing and mapping.
  WANDER = 0;
  EXPLORE = 0;
  DIRECT_COMMAND = 0;
  RotationSpeed = 0.0;
  TranslationSpeed = 0.0;

  // Some very crude commands designed to give manual control over our ATRV Jr
  // Removed now for convenience and efficiency, since we're running from data logs right now.
  /*
  while (1) {
    // Was there a character pressed?
    //    gets(command);
    scanf("%s", command);
    if (command != NULL) {
      if ((PLAYBACK_COMPLETE) || (strncmp(command, "quit", 4) == 0)) {
	if (PLAYBACK == "") {
	  //stop the robot
	  TranslationSpeed = 0.0;
	  RotationSpeed = 0.0;
	  Drive(TranslationSpeed, RotationSpeed);
	}

	// kill the other thread
	continueSlam = 0;
	pthread_join(slam_thread, NULL);

        return 0;
      }

      else if (strncmp(command, "speed", 5) == 0) {
        strncpy(tempString, index(command, ' '), 10);
        TranslationSpeed = atof(tempString);
	if (TranslationSpeed > 0.5)
	  TranslationSpeed = 0.5;
	RotationSpeed = 0;
	DIRECT_COMMAND = 1;
      }

      else if (strncmp(command, "turn", 4) == 0) {
        strncpy(tempString, index(command, ' '), 10);
        RotationSpeed = atof(tempString);
	if (RotationSpeed > 0.6)
	  RotationSpeed = 0.6;
	TranslationSpeed = 0;
	DIRECT_COMMAND = 1;
      }

      else if (strncmp(command, "stop", 4) == 0) {
	TranslationSpeed = 0.0;
	RotationSpeed = 0.0;
	Drive(TranslationSpeed, RotationSpeed);
	DIRECT_COMMAND = 0;
      }

      else if (strncmp(command, "print", 5) == 0) {
	y = 0;
	for (x = 0; x < cur_particles_used; x++)
	  if (particle[x].probability > particle[y].probability)
	    y = x;
	PrintMap(MAP_PATH_NAME, particle[y].parent, FALSE, -1, -1, -1);
      }

      else if (strncmp(command, "particles", 9) == 0) {
	y = 0;
	for (x = 0; x < cur_particles_used; x++)
	  if (particle[x].probability > particle[y].probability)
	    y = x;
	PrintMap(PARTICLES_PATH_NAME, particle[y].parent, TRUE, -1, -1, -1);
      }

      else if (strncmp(command, "overlay", 7) == 0) {
	y = 0;
	for (x = 0; x < cur_particles_used; x++)
	  if (particle[x].probability > particle[y].probability)
	    y = x;
	PrintMap(MAP_PATH_NAME, particle[y].parent, FALSE, particle[y].x, particle[y].y, particle[y].theta);
      }

      else if (strncmp(command, "centerx ", 8) == 0) {
        strncpy(tempString, index(command, ' '), 10);
        scat_center_x = atof(tempString);
      }

      else if (strncmp(command, "centery ", 8) == 0) {
        strncpy(tempString, index(command, ' '), 10);
        scat_center_y = atof(tempString);
      }

      else if (strncmp(command, "radius ", 7) == 0) {
        strncpy(tempString, index(command, ' '), 10);
        scat_radius = atof(tempString);
      }

      //else {
      //fprintf(stderr, "I don't understand you.\n");
      //}
    }

    if ((DIRECT_COMMAND == 1) && (PLAYBACK == "")) {
      Drive(TranslationSpeed, RotationSpeed);
    }
  
    else if (PLAYBACK == "") {
      // stop the robot
      TranslationSpeed = 0.0;
      RotationSpeed = 0.0;
      Drive(TranslationSpeed, RotationSpeed);
    }
    
  }
  */

  pthread_join(slam_thread, NULL);
  return 0;
}
Exemplo n.º 14
0
int main(int argc, char * args[]) {
	UNUSED(argc);
	UNUSED(args);
	init_bob_rand();
    seedMT(4357U);
	srand(time(NULL));
	s_lcrand(time(NULL));
	s_rand_qpr(0, 0);
	int selected_mode = 999999;
	while(selected_mode>10){
		printf("Select random function by pressing number key\n1:Windows random \t2:LFSR \t3:Mersenne twister"
		"\t4BobByrtle \t5Linear Congruential\t6Random with array \t7xorShift \t8Quadratic Resides"
		"\t9Concatenate 16 \t10Tausworth\n");
		selected_mode = get_int_input();
	}
    SDL_Init(SDL_INIT_VIDEO);

    SDL_Window * sdlWindow = SDL_CreateWindow("Visualize PRNG",
        SDL_WINDOWPOS_UNDEFINED,
        SDL_WINDOWPOS_UNDEFINED,
        640, 480,
        SDL_WINDOW_OPENGL);


    if (sdlWindow == NULL) {
        printf("Could not create window: %s\n", SDL_GetError());
        return 1;
    }

    SDL_Renderer * sdlRenderer = SDL_CreateRenderer(sdlWindow, -1, 0);

    if (sdlRenderer == NULL) {
        printf("Could not create renderer: %s\n", SDL_GetError());
        return 1;
    }

    SDL_RenderClear(sdlRenderer);
    SDL_RenderPresent(sdlRenderer);

    SDL_Texture * sdlTexture = SDL_CreateTexture(sdlRenderer,
        SDL_PIXELFORMAT_RGBA8888,
        SDL_TEXTUREACCESS_STREAMING,
        640, 480);
    if (sdlTexture == NULL) {
        printf("Could not create texture: %s\n", SDL_GetError());
        return 1;
    }
	SDL_SetTextureBlendMode(sdlTexture, SDL_BLENDMODE_BLEND);
	
	
	uint8 * data = (uint8 * ) malloc(sizeof(uint8) * 640 * 480);
    uint32 * pixels = (uint32 * ) malloc(sizeof(uint32) * 640 * 480);
    memset(pixels, 0, 640 * 480 * sizeof(uint32));
	memset(data, 0, 640 * 480 * sizeof(uint8));
	int cnt=5999999;
	uint32 * itable =createdata(cnt,selected_mode);
	
	CameraData camData;
	camData.xshift=0;
	camData.yshift=0;
	camData.zoom=1;
	camData.phata=0;
	SDL_PixelFormat* format = SDL_AllocFormat(SDL_PIXELFORMAT_RGBA8888);

	
	
	int arr_size = 640 * 480;
	bool keep_running = true;
    while (keep_running) {
		memset(pixels, 0, 640 * 480 * sizeof(uint32));
		update_data(data,640,480,itable,cnt,&camData);
		for(int i=0;i<arr_size;++i){
			uint8 data_value = data[i];
			if(data_value){
				pixels[i] = SDL_MapRGBA( format, 0xFF, 0xFF, 0xFF, data_value );

			}
		}
		SDL_UpdateTexture(sdlTexture, NULL, pixels, 640 * sizeof(Uint32));
        SDL_RenderClear(sdlRenderer);
        SDL_RenderCopy(sdlRenderer, sdlTexture, NULL, NULL);
        SDL_RenderPresent(sdlRenderer);
        SDL_Delay(20);
		keep_running = handle_events(&camData);
	}
    SDL_DestroyWindow(sdlWindow);

    SDL_Quit();

    free(pixels);
    SDL_DestroyTexture(sdlTexture);
    SDL_DestroyRenderer(sdlRenderer);
    return 0;
}
int DroppedConnectionConvertTest::RunTest(DataStructures::List<RakString> params,bool isVerbose,bool noPauses)
{

	RakPeerInterface *server;
	RakPeerInterface *clients[NUMBER_OF_CLIENTS];
	unsigned index, connectionCount;
	SystemAddress serverID;
	Packet *p;
	unsigned short numberOfSystems;
	unsigned short numberOfSystems2;
	int sender;

	// Buffer for input (an ugly hack to keep *nix happy)
	//	char buff[256];

	// Used to refer to systems.  We already know the IP
	unsigned short serverPort = 20000;
	serverID.binaryAddress=inet_addr("127.0.0.1");
	serverID.port=serverPort;

	server=RakPeerInterface::GetInstance();
	destroyList.Clear(false,_FILE_AND_LINE_);
	destroyList.Push(server,_FILE_AND_LINE_);
	//	server->InitializeSecurity(0,0,0,0);
	SocketDescriptor socketDescriptor(serverPort,0);
	server->Startup(NUMBER_OF_CLIENTS, &socketDescriptor, 1);
	server->SetMaximumIncomingConnections(NUMBER_OF_CLIENTS);
	server->SetTimeoutTime(2000,UNASSIGNED_SYSTEM_ADDRESS);

	for (index=0; index < NUMBER_OF_CLIENTS; index++)
	{
		clients[index]=RakPeerInterface::GetInstance();
		destroyList.Push(clients[index],_FILE_AND_LINE_);
		SocketDescriptor socketDescriptor2(serverPort+1+index,0);
		clients[index]->Startup(1, &socketDescriptor2, 1);
		if (clients[index]->Connect("127.0.0.1", serverPort, 0, 0)!=CONNECTION_ATTEMPT_STARTED)
		{
			DebugTools::ShowError("Connect function failed.",!noPauses && isVerbose,__LINE__,__FILE__);
			return 2;

		}
		clients[index]->SetTimeoutTime(5000,UNASSIGNED_SYSTEM_ADDRESS);

		RakSleep(1000);
		if (isVerbose)
			printf("%i. ", index);
	}

	TimeMS entryTime=GetTimeMS();//Loop entry time

	int seed = 12345;
	if (isVerbose)
		printf("Using seed %i\n", seed);
	seedMT(seed);//specify seed to keep execution path the same.

	int randomTest;

	bool dropTest=false;
	RakTimer timeoutWaitTimer(1000);

	while (GetTimeMS()-entryTime<30000)//run for 30 seconds.
	{
		// User input

		randomTest=randomMT() %4;

		if(dropTest)
		{

			server->GetConnectionList(0, &numberOfSystems);
			numberOfSystems2=numberOfSystems;

			connectionCount=0;
			for (index=0; index < NUMBER_OF_CLIENTS; index++)
			{
				clients[index]->GetConnectionList(0, &numberOfSystems);
				if (numberOfSystems>1)
				{
					if (isVerbose)
					{
						printf("Client %i has %i connections\n", index, numberOfSystems);
						DebugTools::ShowError("Client has more than one connection",!noPauses && isVerbose,__LINE__,__FILE__);
						return 1;
					}

				}
				if (numberOfSystems==1)
				{
					connectionCount++;
				}
			}

			if (connectionCount!=numberOfSystems2)
			{
				if (isVerbose)
					DebugTools::ShowError("Timeout on dropped clients not detected",!noPauses && isVerbose,__LINE__,__FILE__);
				return 3;
			}

		}
		dropTest=false;

		switch(randomTest)
		{

		case 0:
			{
				index = randomMT() % NUMBER_OF_CLIENTS;

				clients[index]->GetConnectionList(0, &numberOfSystems);
				clients[index]->CloseConnection(serverID, false,0);
				if (numberOfSystems==0)
				{
					if (isVerbose)
						printf("Client %i silently closing inactive connection.\n",index);
				}
				else
				{
					if (isVerbose)
						printf("Client %i silently closing active connection.\n",index);
				}
			}

			break;
		case 1:
			{
				index = randomMT() % NUMBER_OF_CLIENTS;

				clients[index]->GetConnectionList(0, &numberOfSystems);

				if(!CommonFunctions::ConnectionStateMatchesOptions (clients[index],serverID,true,true,true,true) )//Are we connected or is there a pending operation ?
				{
					if (clients[index]->Connect("127.0.0.1", serverPort, 0, 0)!=CONNECTION_ATTEMPT_STARTED)
					{

						DebugTools::ShowError("Connect function failed.",!noPauses && isVerbose,__LINE__,__FILE__);
						return 2;

					}
				}
				if (numberOfSystems==0)
				{
					if (isVerbose)
						printf("Client %i connecting to same existing connection.\n",index);

				}
				else
				{
					if (isVerbose)
						printf("Client %i connecting to closed connection.\n",index);
				}
			}

			break;
		case 2:
			{

				if (isVerbose)
					printf("Randomly connecting and disconnecting each client\n");
				for (index=0; index < NUMBER_OF_CLIENTS; index++)
				{
					if (NUMBER_OF_CLIENTS==1 || (randomMT()%2)==0)
					{
						if (clients[index]->IsActive())
						{

							int randomTest2=randomMT() %2;
							if (randomTest2)
								clients[index]->CloseConnection(serverID, false, 0);
							else
								clients[index]->CloseConnection(serverID, true, 0);
						}
					}
					else
					{
						if(!CommonFunctions::ConnectionStateMatchesOptions (clients[index],serverID,true,true,true,true) )//Are we connected or is there a pending operation ?
						{
							if (clients[index]->Connect("127.0.0.1", serverPort, 0, 0)!=CONNECTION_ATTEMPT_STARTED)
							{
								DebugTools::ShowError("Connect function failed.",!noPauses && isVerbose,__LINE__,__FILE__);
								return 2;

							}
						}
					}
				}
			}
			break;

		case 3:
			{

				if (isVerbose)
					printf("Testing if clients dropped after timeout.\n");
				timeoutWaitTimer.Start();
						//Wait half the timeout time, the other half after receive so we don't drop all connections only missing ones, Active ait so the threads run on linux
				while (!timeoutWaitTimer.IsExpired())
				{
				RakSleep(50);
				}
				dropTest=true;

			}
			break;
		default:
			// Ignore anything else
			break;
		}

		server->GetConnectionList(0, &numberOfSystems);
		numberOfSystems2=numberOfSystems;
		if (isVerbose)
			printf("The server thinks %i clients are connected.\n", numberOfSystems);
		connectionCount=0;
		for (index=0; index < NUMBER_OF_CLIENTS; index++)
		{
			clients[index]->GetConnectionList(0, &numberOfSystems);
			if (numberOfSystems>1)
			{
				if (isVerbose)
				{
					printf("Client %i has %i connections\n", index, numberOfSystems);
					DebugTools::ShowError("Client has more than one connection",!noPauses && isVerbose,__LINE__,__FILE__);
					return 1;
				}

			}
			if (numberOfSystems==1)
			{
				connectionCount++;
			}
		}

		if (isVerbose)
			printf("%i clients are actually connected.\n", connectionCount);
		if (isVerbose)
			printf("server->NumberOfConnections==%i.\n", server->NumberOfConnections());

		//}

		// Parse messages

		while (1)
		{
			p = server->Receive();
			sender=NUMBER_OF_CLIENTS;
			if (p==0)
			{
				for (index=0; index < NUMBER_OF_CLIENTS; index++)
				{
					p = clients[index]->Receive();
					if (p!=0)
					{
						sender=index;
						break;						
					}
				}
			}

			if (p)
			{
				switch (p->data[0])
				{
				case ID_CONNECTION_REQUEST_ACCEPTED:
					if (isVerbose)
						printf("%i: %ID_CONNECTION_REQUEST_ACCEPTED from %i.\n",sender, p->systemAddress.port);
					break;
				case ID_DISCONNECTION_NOTIFICATION:
					// Connection lost normally
					if (isVerbose)
						printf("%i: ID_DISCONNECTION_NOTIFICATION from %i.\n",sender, p->systemAddress.port);
					break;

				case ID_NEW_INCOMING_CONNECTION:
					// Somebody connected.  We have their IP now
					if (isVerbose)
						printf("%i: ID_NEW_INCOMING_CONNECTION from %i.\n",sender, p->systemAddress.port);
					break;


				case ID_CONNECTION_LOST:
					// Couldn't deliver a reliable packet - i.e. the other system was abnormally
					// terminated
					if (isVerbose)
						printf("%i: ID_CONNECTION_LOST from %i.\n",sender, p->systemAddress.port);
					break;

				case ID_NO_FREE_INCOMING_CONNECTIONS:
					if (isVerbose)
						printf("%i: ID_NO_FREE_INCOMING_CONNECTIONS from %i.\n",sender, p->systemAddress.port);
					break;

				default:
					// Ignore anything else
					break;
				}
			}
			else
				break;

			if (sender==NUMBER_OF_CLIENTS)
				server->DeallocatePacket(p);
			else
				clients[sender]->DeallocatePacket(p);
		}
		if (dropTest)
		{
			//Trigger the timeout if no recieve
			timeoutWaitTimer.Start();
			while (!timeoutWaitTimer.IsExpired())
			{
			RakSleep(50);
			}
		}
		// 11/29/05 - No longer necessary since I added the keepalive
		/*
		// Have everyone send a reliable packet so dropped connections are noticed.
		ch=255;
		server->Send((char*)&ch, 1, HIGH_PRIORITY, RELIABLE, 0, UNASSIGNED_SYSTEM_ADDRESS, true);

		for (index=0; index < NUMBER_OF_CLIENTS; index++)
		clients[index]->Send((char*)&ch, 1, HIGH_PRIORITY, RELIABLE, 0, UNASSIGNED_SYSTEM_ADDRESS, true);
		*/

		// Sleep so this loop doesn't take up all the CPU time

		RakSleep(10);

	}

	return 0;
}
Exemplo n.º 16
0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs,
                 const mxArray *prhs[])
{
  double *srwp, *srdp, *srmp, *probs, *d, *w, *ZIN, *XIN, *Z, *X;
  double ALPHA,BETA, GAMMA;
  mwIndex *irwp, *jcwp, *irdp, *jcdp, *irmp, *jcmp;
  int *z,*x, *wp, *mp, *dp, *ztot, *mtot, *stot, *sp, *first, *second, *third;
  int W,T,S,SINPUT,D,NN,SEED,OUTPUT, nzmax, nzmaxwp,nzmaxmp, nzmaxdp, ntokens;
  int NWS,NDS,i,j,c,n,wi,di,i1,i2,i3,i4,S2,S3, startcond;
  
  // Syntax
  //   [ WP , DP , MP , Z , X ] = GibbsSamplerHMMLDA( WS , DS , T , S , N , ALPHA , BETA , GAMMA , SEED , OUTPUT , ZIN , XIN )

  /* Check for proper number of arguments. */
  if (nrhs < 10) {
    mexErrMsgTxt("At least 10 input arguments required");
  } else if (nlhs != 5) {
    mexErrMsgTxt("5 output arguments required");
  }
  
  startcond = 0;
  if (nrhs > 10) startcond = 1;
  
  if (sizeof( int ) != sizeof( mxINT32_CLASS ))
    mexErrMsgTxt("Problem with internal integer representation -- contact programmer" );
  
  w   = mxGetPr( prhs[0] );
  NWS = mxGetN( prhs[0] );
  
  d   = mxGetPr( prhs[1] );
  NDS = mxGetN( prhs[1] );
  
  if ((mxIsSparse( prhs[ 0 ] ) == 1) || (mxIsDouble( prhs[ 0 ] ) != 1) || (mxGetM( prhs[0]) != 1))
      mexErrMsgTxt("WS input stream must be a one-dimensional, non-sparse, double precision vector of word indices");
 
  if ((mxIsSparse( prhs[ 1 ] ) == 1) || (mxIsDouble( prhs[ 1 ] ) != 1) || (mxGetM( prhs[1]) != 1))
      mexErrMsgTxt("DS input stream must be a one-dimensional, non-sparse, double precision vector of document indices");
  
  if (NWS != NDS)
      mexErrMsgTxt("WS and DS input streams must have equal dimensions" );
    
  ntokens = NWS;
  
  T    = (int) mxGetScalar(prhs[2]);
  if (T<=0) mexErrMsgTxt("Number of topics must be greater than zero");
  
  SINPUT    = (int) mxGetScalar(prhs[3]);
  if (SINPUT<=0) mexErrMsgTxt("Number of syntactic states must be greater than zero");
  
  // need one syntactic state for sentence start 
  // need one syntactic state for topic state
  S = SINPUT + 2;
  
  NN    = (int) mxGetScalar(prhs[4]);
  if (NN<0) mexErrMsgTxt("Number of iterations must be greater than zero");
  
  ALPHA = (double) mxGetScalar(prhs[5]);
  if (ALPHA<=0) mexErrMsgTxt("ALPHA must be greater than zero");
  
  BETA = (double) mxGetScalar(prhs[6]);
  if (BETA<=0) mexErrMsgTxt("BETA must be greater than zero");
  
  GAMMA = (double) mxGetScalar(prhs[7]);
  if (GAMMA<=0) mexErrMsgTxt("GAMMA must be greater than zero");
  
  SEED = (int) mxGetScalar(prhs[8]);
  // set the seed of the random number generator
  
  OUTPUT = (int) mxGetScalar(prhs[9]);
  
  if (startcond==1) {
      if ((mxGetN( prhs[ 10 ] )*mxGetM( prhs[ 10 ] )) != ntokens) mexErrMsgTxt( "The ZIN vector should have have the same number of tokens as WS" );
      if ((mxGetN( prhs[ 11 ] )*mxGetM( prhs[ 11 ] )) != ntokens) mexErrMsgTxt( "The ZIN vector should have have the same number of tokens as WS" );
      
      ZIN = mxGetPr( prhs[ 10 ]);
      XIN = mxGetPr( prhs[ 11 ]);
  }
  
  // seeding
  seedMT( 1 + SEED * 2 ); // seeding only works on uneven numbers
  
  W       = 0;
  D       = 0;
  for (i=0; i<ntokens; i++) {
      // in Matlab WI=1...W word occurence, WI=0 sentence marker  
      wi = (int) w[ i ] - 1; // word indices are not zero based in Matlab
      di = (int) d[ i ] - 1; // document indices are not zero based in Matlab
      
      //mexPrintf( "i=%d wi=%d di=%d\n" , i , wi , di );
      
      if (wi > W) W = wi;
      if (di > D) D = di;
      
      if (wi < -1) mexErrMsgTxt("Unrecognized code in word stream (<-2)");
      if (di <  0) mexErrMsgTxt("Unrecognized code in document stream (<0)");
  }
  W = W + 1;
  D = D + 1;
 
  if ( wi != -1 )
       mexErrMsgTxt("Word stream should end with sentence marker");
  
  n = ntokens;
  //mexPrintf( "W=%d D=%d n=%d w[n-1]=%d\n" , W , D , n , (int) w[ n-1 ] );
  
  
  // allocate memory 
  z  = (int *) mxCalloc( ntokens , sizeof( int ));
  x  = (int *) mxCalloc( ntokens , sizeof( int ));
  
  if (startcond==1) {
     for (i=0; i<ntokens; i++) {
        z[ i ] = (int) ZIN[ i ] - 1;
        x[ i ] = (int) XIN[ i ] - 1;
     }
  }
  
  first   = (int *) mxCalloc( ntokens , sizeof( int ));
  second  = (int *) mxCalloc( ntokens , sizeof( int ));
  third   = (int *) mxCalloc( ntokens , sizeof( int ));
  mp  = (int *) mxCalloc( W*S , sizeof( int ));
  wp  = (int *) mxCalloc( T*W , sizeof( int ));
  dp  = (int *) mxCalloc( T*D , sizeof( int ));
  ztot  = (int *) mxCalloc( T , sizeof( int ));
  mtot  = (int *) mxCalloc( S , sizeof( int ));
  stot  = (int *) mxCalloc( S * S * S , sizeof( int ));
  sp    = (int *) mxCalloc( S * S * S * S , sizeof( int ));
  probs  = (double *) mxCalloc( T+S , sizeof( double ));
  
  
  //for (i=0; i<n; i++) mexPrintf( "i=%4d w[i]=%3d d[i]=%d\n" , i , w[i] , d[i] , z[i] );    
  
  if (OUTPUT==2) {
      mexPrintf( "Running HMM-LDA Gibbs Sampler Version 1.0\n" );
      if (startcond==1) mexPrintf( "Starting sampler from previous state\n" );
      mexPrintf( "Arguments:\n" );
      mexPrintf( "\tNumber of words            W = %d\n" , W );
      mexPrintf( "\tNumber of docs             D = %d\n" , D );
      mexPrintf( "\tNumber of topics           T = %d\n" , T );
      mexPrintf( "\tNumber of syntactic states S = %d\n" , S );
      mexPrintf( "\tNumber of iterations       N = %d\n" , NN );
      mexPrintf( "\tHyperparameter         ALPHA = %4.4f\n" , ALPHA );
      mexPrintf( "\tHyperparameter          BETA = %4.4f\n" , BETA );
      mexPrintf( "\tHyperparameter         GAMMA = %4.4f\n" , GAMMA );
      mexPrintf( "\tSeed number             SEED = %d\n" , SEED );
      mexPrintf( "Properties of WS stream\n" );
      mexPrintf( "\tNumber of tokens       NZ = %d\n" , ntokens );
      mexPrintf( "Internal Memory Allocation\n" );
      mexPrintf( "\tz,x,first,second,third indices combined = %d bytes\n" , 5 * sizeof( int) * ntokens );
      mexPrintf( "\twp   matrix = %d bytes\n" , sizeof( int ) * W * T  );
      mexPrintf( "\tmp   matrix = %d bytes\n" , sizeof( int ) * W * S  );
      mexPrintf( "\tdp   matrix = %d bytes\n" , sizeof( int ) * D * T  );
      mexPrintf( "\tstot matrix = %d bytes\n" , sizeof( int ) * S * S * S );
      mexPrintf( "\tsp   matrix = %d bytes\n" , sizeof( int ) * S * S * S * S );
      //mexPrintf( "Checking: sizeof(int)=%d sizeof(long)=%d sizeof(double)=%d\n" , sizeof(int) , sizeof(long) , sizeof(double));
  }
  
  // run the model 
  GibbsSamplerHMMLDA( ALPHA, BETA, GAMMA, W, T, S , D, NN, OUTPUT, n, z, x, d, w, wp, mp, dp, ztot, mtot, stot , sp , first,second,third, probs, startcond );
  
  // convert the full wp matrix into a sparse matrix 
  nzmaxwp = 0;
  for (i=0; i<W; i++) {
     for (j=0; j<T; j++)
         nzmaxwp += (int) ( *( wp + j + i*T )) > 0;
  }
  if (OUTPUT==2) mexPrintf( "Constructing sparse output matrix WP  nnz=%d\n" , nzmaxwp );

  plhs[0] = mxCreateSparse( W,T,nzmaxwp,mxREAL);
  srwp  = mxGetPr(plhs[0]);
  irwp = mxGetIr(plhs[0]);
  jcwp = mxGetJc(plhs[0]); 
  n = 0;
  for (j=0; j<T; j++) {
      *( jcwp + j ) = n;
      for (i=0; i<W; i++) {
         c = (int) *( wp + i*T + j );
         if (c >0) {
             *( srwp + n ) = c;
             *( irwp + n ) = i;
             n++;
         }
      }    
  }
  
  *( jcwp + T ) = n;    
   
  // processing DP as sparse output matrix 
  nzmaxdp = 0;
  for (i=0; i<D; i++) {
      for (j=0; j<T; j++)
          nzmaxdp += (int) ( *( dp + j + i*T )) > 0;
  }  
  if (OUTPUT==2) mexPrintf( "Constructing sparse output matrix DP  nnz=%d\n" , nzmaxdp );
  
  plhs[1] = mxCreateSparse( D,T,nzmaxdp,mxREAL);
  srdp  = mxGetPr(plhs[1]);
  irdp = mxGetIr(plhs[1]);
  jcdp = mxGetJc(plhs[1]);
  
  n = 0;
  for (j=0; j<T; j++) {
      *( jcdp + j ) = n;
      for (i=0; i<D; i++) {
          c = (int) *( dp + i*T + j );
          if (c >0) {
              *( srdp + n ) = c;
              *( irdp + n ) = i;
              n++;
          }
      }
  }
  *( jcdp + T ) = n;
 
  // processing MP as sparse output matrix 
  nzmaxmp = 0;
  for (i=0; i<W; i++) {
      for (j=0; j<S; j++) {
          nzmaxmp += (int) ( mp[ j + i*S ] > 0 );
      }    
  }  
  if (OUTPUT==2) mexPrintf( "Constructing sparse output matrix MP nnz=%d\n" , nzmaxmp );
  
  plhs[2] = mxCreateSparse( W,SINPUT,nzmaxmp,mxREAL);
  srmp  = mxGetPr(plhs[2]);
  irmp = mxGetIr(plhs[2]);
  jcmp = mxGetJc(plhs[2]);
  
  n = 0;
  for (j=0; j<SINPUT; j++) {
      *( jcmp + j ) = n;
      for (i=0; i<W; i++) {
          c = mp[ i*S + (j+2) ];
          if (c >0) {
              *( srmp + n ) = c;
              *( irmp + n ) = i;
              n++;
          }
      }
  }
  *( jcmp + SINPUT ) = n;
  
  
  
  /* ---------------------------------------------
     create the topic assignment vector
   -----------------------------------------------*/
  plhs[ 3 ] = mxCreateDoubleMatrix( ntokens , 1 , mxREAL );
  Z = mxGetPr( plhs[ 3 ] );
  for (i=0; i<ntokens; i++) Z[ i ] = (double) z[ i ] + 1;

   /* ---------------------------------------------
     create the HMM state assignment vector
   -----------------------------------------------*/
  plhs[ 4 ] = mxCreateDoubleMatrix( ntokens , 1 , mxREAL );
  X = mxGetPr( plhs[ 4 ] );
  for (i=0; i<ntokens; i++) X[ i ] = (double) x[ i ] + 1;
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) 
{
	double *mu, *MUIN, *phi, *theta, *sr;
	double ALPHA, BETA, threshold;
	mwIndex *ir, *jc;
	int W, J, D, NN, SEED, OUTPUT, nzmax, i, startcond;

	/* Check for proper number of arguments. */
	if (nrhs < 8) {
		mexErrMsgTxt("At least 8 input arguments required");
	} else if (nlhs < 1) {
		mexErrMsgTxt("At least 1 output arguments required");
	}

	startcond = 0;
	if (nrhs == 9) startcond = 1;

	/* read sparse array DW */
	if (mxIsDouble(prhs[0]) != 1) mexErrMsgTxt("DW must be a double precision matrix");
	sr = mxGetPr(prhs[0]);
	ir = mxGetIr(prhs[0]);
	jc = mxGetJc(prhs[0]);
	nzmax = (int) mxGetNzmax(prhs[0]);
	D = (int) mxGetM(prhs[0]);
	W = (int) mxGetN(prhs[0]);

	phi = mxGetPr(prhs[1]);
	J = (int) mxGetM(prhs[1]);
	if (J<=0) mexErrMsgTxt("Number of topics must be greater than zero");
	if ((int) mxGetN(prhs[1]) != W) mexErrMsgTxt("Vocabulary mismatches");

	threshold = (double) mxGetScalar(prhs[2]);
	if (threshold<0 || threshold>1) mexErrMsgTxt("Threshold should be between 0 and 1.");

	NN = (int) mxGetScalar(prhs[3]);
	if (NN<0) mexErrMsgTxt("Number of iterations must be positive");

	ALPHA = (double) mxGetScalar(prhs[4]);
	if (ALPHA<0) mexErrMsgTxt("ALPHA must be greater than zero");

	BETA = (double) mxGetScalar(prhs[5]);
	if (BETA<0) mexErrMsgTxt("BETA must be greater than zero");

	SEED = (int) mxGetScalar(prhs[6]);

	OUTPUT = (int) mxGetScalar(prhs[7]);

	if (startcond == 1) {
		MUIN = mxGetPr(prhs[8]);
		if (nzmax != (mxGetN(prhs[8]))) mexErrMsgTxt("DW and MUIN mismatch");
		if (J != (mxGetM(prhs[8]))) mexErrMsgTxt("J and MUIN mismatch");
	}

	/* seeding */
	seedMT(1 + SEED*2); // seeding only works on uneven numbers

	/* allocate memory */
	mu  = dvec(J*nzmax);
	if (startcond == 1) {
		for (i=0; i<J*nzmax; i++) mu[i] = (double) MUIN[i];   
	}
	
	theta = dvec(J*D);

	/* run the learning algorithm */
	FBP_voc(ALPHA, BETA, W, J, D, NN, OUTPUT, nzmax, sr, ir, jc, phi, theta, mu, threshold, startcond);

	/* output */
	plhs[0] = mxCreateDoubleMatrix(J, D, mxREAL);
	mxSetPr(plhs[0], theta);

	plhs[1] = mxCreateDoubleMatrix(J, nzmax, mxREAL);
	mxSetPr(plhs[1], mu);
}
Exemplo n.º 18
0
//-----------------------------------------------------------------------------
// Name: InitDeviceObjects()
// Desc: Initialize scene objects.
//-----------------------------------------------------------------------------
HRESULT CMyD3DApplication::InitDeviceObjects()
{
	
/*
    HRESULT hr;
    // Create textures
    if( FAILED( D3DUtil_CreateTexture( m_pd3dDevice, _T("Ground2.bmp"),
                                       &m_pGroundTexture ) ) )
        return D3DAPPERR_MEDIANOTFOUND;

    if( FAILED( D3DUtil_CreateTexture( m_pd3dDevice, _T("Particle.bmp"),
                                       &m_pParticleTexture ) ) )
        return D3DAPPERR_MEDIANOTFOUND;

    // Set up the fonts and textures
    m_pFont->InitDeviceObjects( m_pd3dDevice );
    m_pFontSmall->InitDeviceObjects( m_pd3dDevice );
*/
    // Check if we can do the reflection effect
    m_bCanDoAlphaBlend = (m_d3dCaps.SrcBlendCaps & D3DPBLENDCAPS_SRCALPHA) &&
                         (m_d3dCaps.DestBlendCaps & D3DPBLENDCAPS_INVSRCALPHA);

    if( m_bCanDoAlphaBlend )
        m_bDrawReflection = TRUE;
/*
    // Create ground object
    {
        // Create vertex buffer for ground object
        hr = m_pd3dDevice->CreateVertexBuffer( m_dwNumGroundVertices*sizeof(COLORVERTEX),
                                               D3DUSAGE_WRITEONLY, D3DFVF_COLORVERTEX,
                                               D3DPOOL_MANAGED, &m_pGroundVB );
        if( FAILED(hr) )
            return E_FAIL;

        // Fill vertex buffer
        COLORVERTEX* pVertices;
        if( FAILED( m_pGroundVB->Lock( 0, 0, (BYTE**)&pVertices, NULL ) ) )
            return hr;

        // Fill in vertices
        for( DWORD zz = 0; zz <= GROUND_GRIDSIZE; zz++ )
        {
            for( DWORD xx = 0; xx <= GROUND_GRIDSIZE; xx++ )
            {
                pVertices->v.x   = GROUND_WIDTH * (xx/(FLOAT)GROUND_GRIDSIZE-0.5f);
                pVertices->v.y   = 0.0f;
                pVertices->v.z   = GROUND_HEIGHT * (zz/(FLOAT)GROUND_GRIDSIZE-0.5f);
                pVertices->color = GROUND_COLOR;
                pVertices->tu    = xx*GROUND_TILE/(FLOAT)GROUND_GRIDSIZE;
                pVertices->tv    = zz*GROUND_TILE/(FLOAT)GROUND_GRIDSIZE;
                pVertices++;
            }
        }

        m_pGroundVB->Unlock();

        // Create the index buffer
        WORD* pIndices;
        hr = m_pd3dDevice->CreateIndexBuffer( m_dwNumGroundIndices*sizeof(WORD),
                                              D3DUSAGE_WRITEONLY,
                                              D3DFMT_INDEX16, D3DPOOL_MANAGED,
                                              &m_pGroundIB );
        if( FAILED(hr) )
            return E_FAIL;

        // Fill the index buffer
        m_pGroundIB->Lock( 0, m_dwNumGroundIndices*sizeof(WORD), (BYTE**)&pIndices, 0 );
        if( FAILED(hr) )
            return E_FAIL;

        // Fill in indices
        for( DWORD z = 0; z < GROUND_GRIDSIZE; z++ )
        {
            for( DWORD x = 0; x < GROUND_GRIDSIZE; x++ )
            {
                DWORD vtx = x + z * (GROUND_GRIDSIZE+1);
                *pIndices++ = (WORD)( vtx + 1 );
                *pIndices++ = (WORD)( vtx + 0 );
                *pIndices++ = (WORD)( vtx + 0 + (GROUND_GRIDSIZE+1) );
                *pIndices++ = (WORD)( vtx + 1 );
                *pIndices++ = (WORD)( vtx + 0 + (GROUND_GRIDSIZE+1) );
                *pIndices++ = (WORD)( vtx + 1 + (GROUND_GRIDSIZE+1) );
            }
        }

        m_pGroundIB->Unlock();
    }
*/
	
			
	if (!inputSystem->Init (m_hWnd,m_hInstance,(int)manager->GetScreenWidth(),(int)manager->GetScreenHeight()))
		return FALSE;
	if (!inputSystem->InitMouse (1))
		return FALSE;
	if (!inputSystem->InitKeyboard ())
		return FALSE;

	// Center the mouse
	inputSystem->m_posX = manager->GetScreenWidth()/2;
	inputSystem->m_posY = manager->GetScreenHeight()/2;
	seedMT(timeGetTime());
	GetTime::Instance()->Init();

	FLOAT fAspect = ((FLOAT)(int)manager->GetScreenWidth()) / (int)manager->GetScreenHeight();
	user->camera.SetProjParams(D3DX_PI/4, fAspect, 0.1f, VIEW_DISTANCE);

	manager->SetState(mainMenuState);


    return S_OK;
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) 
{
	double *mu, *MUIN, *phi, *theta, *sr;
	double ALPHA, BETA;
	mwIndex *ir, *jc;
	int W, J, D, NN, SEED, OUTPUT, nzmax, i, j, wi, di, startcond;

	/* Check for proper number of arguments. */
	if (nrhs < 7) {
		mexErrMsgTxt("At least 7 input arguments required");
	} else if (nlhs < 2) {
		mexErrMsgTxt("At least 2 output arguments required");
	}

	startcond = 0;
	if (nrhs == 8) startcond = 1;

	/* dealing with sparse array WD */
	if (mxIsDouble(prhs[0]) != 1) mexErrMsgTxt("WD must be a double precision matrix");
	sr = mxGetPr(prhs[0]);
	ir = mxGetIr(prhs[0]);
	jc = mxGetJc(prhs[0]);
	nzmax = (int) mxGetNzmax(prhs[0]);
	W = (int) mxGetM(prhs[0]);
	D = (int) mxGetN(prhs[0]);

	J = (int) mxGetScalar(prhs[1]);
	if (J<=0) mexErrMsgTxt("Number of topics must be greater than zero");

	NN = (int) mxGetScalar(prhs[2]);
	if (NN<0) mexErrMsgTxt("Number of iterations must be positive");

	ALPHA = (double) mxGetScalar(prhs[3]);
	if (ALPHA<0) mexErrMsgTxt("ALPHA must be greater than zero");

	BETA = (double) mxGetScalar(prhs[4]);
	if (BETA<0) mexErrMsgTxt("BETA must be greater than zero");

	SEED = (int) mxGetScalar(prhs[5]);

	OUTPUT = (int) mxGetScalar(prhs[6]);

	if (startcond == 1) {
		MUIN = mxGetPr(prhs[7]);
		if (nzmax != (mxGetN(prhs[7]))) mexErrMsgTxt("WD and MUIN mismatch");
		if (J != (mxGetM(prhs[7]))) mexErrMsgTxt("J and MUIN mismatch");
	}

	/* seeding */
	seedMT(1 + SEED*2); // seeding only works on uneven numbers

	/* allocate memory */
	mu  = dvec(J*nzmax);
	if (startcond == 1) {
		for (i=0; i<J*nzmax; i++) mu[i] = (double) MUIN[i];   
	}

	phi = dvec(J*W);
	theta = dvec(J*D);

	/* run the learning algorithm */
	asiBP(ALPHA, BETA, W, J, D, NN, OUTPUT, sr, ir, jc, phi, theta, mu, startcond);

	/* output */
	plhs[0] = mxCreateDoubleMatrix(J, W, mxREAL);
	mxSetPr(plhs[0], phi);

	plhs[1] = mxCreateDoubleMatrix(J, D, mxREAL );
	mxSetPr(plhs[1], theta);

	plhs[2] = mxCreateDoubleMatrix(J, nzmax, mxREAL );
	mxSetPr(plhs[2], mu);
}
Exemplo n.º 20
0
int WINAPI WinMain(HINSTANCE hInst, HINSTANCE prev, LPSTR cmd, int show)
{
#if defined( USE_GDIPLUS )
    Gdiplus::GdiplusStartupInput gdiplusStartupInput;
    ULONG_PTR           gdiplusToken;
    // Initialize GDI+.
    Gdiplus::GdiplusStartup(&gdiplusToken, &gdiplusStartupInput, NULL);
#endif

    MathAccel::init();
    srand( (unsigned int)time( 0 ) );
    seedMT( (unsigned int)time( 0 ) );

    SetPriorityClass(GetCurrentProcess(), IDLE_PRIORITY_CLASS);

    theMainInstance = hInst;

    WNDCLASS wc;
    wc.style            = CS_HREDRAW | CS_VREDRAW;
    wc.lpfnWndProc      = (WNDPROC)MainWndProc;
    wc.cbClsExtra       = 0;
    wc.cbWndExtra       = 0;
    wc.hInstance        = theMainInstance;
    wc.hIcon            = NULL;
    wc.hCursor          = LoadCursor(NULL, IDC_ARROW);
    wc.hbrBackground    = NULL;
    wc.lpszMenuName     = NULL;
    wc.lpszClassName    = L"StrangeWorld4";

    RegisterClass(&wc);

    RECT wndRect;
    GetDefaultWindowSize(&wndRect);
    theMainWindow = CreateWindow( 
        L"StrangeWorld4", APPLICATION_NAME, 
        WS_VISIBLE | WS_OVERLAPPEDWINDOW,
        100, 100, wndRect.right, wndRect.bottom,
        NULL, NULL, theMainInstance, NULL);

    MSG msg = {0};
    do
    {
        if ( gCurrentState == ePAUSE )
        {
            if ( GetMessage( &msg, 0, 0, 0 ) )
            {
                TranslateMessage( &msg );
                DispatchMessage( &msg );
            }
        }
        else
        {
            Tick();
            if ( PeekMessage( &msg, 0, 0, 0, PM_REMOVE ) )
            {
                TranslateMessage( &msg );
                DispatchMessage( &msg );
            }
        }
    }
    while( msg.message != WM_QUIT );

#if defined( USE_GDIPLUS )
    Gdiplus::GdiplusShutdown(gdiplusToken);
#endif
    return 0;
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    double *srphi, *srtheta, *srad, *sr, *probs, *WS, *DS, *ZIN, *XIN, *Z, *X;
    double ALPHA, BETA;
    mwIndex *irphi, *jcphi, *irtheta, *jctheta, *irad, *jcad, *ir, *jc;
    int *z, *d, *w, *x, *order, *phi, *theta, *phitot, *atot;
    int W, J, D, A, NN, SEED, OUTPUT, nzmax, nzmaxad, nzmaxphi, nzmaxtheta, ntokens;
    int i, j, k, c, n, nt, wi, ci, ntcount, di;
    int i_start, i_end, a, nauthors, startcond;

    // Syntax
    //   [ PHI , THETA , Z , X ] = ATMGS( WD, AD , J , N , ALPHA , BETA , SEED , OUTPUT )

    // Syntax
    //   [ PHI , THETA , Z , X ] = ATMGS( WD, AD , J , N , ALPHA , BETA , SEED , OUTPUT , ZIN , XIN )


    /* Check for proper number of arguments. */
    if (nrhs < 8) {
        mexErrMsgTxt("At least 8 input arguments required");
    } else if (nlhs < 1) {
        mexErrMsgTxt("At least 1 output arguments required");
    }

    startcond = 0;
    if (nrhs > 8) startcond = 1;

    /* dealing with sparse array WD */
    if (mxIsDouble(prhs[0]) != 1) mexErrMsgTxt("WD must be a double precision matrix");
    sr = mxGetPr(prhs[0]);
    ir = mxGetIr(prhs[0]);
    jc = mxGetJc(prhs[0]);
    nzmax = (int) mxGetNzmax(prhs[0]);
    W = (int) mxGetM(prhs[0]);
    D = (int) mxGetN(prhs[0]);

    // get the number of tokens
    ntokens = (int) 0;
    for  (i=0; i<nzmax; i++) ntokens += (int) sr[i];
    if (ntokens == 0) mexErrMsgTxt("word vector is empty");

    d = ivec(ntokens);
    w = ivec(ntokens);

    // copy over the word and document indices into internal format
    k = (int) 0;
    for (di=0; di<D; di++) {
        for (i=jc[di]; i<jc[di+1]; i++) {
            wi = (int) ir[i];
            ci = (int) sr[i];
            for (j=0; j<ci; j++) {
                d[k] = di;
                w[k] = wi;
                k++;
            }
        }
    }

    if (ntokens != k) mexErrMsgTxt("Fail to read data");

    n = ntokens;

    if ((mxIsSparse( prhs[ 1 ] ) != 1) || (mxIsDouble( prhs[ 1 ] ) != 1))
        mexErrMsgTxt("Input matrix must be a sparse double precision matrix");

    /* dealing with sparse array AD */
    srad = mxGetPr(prhs[1]);
    irad = mxGetIr(prhs[1]);
    jcad = mxGetJc(prhs[1]);
    nzmaxad = mxGetNzmax(prhs[1]);
    A = mxGetM( prhs[1] );
    if (mxGetN( prhs[1] ) != D) mexErrMsgTxt("The number of columns in WD and AD matrix must be equal" );

    // get phi
    srphi = mxGetPr(prhs[2]);
    irphi = mxGetIr(prhs[2]);
    jcphi = mxGetJc(prhs[2]);
    nzmaxphi = (int) mxGetNzmax(prhs[2]);
    if (W != (int) mxGetN(prhs[2])) mexErrMsgTxt("Vocabulary size mismatches"); ;
    J = (int) mxGetM(prhs[2]);
    phi  = (int *) mxCalloc(J*W , sizeof(int));
    phitot = (int *) mxCalloc(J, sizeof(int));
    for (wi=0; wi<W; wi++) {
        for (i=jcphi[wi]; i<jcphi[wi+1]; i++) {
            j = (int) irphi[i];
            phi[wi*J + j] = (int) srphi[i];
            phitot[j] += (int) srphi[i];
        }
    }

    NN    = (int) mxGetScalar(prhs[3]);
    if (NN<0) mexErrMsgTxt("Number of iterations must be greater than zero");

    ALPHA = (double) mxGetScalar(prhs[4]);
    if (ALPHA<=0) mexErrMsgTxt("ALPHA must be greater than zero");

    BETA = (double) mxGetScalar(prhs[5]);
    if (BETA<=0) mexErrMsgTxt("BETA must be greater than zero");

    SEED = (int) mxGetScalar(prhs[6]);
    // set the seed of the random number generator

    OUTPUT = (int) mxGetScalar(prhs[7]);

    if (startcond == 1) {
        ZIN = mxGetPr(prhs[8]);
        if (ntokens != (mxGetM(prhs[8])*mxGetN(prhs[8]))) mexErrMsgTxt("WS and ZIN vectors should have same number of entries");

        XIN = mxGetPr(prhs[9]);
        if (ntokens != (mxGetM(prhs[9])*mxGetN(prhs[9]))) mexErrMsgTxt("WS and XIN vectors should have same number of entries");
    }

    // seeding
    seedMT(1 + SEED*2); // seeding only works on uneven numbers

    // check entries of AD matrix
    for (i=0; i<nzmaxad; i++) {
        nt = (int) srad[i];
        if ((nt<0) || (nt>1)) mexErrMsgTxt("Entries in AD matrix can only be 0 or 1");
    }

    /* allocate memory */
    z = ivec(ntokens);
    x = ivec( ntokens);

    if (startcond == 1) {
        for (i=0; i<ntokens; i++) {
            z[i] = (int) ZIN[i] - 1;
            x[i] = (int) XIN[i] - 1;
        }
    }

    order = ivec(ntokens);
    theta = ivec(J*A);
    phitot = ivec(J);
    atot = ivec(A);
    probs = dvec(J*NAMAX);

    /* check that every document has some authors */
    for (j=0; j<D; j++) {
        i_start = jcad[j];
        i_end   = jcad[j+1];
        nauthors = i_end - i_start;
        if (nauthors == 0) mexErrMsgTxt("There are some documents without authors in AD matrix ");
        if (nauthors > NAMAX) mexErrMsgTxt("Too many authors in some documents ... reached the NAMAX limit");
    }

    /* run the model */
    ATMGS(ALPHA, BETA, W, J, D, NN, OUTPUT, n, z, d, w, x, phi, theta, phitot, atot, order, probs, irad, jcad, startcond);

    // create sparse matrix theta
    nzmaxtheta = 0;
    for (i=0; i<A; i++) {
        for (j=0; j<J; j++)
            nzmaxtheta += (int) ( *( theta + j + i*J )) > 0;
    }

    plhs[0] = mxCreateSparse(J, A, nzmaxtheta, mxREAL);
    srtheta  = mxGetPr(plhs[0]);
    irtheta = mxGetIr(plhs[0]);
    jctheta = mxGetJc(plhs[0]);

    n = 0;
    for (i=0; i<A; i++) {
        *( jctheta + i ) = n;
        for (j=0; j<J; j++) {
            c = (int) *( theta + i*J + j );
            if (c >0) {
                *( srtheta + n ) = c;
                *( irtheta + n ) = j;
                n++;
            }
        }
    }

    *(jctheta + A) = n;

    plhs[1] = mxCreateDoubleMatrix(1, ntokens, mxREAL);
    Z = mxGetPr(plhs[1]);
    for (i=0; i<ntokens; i++) Z[i] = (double) z[i] + 1;

    plhs[2] = mxCreateDoubleMatrix(1, ntokens, mxREAL);
    X = mxGetPr(plhs[2]);
    for (i=0; i<ntokens; i++) X[i] = (double) x[i] + 1;
}
Exemplo n.º 22
0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs,
                 const mxArray *prhs[])
{
  double *srww_old, *srww_new,*srwp_old,*srdp_new, *probs, *DS_NEW, *WS_NEW, *WS_OLD, *SI_NEW, *WC_OLD, *PROBA, *C, *Z;
  double ALPHA,BETA, GAMMA0, GAMMA1, DELTA;
  int *irww_old, *irww_new, *jcww_old, *jcww_new, *irwp_old, *jcwp_old, *irdp_new, *jcdp_new;
  int *z,*d,*w, *s, *c, *sc, *order, *wp_old, *dp_new, *dp_new2, *ztot_old, *wtot_new, *wc_new;
  int W_NEW, W_OLD,T_NEW, T_OLD,D_NEW,NITER,SEED,OUTPUT, nzmax_new, nzmax_old, nzmaxdp_new, ntokens_new, ntokens_old;
  int i,j,cc,nt,n,wi,wipre,di, startindex, endindex, currentrow, count;
  int NSAMPLE, updatecols;
  
  /* Check for proper number of arguments. */
  if (nrhs != 19) {
    mexErrMsgTxt("19 input arguments required");
  } else if (nlhs != 4) {
    mexErrMsgTxt("4 output arguments required");
  }

// Syntax                                       0       1          2       3       4   5    6         7      8       9      10        11      12      13      14       15       16       17     18 
//   [ DP_NEW,C,Z,PROB ] = NewDocumentsLDACOL( WS_NEW , DS_NEW , SI_NEW , WW_NEW , T , N , NSAMPLE , ALPHA , BETA , GAMMA0, GAMMA1 , DELTA , SEED , OUTPUT , WW_OLD , WP_OLD , WC_OLD , WS_OLD, UPDATECOLS );
  
  
  /* process the input arguments */
  if (mxIsDouble( prhs[ 0 ] ) != 1)
      mexErrMsgTxt("WS_NEW must be double precision"); 
  WS_NEW = mxGetPr(prhs[0]);
  ntokens_new = mxGetM( prhs[ 0 ] ) * mxGetN( prhs[ 0 ] );
  
  if (mxIsDouble( prhs[ 1 ] ) != 1)
      mexErrMsgTxt("DS_NEW must be double precision"); 
  DS_NEW = mxGetPr(prhs[1]);
  
  if (mxIsDouble( prhs[ 2 ] ) != 1)
      mexErrMsgTxt("SI_NEW must be double precision");
  SI_NEW = mxGetPr(prhs[2]);
  
  if ((mxIsSparse( prhs[ 3 ] ) != 1) || (mxIsDouble( prhs[ 3 ] ) != 1))
      mexErrMsgTxt("WW_NEW collocation matrix must be a sparse double precision matrix");

  /* dealing with sparse array WW_NEW */
  srww_new  = mxGetPr(prhs[3]);
  irww_new  = mxGetIr(prhs[3]);
  jcww_new  = mxGetJc(prhs[3]);
  nzmax_new = mxGetNzmax(prhs[3]);  
  W_NEW     = mxGetM( prhs[3] );
  
  if (mxGetN( prhs[3] ) != W_NEW) mexErrMsgTxt("WW_NEW matrix should be square");
  
  D_NEW    = 0;
  for (i=0; i<ntokens_new; i++) {
      if (DS_NEW[ i ] > D_NEW) D_NEW = (int) DS_NEW[ i ];
      if (WS_NEW[ i ] > W_NEW) mexErrMsgTxt("Some word tokens in WS_NEW stream exceed number of word types in WW_NEW matrix");
  }
   
  T_NEW    = (int) mxGetScalar(prhs[4]);
  if (T_NEW<=0) mexErrMsgTxt("Number of topics must be greater than zero");
  
  NITER    = (int) mxGetScalar(prhs[5]);
  if (NITER<0) mexErrMsgTxt("Number of iterations must be greater than zero");
  
  NSAMPLE    = (int) mxGetScalar(prhs[6]);
  if (NSAMPLE<0) mexErrMsgTxt("Number of samples must be greater than zero");
  
  ALPHA = (double) mxGetScalar(prhs[7]);
  if (ALPHA<=0) mexErrMsgTxt("ALPHA must be greater than zero");
  
  BETA = (double) mxGetScalar(prhs[8]);
  if (BETA<=0) mexErrMsgTxt("BETA must be greater than zero");
  
  GAMMA0 = (double) mxGetScalar(prhs[9]);
  if (GAMMA0<=0) mexErrMsgTxt("GAMMA0 must be greater than zero");
  
  GAMMA1 = (double) mxGetScalar(prhs[10]);
  if (GAMMA1<=0) mexErrMsgTxt("GAMMA1 must be greater than zero");
  
  DELTA = (double) mxGetScalar(prhs[11]);
  if (DELTA<=0) mexErrMsgTxt("DELTA must be greater than zero");
  
  SEED = (int) mxGetScalar(prhs[12]);
  // set the seed of the random number generator
  
  OUTPUT = (int) mxGetScalar(prhs[13]);
  
  // dealing with sparse array WW_OLD 
  if ((mxIsSparse( prhs[ 14 ] ) != 1) || (mxIsDouble( prhs[ 14 ] ) != 1))
      mexErrMsgTxt("WW_OLD collocation frequency matrix must be a sparse double precision matrix");  
  srww_old  = mxGetPr(prhs[14]);
  irww_old  = mxGetIr(prhs[14]);
  jcww_old  = mxGetJc(prhs[14]);
  nzmax_old = mxGetNzmax(prhs[14]);
  W_OLD     = mxGetM( prhs[14] );
  if (W_OLD != W_NEW) mexErrMsgTxt("WW_OLD and WW_NEW matrices have different dimensions");
  if (mxGetN( prhs[14] ) != W_NEW) mexErrMsgTxt("WW_OLD matrix should be square");
  
  // dealing with sparse array WP_OLD 
  if ((mxIsSparse( prhs[ 15 ] ) != 1) || (mxIsDouble( prhs[ 15 ] ) != 1))
      mexErrMsgTxt("WP_OLD topic-word matrix must be a sparse double precision matrix");  
  srwp_old  = mxGetPr(prhs[15]);
  irwp_old  = mxGetIr(prhs[15]);
  jcwp_old  = mxGetJc(prhs[15]);
  if (mxGetM( prhs[15] ) != W_NEW) mexErrMsgTxt("Number of words in WP_OLD matrix does not match WW_NEW number of words");
  if (mxGetN( prhs[15] ) != T_NEW) mexErrMsgTxt("Number of topics in WP_OLD matrix does not match given number of topics");
 
  WC_OLD = mxGetPr( prhs[ 16 ]);
  if ((mxGetM( prhs[16] ) * mxGetN( prhs[16] )) != W_NEW ) mexErrMsgTxt("Number of words in WC_OLD matrix does not match WW_NEW number of words");   
  
  if (mxIsDouble( prhs[ 17 ] ) != 1) mexErrMsgTxt("WS_OLD must be double precision"); 
  WS_OLD = mxGetPr(prhs[ 17 ]);
  ntokens_old = mxGetM( prhs[ 17 ] ) * mxGetN( prhs[ 17 ] );
  
  updatecols = (int) mxGetScalar(prhs[18]);
  
  // seeding
  seedMT( 1 + SEED * 2 ); // seeding only works on uneven numbers
  
  /* allocate memory */
  z  = (int *) mxCalloc( ntokens_new , sizeof( int ));
  d  = (int *) mxCalloc( ntokens_new , sizeof( int ));
  w  = (int *) mxCalloc( ntokens_new , sizeof( int ));
  s  = (int *) mxCalloc( ntokens_new , sizeof( int ));
  c  = (int *) mxCalloc( ntokens_new , sizeof( int ));
  sc  = (int *) mxCalloc( ntokens_new , sizeof( int ));
 
  for (i=0; i<ntokens_new; i++) w[ i ] = (int) (WS_NEW[ i ] - 1); // Matlab indexing not zero based
  for (i=0; i<ntokens_new; i++) d[ i ] = (int) (DS_NEW[ i ] - 1); // Matlab indexing not zero based
  for (i=0; i<ntokens_new; i++) s[ i ] = (int) SI_NEW[ i ];
  
  order   = (int *) mxCalloc( ntokens_new , sizeof( int ));
  wp_old  = (int *) mxCalloc( T_NEW*W_NEW , sizeof( int ));
  dp_new  = (int *) mxCalloc( T_NEW*D_NEW , sizeof( int ));
  dp_new2  = (int *) mxCalloc( T_NEW*D_NEW , sizeof( int ));
  wc_new  = (int *) mxCalloc( W_NEW , sizeof( int ));
  ztot_old  = (int *) mxCalloc( T_NEW , sizeof( int ));
  wtot_new  = (int *) mxCalloc( W_NEW , sizeof( int ));
  probs  = (double *) mxCalloc( T_NEW , sizeof( double ));
  
  // FILL IN WTOT_NEW with OLD COUNTS
  for (i=0; i<ntokens_old; i++) {
     j = (int) WS_OLD[ i ] - 1;
     wtot_new[ j ]++; // calculate number of words of each type
  }
  
  // FILL IN WTOT_NEW with NEW COUNTS
  if (updatecols==1)
  for (i=0; i<ntokens_new; i++) {
     j = w[ i ];
     wtot_new[ j ]++; // calculate number of words of each type
  }
  
  // FILL IN WC WITH COUNTS FROM OLD SET
  for (i=0; i<W_NEW; i++) {
     wc_new[ i ] = (int) WC_OLD[ i ];   
  }
  
  // FILL IN wp_old and ztot_old
  for (j=0; j<T_NEW; j++) {
      startindex = *( jcwp_old + j );
      endindex   = *( jcwp_old + j + 1 ) - 1;
      
      for (i=startindex; i<=endindex; i++) {
          currentrow = *( irwp_old + i );
          count      = (int) *( srwp_old + i );
          wp_old[ j + currentrow*T_NEW ] = count;
          ztot_old[ j ] += count;
      }    
  }
  
  // FILL IN SC COUNTS 
  for (i=1; i<ntokens_new; i++) // start with second item
  {
      wi = w[ i ]; // word index
      
      count = 0;
      
      // what is the previous word?
      wipre = w[ i-1 ];
      
      // calculate how many times the current word follows the previous word IN THE OLD DOCUMENT SET    
      startindex = *( jcww_old + wipre ); // look up start and end index for column wipre in WW_OLD matrix
      endindex   = *( jcww_old + wipre + 1 ) - 1;   
      for (j=startindex; j<=endindex; j++) {
          currentrow = *( irww_old + j );
          if (currentrow == wi) count = (int) *( srww_old + j );
      }
      
      // calculate how many times the current word follows the previous word IN THE NEW DOCUMENT SET
      if (updatecols==1) {
          startindex = *( jcww_new + wipre ); // look up start and end index for column wipre in WW_NEW matrix
          endindex   = *( jcww_new + wipre + 1 ) - 1;
          for (j=startindex; j<=endindex; j++) {
              currentrow = *( irww_new + j );
              if (currentrow == wi) count += (int) *( srww_new + j ); // add the count to what was there previously !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
          }
      }
      
      sc[ i ] = count;
  }
  
  
  
  /*for (i=0; i<10; i++) {
     mexPrintf( "i=%4d w[i]=%3d d[i]=%d z[i]=%d s[i]=%d\n" , i , w[i] , d[i] , z[i] , s[i] );    
  }*/
  
  if (OUTPUT==2) {
      mexPrintf( "Running NEW DOCUMENT LDA COL Gibbs Sampler Version 1.0\n" );
      mexPrintf( "Arguments:\n" );
      mexPrintf( "\tNumber of words      W = %d\n" , W_NEW );
      mexPrintf( "\tNumber of docs       D = %d\n" , D_NEW );
      mexPrintf( "\tNumber of topics     T = %d\n" , T_NEW );
      mexPrintf( "\tNumber of iterations N = %d\n" , NITER );
      mexPrintf( "\tNumber of samples    S = %d\n" , NSAMPLE );
      mexPrintf( "\tHyperparameter   ALPHA = %4.4f\n" , ALPHA );
      mexPrintf( "\tHyperparameter    BETA = %4.4f\n" , BETA );
      mexPrintf( "\tHyperparameter  GAMMA0 = %4.4f\n" , GAMMA0 );
      mexPrintf( "\tHyperparameter  GAMMA1 = %4.4f\n" , GAMMA1 );
      mexPrintf( "\tHyperparameter  DELTA  = %4.4f\n" , DELTA );
      mexPrintf( "\tSeed number            = %d\n" , SEED );
      mexPrintf( "\tNumber of tokens       = %d\n" , ntokens_new );
      mexPrintf( "\tUpdating collocation counts with new documents? = %d\n" , updatecols );
  }
  
  /* ---------------------------------------------
     create the PROB vector
   -----------------------------------------------*/
  plhs[ 3 ] = mxCreateDoubleMatrix( ntokens_new , 1 , mxREAL );
  PROBA = mxGetPr( plhs[ 3 ] );
  
  /* run the model */  
  GibbsSamplerLDACOL( ALPHA, BETA, GAMMA0,GAMMA1,DELTA, W_NEW, T_NEW, D_NEW, NITER, NSAMPLE , OUTPUT, ntokens_new, z, d, w, s, c, sc, wp_old, dp_new, dp_new2, ztot_old, wtot_new, order, probs, wc_new , PROBA , updatecols );
  

  /* ---------------------------------------------
   convert the full DP matrix into a sparse matrix 
   -----------------------------------------------*/
  nzmaxdp_new = 0;
  for (i=0; i<D_NEW; i++) {
      for (j=0; j<T_NEW; j++)
          nzmaxdp_new += (int) ( *( dp_new2 + j + i*T_NEW )) > 0; // !!!!!!!!!!!!!!!!!! copy from dp_new2
  }
  
  if (OUTPUT==2) {
      mexPrintf( "Constructing sparse output matrix dp_new\n" );
      mexPrintf( "Number of nonzero entries for DP = %d\n" , nzmaxdp_new );
  }
  
  plhs[0] = mxCreateSparse( D_NEW,T_NEW,nzmaxdp_new,mxREAL);
  srdp_new  = mxGetPr(plhs[0]);
  irdp_new = mxGetIr(plhs[0]);
  jcdp_new = mxGetJc(plhs[0]);
  n = 0;
  for (j=0; j<T_NEW; j++) {
      *( jcdp_new + j ) = n;
      for (i=0; i<D_NEW; i++) {
          cc = (int) *( dp_new2 + i*T_NEW + j ); // !!!!!!!!!!!!!!!!!! copy from dp_new2
          if (cc >0) {
              *( srdp_new + n ) = cc;
              *( irdp_new + n ) = i;
              n++;
          }
      }
  } 
  *( jcdp_new + T_NEW ) = n;
  
  /* ---------------------------------------------
     create the C route vector
   -----------------------------------------------*/
  plhs[ 1 ] = mxCreateDoubleMatrix( ntokens_new , 1 , mxREAL );
  C = mxGetPr( plhs[ 1 ] );
  for (i=0; i<ntokens_new; i++) C[ i ] = (double) c[ i ];
  
  /* ---------------------------------------------
     create the topic assignment vector
   -----------------------------------------------*/
  plhs[ 2 ] = mxCreateDoubleMatrix( ntokens_new , 1 , mxREAL );
  Z = mxGetPr( plhs[ 2 ] );
  for (i=0; i<ntokens_new; i++) Z[ i ] = (double) z[ i ] + 1;

}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs,
                 const mxArray *prhs[])
{
  double *srwp, *srdp, *probs, *Z, *X, *WS, *DS, *ZIN, *XIN;
  double ALPHA, BETA, GAMMA1, GAMMA0;
  int *irwp, *jcwp, *irdp, *jcdp;
  int *z,*d,*w, *x, *xcounts0, *xcounts1, *order, *wp, *dp, *sumdp, *ztot;
  int W,T,T2,D,NN,SEED,OUTPUT, nzmax, nzmaxwp, nzmaxdp, ntokens;
  int i,j,c,n,nt,wi,di, startcond;
  
  /* Check for proper number of arguments. */
  if (nrhs < 10) {
    mexErrMsgTxt("At least 10 input arguments required");
  } else if (nlhs < 4) {
    mexErrMsgTxt("4 output arguments required");
  }
  
  startcond = 0;
  if (nrhs >= 11) startcond = 1;
  
  /* process the input arguments */
  if (mxIsDouble( prhs[ 0 ] ) != 1) mexErrMsgTxt("WS input vector must be a double precision matrix");
  if (mxIsDouble( prhs[ 1 ] ) != 1) mexErrMsgTxt("DS input vector must be a double precision matrix");
  
  // pointer to word indices
  WS = mxGetPr( prhs[ 0 ] );
     
  // pointer to document indices
  DS = mxGetPr( prhs[ 1 ] );
  
  // get the number of tokens
  ntokens = mxGetM( prhs[ 0 ] ) * mxGetN( prhs[ 0 ] );
  
  
  if (ntokens == 0) mexErrMsgTxt("WS vector is empty"); 
  if (ntokens != ( mxGetM( prhs[ 1 ] ) * mxGetN( prhs[ 1 ] ))) mexErrMsgTxt("WS and DS vectors should have same number of entries");
  
  T    = (int) mxGetScalar(prhs[2]);
  if (T<=0) mexErrMsgTxt("Number of topics must be greater than zero");
  
  NN    = (int) mxGetScalar(prhs[3]);
  if (NN<0) mexErrMsgTxt("Number of iterations must be positive");
  
  ALPHA = (double) mxGetScalar(prhs[4]);
  if (ALPHA<=0) mexErrMsgTxt("ALPHA0 must be greater than zero");

  BETA = (double) mxGetScalar(prhs[5]);
  if (BETA<=0) mexErrMsgTxt("BETA must be greater than zero");

  GAMMA0 = (double) mxGetScalar(prhs[6]);
  if (GAMMA0<=0) mexErrMsgTxt("GAMMA0 must be greater than zero");

  GAMMA1 = (double) mxGetScalar(prhs[7]);
  if (GAMMA1<=0) mexErrMsgTxt("GAMMA1 must be greater than zero");

  SEED = (int) mxGetScalar(prhs[8]);
  
  OUTPUT = (int) mxGetScalar(prhs[9]);
  
  if (startcond == 1) {
      ZIN = mxGetPr( prhs[ 10 ] );
      if (ntokens != ( mxGetM( prhs[ 10 ] ) * mxGetN( prhs[ 10 ] ))) mexErrMsgTxt("WS and ZIN vectors should have same number of entries");
      
      XIN = mxGetPr( prhs[ 11 ] );
      if (ntokens != ( mxGetM( prhs[ 11 ] ) * mxGetN( prhs[ 11 ] ))) mexErrMsgTxt("WS and XIN vectors should have same number of entries");
  }
  
  // seeding
  seedMT( 1 + SEED * 2 ); // seeding only works on uneven numbers
  
   
  
  /* allocate memory */
  z  = (int *) mxCalloc( ntokens , sizeof( int ));
  x  = (int *) mxCalloc( ntokens , sizeof( int ));
  
  if (startcond == 1) {
     for (i=0; i<ntokens; i++) z[ i ] = (int) ZIN[ i ] - 1;
     for (i=0; i<ntokens; i++) x[ i ] = (int) XIN[ i ]; // 0 = normal topic assignment 1 = idiosyncratic topic
  }
  
  d  = (int *) mxCalloc( ntokens , sizeof( int ));
  w  = (int *) mxCalloc( ntokens , sizeof( int ));
  order  = (int *) mxCalloc( ntokens , sizeof( int ));  
  
  
  // copy over the word and document indices into internal format
  for (i=0; i<ntokens; i++) {
     w[ i ] = (int) WS[ i ] - 1;
     d[ i ] = (int) DS[ i ] - 1;
  }
  
  n = ntokens;
  
  W = 0;
  D = 0;
  for (i=0; i<n; i++) {
     if (w[ i ] > W) W = w[ i ];
     if (d[ i ] > D) D = d[ i ];
  }
  W = W + 1;
  D = D + 1;
  
  // NOTE: the wp matrix has T+D topics where the last D topics are idiosyncratic
  T2 = T + D;
  wp  = (int *) mxCalloc( T2*W , sizeof( int ));
  
  // NOTE: the last topic probability is the special topic probability
  dp  = (int *) mxCalloc( (T+1)*D , sizeof( int ));
  
  sumdp  = (int *) mxCalloc( D , sizeof( int ));
  
  ztot  = (int *) mxCalloc( T2 , sizeof( int ));
  probs  = (double *) mxCalloc( T+1 , sizeof( double ));
  xcounts0 = (int *) mxCalloc( D , sizeof( int ));
  xcounts1 = (int *) mxCalloc( D , sizeof( int ));
  
  //mexPrintf( "N=%d  T=%d W=%d D=%d\n" , ntokens , T , W , D );
  
  if (OUTPUT==2) {
      mexPrintf( "Running LDA Gibbs Sampler -- with special topics\n" );
      if (startcond==1) mexPrintf( "Starting from previous state ZIN\n" );
      mexPrintf( "Arguments:\n" );
      mexPrintf( "\tNumber of words      W = %d\n"    , W );
      mexPrintf( "\tNumber of docs       D = %d\n"    , D );
      mexPrintf( "\tNumber of topics     T = %d\n"    , T );
      mexPrintf( "\tNumber of iterations N = %d\n"    , NN );
      mexPrintf( "\tHyperparameter   ALPHA = %4.4f\n" , ALPHA );
      mexPrintf( "\tHyperparameter    BETA = %4.4f\n" , BETA );
      mexPrintf( "\tHyperparameter  GAMMA0 = %4.4f\n" , GAMMA0 );
      mexPrintf( "\tHyperparameter  GAMMA1 = %4.4f\n" , GAMMA1 );
      mexPrintf( "\tSeed number            = %d\n"    , SEED );
      mexPrintf( "\tNumber of tokens       = %d\n"    , ntokens );
      //mexPrintf( "Internal Memory Allocation\n" );
      //mexPrintf( "\tw,d,z,order indices combined = %d bytes\n" , 4 * sizeof( int) * ntokens );
      //mexPrintf( "\twp (full) matrix = %d bytes\n" , sizeof( int ) * W * T  );
      //mexPrintf( "\tdp (full) matrix = %d bytes\n" , sizeof( int ) * D * T  );
      //mexPrintf( "Checking: sizeof(int)=%d sizeof(long)=%d sizeof(double)=%d\n" , sizeof(int) , sizeof(long) , sizeof(double));
  }
  
  /* run the model */
  GibbsSampler( ALPHA, BETA, GAMMA0, GAMMA1, W, T, D, NN, OUTPUT, n, z, x, d, w, wp, dp, sumdp, ztot, xcounts0, xcounts1, order, probs, startcond );
  
  /* convert the full wp matrix into a sparse matrix */
  nzmaxwp = 0;
  for (i=0; i<W; i++) {
     for (j=0; j<T2; j++)
         nzmaxwp += (int) ( *( wp + j + i*T2 )) > 0;
  }  
  /*if (OUTPUT==2) {
      mexPrintf( "Constructing sparse output matrix wp\n" );
      mexPrintf( "Number of nonzero entries for WP = %d\n" , nzmaxwp );
  }*/
  
  // MAKE THE WP SPARSE MATRIX
  plhs[0] = mxCreateSparse( W,T2,nzmaxwp,mxREAL);
  srwp  = mxGetPr(plhs[0]);
  irwp = mxGetIr(plhs[0]);
  jcwp = mxGetJc(plhs[0]);  
  n = 0;
  for (j=0; j<T2; j++) {
      *( jcwp + j ) = n;
      for (i=0; i<W; i++) {
         c = (int) *( wp + i*T2 + j );
         if (c >0) {
             *( srwp + n ) = c;
             *( irwp + n ) = i;
             n++;
         }
      }    
  }  
  *( jcwp + T2 ) = n;    
   
  // MAKE THE DP SPARSE MATRIX
  nzmaxdp = 0;
  for (i=0; i<D; i++) {
      for (j=0; j<(T+1); j++)
          nzmaxdp += (int) ( *( dp + j + i*(T+1) )) > 0;
  }  
  /*if (OUTPUT==2) {
      mexPrintf( "Constructing sparse output matrix dp\n" );
      mexPrintf( "Number of nonzero entries for DP = %d\n" , nzmaxdp );
  } */ 
  plhs[1] = mxCreateSparse( D,T+1,nzmaxdp,mxREAL);
  srdp  = mxGetPr(plhs[1]);
  irdp = mxGetIr(plhs[1]);
  jcdp = mxGetJc(plhs[1]);
  n = 0;
  for (j=0; j<T+1; j++) {
      *( jcdp + j ) = n;
      for (i=0; i<D; i++) {
          c = (int) *( dp + i*(T+1) + j );
          if (c >0) {
              *( srdp + n ) = c;
              *( irdp + n ) = i;
              n++;
          }
      }
  }
  *( jcdp + (T+1) ) = n;
  
  plhs[ 2 ] = mxCreateDoubleMatrix( 1,ntokens , mxREAL );
  Z = mxGetPr( plhs[ 2 ] );
  for (i=0; i<ntokens; i++) Z[ i ] = (double) z[ i ] + 1;
  
  plhs[ 3 ] = mxCreateDoubleMatrix( 1,ntokens , mxREAL );
  X = mxGetPr( plhs[ 3 ] );
  for (i=0; i<ntokens; i++) X[ i ] = (double) x[ i ];
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs,
                 const mxArray *prhs[])
{
  double *srwp, *srdp, *srtd, *probs, *Z, *WS, *DS, *ZIN;
  double ALPHA,BETA;
  mwIndex *irwp, *jcwp, *irdp, *jcdp, *irtd, *jctd;
  int *z,*d,*w, *order, *wp, *ztot;
  int W,T,D,NN,SEED,OUTPUT, nzmax, nzmaxwp, nzmaxdp, ntokens;
  int i,j,c,n,nt,wi,di, startcond, n1,n2, TAVAIL, topic;
  
  /* Check for proper number of arguments. */
  if (nrhs < 8) {
    mexErrMsgTxt("At least 8 input arguments required");
  } else if (nlhs < 3) {
    mexErrMsgTxt("3 output arguments required");
  }
  
  startcond = 0;
  if (nrhs == 9) startcond = 1;
  
  /* process the input arguments */
  if (mxIsDouble( prhs[ 0 ] ) != 1) mexErrMsgTxt("WS input vector must be a double precision matrix");
  if (mxIsDouble( prhs[ 1 ] ) != 1) mexErrMsgTxt("DS input vector must be a double precision matrix");
  
  // pointer to word indices
  WS = mxGetPr( prhs[ 0 ] );
     
  // pointer to document indices
  DS = mxGetPr( prhs[ 1 ] );
  
  // get the number of tokens
  ntokens = (int) mxGetM( prhs[ 0 ] ) * (int) mxGetN( prhs[ 0 ] );
  
  
  if (ntokens == 0) mexErrMsgTxt("WS vector is empty"); 
  if (ntokens != ( mxGetM( prhs[ 1 ] ) * mxGetN( prhs[ 1 ] ))) mexErrMsgTxt("WS and DS vectors should have same number of entries");
  
  // Input Sparse-Tag-Document Matrix
  srtd  = mxGetPr(prhs[2]);
  irtd = mxGetIr(prhs[2]);
  jctd = mxGetJc(prhs[2]);
  nzmaxdp = (int) mxGetNzmax(prhs[2]); // number of nonzero entries in tag-document matrix
   
  NN    = (int) mxGetScalar(prhs[3]);
  if (NN<0) mexErrMsgTxt("Number of iterations must be positive");
  
  ALPHA = (double) mxGetScalar(prhs[4]);
  if (ALPHA<=0) mexErrMsgTxt("ALPHA must be greater than zero");
  
  BETA = (double) mxGetScalar(prhs[5]);
  if (BETA<=0) mexErrMsgTxt("BETA must be greater than zero");
  
  SEED = (int) mxGetScalar(prhs[6]);
  
  OUTPUT = (int) mxGetScalar(prhs[7]);
  
  if (startcond == 1) {
      ZIN = mxGetPr( prhs[ 8 ] );
      if (ntokens != ( mxGetM( prhs[ 8 ] ) * mxGetN( prhs[ 8 ] ))) mexErrMsgTxt("WS and ZIN vectors should have same number of entries");
  }
  
  // seeding
  seedMT( 1 + SEED * 2 ); // seeding only works on uneven numbers
    
  /* allocate memory */
  z  = (int *) mxCalloc( ntokens , sizeof( int ));
  
  if (startcond == 1) {
     for (i=0; i<ntokens; i++) z[ i ] = (int) ZIN[ i ] - 1;   
  }
  
  d  = (int *) mxCalloc( ntokens , sizeof( int ));
  w  = (int *) mxCalloc( ntokens , sizeof( int ));
  order  = (int *) mxCalloc( ntokens , sizeof( int ));  
  
  
  // copy over the word and document indices into internal format
  for (i=0; i<ntokens; i++) {
     w[ i ] = (int) WS[ i ] - 1;
     d[ i ] = (int) DS[ i ] - 1;
  }
  
  n = ntokens;
  
  W = 0;
  D = 0;
  for (i=0; i<n; i++) {
     if (w[ i ] > W) W = w[ i ];
     if (d[ i ] > D) D = d[ i ];
  }
  W = W + 1;
  D = D + 1;
   
  // Number of topics is based on number of tags in sparse tag-document matrix 
  T  = (int) mxGetM( prhs[ 2 ] );
  
  // check number of docs in sparse tag-document matrix
  if (D != (int) mxGetN( prhs[ 2 ])) mexErrMsgTxt("Mismatch in number of documents in DS vector TAGSET sparse matrix");
  
  ztot  = (int *) mxCalloc( T , sizeof( int ));
  probs  = (double *) mxCalloc( T , sizeof( double ));
  wp  = (int *) mxCalloc( T*W , sizeof( int ));
  

  // create sparse DP matrix that is T x D and not the other way around !!!!
  plhs[1] = mxCreateSparse( T,D,nzmaxdp,mxREAL);
  srdp  = mxGetPr(plhs[1]);
  irdp = mxGetIr(plhs[1]);
  jcdp = mxGetJc(plhs[1]);
  
  // now copy the structure from TD over to DP
  for (i=0; i<D; i++) {
     n1 = (int) *( jctd + i     );
     n2 = (int) *( jctd + i + 1 );
     
     // copy over the row-index start and end indices
     *( jcdp + i ) = (int) n1;
    
     // number of available topics for this document
     TAVAIL = (n2-n1);     
     for (j = 0; j < TAVAIL; j++) {
         topic = (int) *( irtd + n1 + j );
         *( irdp + n1 + j ) = topic;
         *( srdp + n1 + j ) = 0; // initialize DP counts with ZERO
     }        
  }
  // copy over final column indices
  n1 = (int) *( jctd + D     );
  *( jcdp + D ) = (int) n1;
  

  if (OUTPUT==2) {
      mexPrintf( "Running LDA Gibbs Sampler Version 1.0\n" );
      if (startcond==1) mexPrintf( "Starting from previous state ZIN\n" );
      mexPrintf( "Arguments:\n" );
      mexPrintf( "\tNumber of words      W = %d\n"    , W );
      mexPrintf( "\tNumber of docs       D = %d\n"    , D );
      mexPrintf( "\tNumber of tags       T = %d\n"    , T );
      mexPrintf( "\tNumber of iterations N = %d\n"    , NN );
      mexPrintf( "\tHyperparameter   ALPHA = %4.4f\n" , ALPHA );
      mexPrintf( "\tHyperparameter    BETA = %4.4f\n" , BETA );
      mexPrintf( "\tSeed number            = %d\n"    , SEED );
      mexPrintf( "\tNumber of tokens       = %d\n"    , ntokens );
      mexPrintf( "\tNumber of nonzeros in tag matrix  = %d\n"    , nzmaxdp );
  }
  
  /* run the model */
  GibbsSamplerLDA( ALPHA, BETA, W, T, D, NN, OUTPUT, n, z, d, w, wp, ztot, order, probs, startcond,
                   irtd, jctd, srdp, irdp, jcdp );
  
  /* convert the full wp matrix into a sparse matrix */
  nzmaxwp = 0;
  for (i=0; i<W; i++) {
     for (j=0; j<T; j++)
         nzmaxwp += (int) ( *( wp + j + i*T )) > 0;
  }  
  
  // MAKE THE WP SPARSE MATRIX
  plhs[0] = mxCreateSparse( W,T,nzmaxwp,mxREAL);
  srwp  = mxGetPr(plhs[0]);
  irwp = mxGetIr(plhs[0]);
  jcwp = mxGetJc(plhs[0]);  
  n = 0;
  for (j=0; j<T; j++) {
      *( jcwp + j ) = n;
      for (i=0; i<W; i++) {
         c = (int) *( wp + i*T + j );
         if (c >0) {
             *( srwp + n ) = c;
             *( irwp + n ) = i;
             n++;
         }
      }    
  }  
  *( jcwp + T ) = n;    
   
  plhs[ 2 ] = mxCreateDoubleMatrix( 1,ntokens , mxREAL );
  Z = mxGetPr( plhs[ 2 ] );
  for (i=0; i<ntokens; i++) Z[ i ] = (double) z[ i ] + 1;
}
Exemplo n.º 25
0
int main(void)
{
    int i;

    for (i=0; i < NUM_PEERS; i++)
        rakPeer[i]=RakNetworkFactory::GetRakPeerInterface();

    printf("This project tests and demonstrates the fully connected mesh plugin.\n");
    printf("No data is actually sent so it's mostly a sample of how to use a plugin.\n");
    printf("Difficulty: Beginner\n\n");

    int peerIndex;

    // Initialize the message handlers
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
//		fullyConnectedMeshPlugin[peerIndex].Startup(0,0);
        rakPeer[peerIndex]->AttachPlugin(&fullyConnectedMeshPlugin[peerIndex]);
        // The fully connected mesh relies on the connection graph plugin also being attached
        rakPeer[peerIndex]->AttachPlugin(&connectionGraphPlugin[peerIndex]);
        rakPeer[peerIndex]->SetMaximumIncomingConnections(NUM_PEERS);
    }

    // Initialize the peers
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        SocketDescriptor socketDescriptor(60000+peerIndex,0);
        rakPeer[peerIndex]->Startup(NUM_PEERS, 0, &socketDescriptor, 1);
    }

    // Give the threads time to properly start
    RakSleep(200);

    printf("Peers initialized. ");
    printf("Connecting each peer to the prior peer\n");

    // Connect each peer to the prior peer
    for (peerIndex=1; peerIndex < NUM_PEERS; peerIndex++)
    {
        rakPeer[peerIndex]->Connect("127.0.0.1", 60000+peerIndex-1, 0, 0);
    }

    PrintConnections();

    // Close all connections
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        rakPeer[peerIndex]->Shutdown(100);
    }

    // Reinitialize the peers
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        SocketDescriptor socketDescriptor(60000+peerIndex,0);
        rakPeer[peerIndex]->Startup(NUM_PEERS, 0,&socketDescriptor, 1 );
    }

    printf("Connecting each peer to a central peer.\n");
    // Connect each peer to a central peer
    for (peerIndex=1; peerIndex < NUM_PEERS; peerIndex++)
    {
        rakPeer[peerIndex]->Connect("127.0.0.1", 60000, 0, 0);
    }

    PrintConnections();

    // Close all connections
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        rakPeer[peerIndex]->Shutdown(100);
    }

    // Reinitialize the peers
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        SocketDescriptor socketDescriptor(60000+peerIndex,0);
        rakPeer[peerIndex]->Startup(NUM_PEERS, 0, &socketDescriptor, 1);
    }

    printf("Cross connecting each pair of peers, then first and last peer.\n");
    // Connect each peer to a central peer
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        rakPeer[peerIndex]->Connect("127.0.0.1", 60000+peerIndex+(((peerIndex%2)==0) ? 1 : -1), 0, 0);
    }

    printf("Pairs Connected\n");
    PrintConnections();
    rakPeer[0]->Connect("127.0.0.1", 60000+NUM_PEERS-1, 0, 0);
    printf("First and last connected\n");
    PrintConnections();

    // Close all connections
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        rakPeer[peerIndex]->Shutdown(100);
    }

    // Reinitialize the peers
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        SocketDescriptor socketDescriptor(60000+peerIndex,0);
        rakPeer[peerIndex]->Startup(NUM_PEERS, 0, &socketDescriptor, 1);
    }


    unsigned int seed = (unsigned int) RakNet::GetTime();
    seedMT(seed);
    printf("Connecting each peer to a random peer with seed %u.\n", seed);
    int connectTo=0;
    // Connect each peer to a central peer
    for (peerIndex=0; peerIndex < NUM_PEERS; peerIndex++)
    {
        do
        {
            connectTo=randomMT() % NUM_PEERS;
        } while (connectTo==peerIndex);

        rakPeer[peerIndex]->Connect("127.0.0.1", 60000+connectTo, 0, 0);
    }

    PrintConnections();

    for (i=0; i < NUM_PEERS; i++)
        RakNetworkFactory::DestroyRakPeerInterface(rakPeer[i]);

    return 1;
}
void inicializar_seed(long seed){
    seedMT(seed);
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) 
{
	double *srwd, *srad, *MUZIN, *MUXIN, *theta, *phi, *thetad, *muz, *mux;
	double ALPHA,BETA;	
	int W, J, D, A, MA = 0, NN, SEED, OUTPUT, nzmaxwd, nzmaxad, i, j, a, startcond;
	mwIndex *irwd, *jcwd, *irad, *jcad;

	/* Check for proper number of arguments. */
	if (nrhs < 8) {
		mexErrMsgTxt("At least 8 input arguments required");
	} else if (nlhs < 1) {
		mexErrMsgTxt("At least 1 output arguments required");
	}

	startcond = 0;
	if (nrhs > 8) startcond = 1;

	/* dealing with sparse array WD */
	if (mxIsDouble(prhs[0]) != 1) mexErrMsgTxt("WD must be a double precision matrix");
	srwd = mxGetPr(prhs[0]);
	irwd = mxGetIr(prhs[0]);
	jcwd = mxGetJc(prhs[0]);
	nzmaxwd = (int) mxGetNzmax(prhs[0]);
	W = (int) mxGetM(prhs[0]);
	D = (int) mxGetN(prhs[0]);

	/* dealing with sparse array AD */
	if (mxIsDouble(prhs[1]) != 1) mexErrMsgTxt("AD must be a double precision matrix");
	srad = mxGetPr(prhs[1]);
	irad = mxGetIr(prhs[1]);
	jcad = mxGetJc(prhs[1]);
	nzmaxad = (int) mxGetNzmax(prhs[1]);
	A = (int) mxGetM(prhs[1]);
	if ((int) mxGetN(prhs[1]) != D) mexErrMsgTxt("WD and AD must have the same number of columns");

	/* check that every document has some authors */
	for (i=0; i<D; i++) {
		if ((jcad[i + 1] - jcad[i]) == 0) mexErrMsgTxt("there are some documents without authors in AD matrix ");
		if ((jcad[i + 1] - jcad[i]) > NAMAX) mexErrMsgTxt("Too many authors in some documents ... reached the NAMAX limit");
		if ((jcad[i + 1] - jcad[i]) > MA) MA = (int) (jcad[i + 1] - jcad[i]);
	}

	phi = mxGetPr(prhs[2]);
	J = (int) mxGetM(prhs[2]);
	if (J<=0) mexErrMsgTxt("Number of topics must be greater than zero");
	if ((int) mxGetN(prhs[2]) != W) mexErrMsgTxt("Vocabulary mismatches");

	NN = (int) mxGetScalar(prhs[3]);
	if (NN<0) mexErrMsgTxt("Number of iterations must be greater than zero");

	ALPHA = (double) mxGetScalar(prhs[4]);
	if (ALPHA<0) mexErrMsgTxt("ALPHA must be greater than zero");

	BETA = (double) mxGetScalar(prhs[5]);
	if (BETA<0) mexErrMsgTxt("BETA must be greater than zero");

	SEED = (int) mxGetScalar(prhs[6]);
	// set the seed of the random number generator

	OUTPUT = (int) mxGetScalar(prhs[7]);

	if (startcond == 1) {
		MUZIN = mxGetPr(prhs[8]);
		if (nzmaxwd != mxGetN(prhs[8])) mexErrMsgTxt("WD and MUZIN mismatch");
		if (J != mxGetM( prhs[ 8 ])) mexErrMsgTxt("J and MUZIN mismatch");
		MUXIN = mxGetPr(prhs[9]);
		if (nzmaxwd != mxGetN( prhs[9])) mexErrMsgTxt("WD and MUXIN mismatch");
		if (MA != mxGetM(prhs[9])) mexErrMsgTxt("MA and MUXIN mismatch");
	}

	// seeding
	seedMT( 1 + SEED * 2 ); // seeding only works on uneven numbers

	/* allocate memory */
	muz  = dvec(J*nzmaxwd);
	mux  = dvec(MA*nzmaxwd);

	if (startcond == 1) {
		for (i=0; i<J*nzmaxwd; i++) muz[i] = (double) MUZIN[i]; 
		for (a=0; a<MA*nzmaxwd; a++) mux[i] = (double) MUXIN[i];
	}

	theta = dvec(J*A);

	/* run the model */
	ATMBP( ALPHA, BETA, W, J, D, A, MA, NN, OUTPUT, irwd, jcwd, srwd, irad, jcad, muz, mux, phi, theta, startcond );

	/* output */
	plhs[0] = mxCreateDoubleMatrix(J, A, mxREAL);
	mxSetPr(plhs[0], theta);

	plhs[1] = mxCreateDoubleMatrix(J, nzmaxwd, mxREAL);
	mxSetPr(plhs[1], muz);

	plhs[2] = mxCreateDoubleMatrix(MA, nzmaxwd, mxREAL);
	mxSetPr(plhs[2], mux);
}
Exemplo n.º 28
0
int ComprehensiveConvertTest::RunTest(DataStructures::List<RakNet::RakString> params,bool isVerbose,bool noPauses)
{

	static const int CONNECTIONS_PER_SYSTEM =4;

	SystemAddress currentSystem;

	//	DebugTools::ShowError("Note: The conversion of this is on hold until the original sample's problem is known.",!noPauses && isVerbose,__LINE__,__FILE__);

	//	return 55;


	//	AutoRPC autoRpcs[NUM_PEERS];




	//AutoRPC autoRpcs[NUM_PEERS];

	
	int peerIndex;
	float nextAction;
	int i;
	int portAdd;


	char data[8096];

	int seed = 12345;
	if (isVerbose)
		printf("Using seed %i\n", seed);
	seedMT(seed);

	for (i=0; i < NUM_PEERS; i++)
	{

		//autoRpcs[i].RegisterFunction("RPC1", RPC1, false);
		//autoRpcs[i].RegisterFunction("RPC2", RPC2, false);
		//autoRpcs[i].RegisterFunction("RPC3", RPC3, false);
		//autoRpcs[i].RegisterFunction("RPC4", RPC4, false);
		peers[i]=RakNetworkFactory::GetRakPeerInterface();
		peers[i]->SetMaximumIncomingConnections(CONNECTIONS_PER_SYSTEM);
		SocketDescriptor socketDescriptor(60000+i, 0);
		peers[i]->Startup(NUM_PEERS, 0, &socketDescriptor, 1);
		peers[i]->SetOfflinePingResponse("Offline Ping Data", (int)strlen("Offline Ping Data")+1);
		peers[i]->ApplyNetworkSimulator(500,50,50);



		//		peers[i]->AttachPlugin(&autoRpc[i]);



	}

	for (i=0; i < NUM_PEERS; i++)
	{

		portAdd=randomMT()%NUM_PEERS;

		currentSystem.SetBinaryAddress("127.0.0.1");
		currentSystem.port=60000+portAdd;
		if(!peers[i]->IsConnected (currentSystem,true,true) )//Are we connected or is there a pending operation ?
		{
			if (!peers[i]->Connect("127.0.0.1", 60000+portAdd, 0, 0))
			{
				DebugTools::ShowError("Problem while calling connect.\n",!noPauses && isVerbose,__LINE__,__FILE__);
				return 1;

			}

		}

	}

	RakNetTime endTime = RakNet::GetTime()+10000;
	while (RakNet::GetTime()<endTime)
	{
		nextAction = frandomMT();

		if (nextAction < .04f)
		{
			// Initialize
			peerIndex=randomMT()%NUM_PEERS;
			SocketDescriptor socketDescriptor(60000+peerIndex, 0);
			peers[peerIndex]->Startup(NUM_PEERS, randomMT()%30, &socketDescriptor, 1);
			portAdd=randomMT()%NUM_PEERS;

			currentSystem.SetBinaryAddress("127.0.0.1");
			currentSystem.port=60000+portAdd;
			if(!peers[peerIndex]->IsConnected (currentSystem,true,true) )//Are we connected or is there a pending operation ?
			{
				if(!peers[peerIndex]->Connect("127.0.0.1", 60000+portAdd, 0, 0))
				{
					DebugTools::ShowError("Problem while calling connect.\n",!noPauses && isVerbose,__LINE__,__FILE__);
					return 1;

				}

			}
		}
		else if (nextAction < .09f)
		{
			// Connect
			peerIndex=randomMT()%NUM_PEERS;
			portAdd=randomMT()%NUM_PEERS;

			currentSystem.SetBinaryAddress("127.0.0.1");
			currentSystem.port=60000+portAdd;
			if(!peers[peerIndex]->IsConnected (currentSystem,true,true) )//Are we connected or is there a pending operation ?
			{
				if (!peers[peerIndex]->Connect("127.0.0.1", 60000+portAdd, 0, 0))
				{
					DebugTools::ShowError("Problem while calling connect.\n",!noPauses && isVerbose,__LINE__,__FILE__);
					return 1;

				}
			}
		}
		else if (nextAction < .10f)
		{
			// Disconnect
			peerIndex=randomMT()%NUM_PEERS;
			//	peers[peerIndex]->Shutdown(randomMT() % 100);
		}
		else if (nextAction < .12f)
		{
			// GetConnectionList
			peerIndex=randomMT()%NUM_PEERS;
			SystemAddress remoteSystems[NUM_PEERS];
			unsigned short numSystems=NUM_PEERS;
			peers[peerIndex]->GetConnectionList(remoteSystems, &numSystems);
			if (numSystems>0)
			{
				if (isVerbose){
					printf("%i: ", 60000+numSystems);
					for (i=0; i < numSystems; i++)
					{
						printf("%i: ", remoteSystems[i].port);
					}
					printf("\n");
				}
			}			
		}
		else if (nextAction < .14f)
		{
			// Send
			int dataLength;
			PacketPriority priority;
			PacketReliability reliability;
			unsigned char orderingChannel;
			SystemAddress target;
			bool broadcast;

			//	data[0]=ID_RESERVED1+(randomMT()%10);
			data[0]=ID_USER_PACKET_ENUM;
			dataLength=3+(randomMT()%8000);
			//			dataLength=600+(randomMT()%7000);
			priority=(PacketPriority)(randomMT()%(int)NUMBER_OF_PRIORITIES);
			reliability=(PacketReliability)(randomMT()%((int)RELIABLE_SEQUENCED+1));
			orderingChannel=randomMT()%32;
			if ((randomMT()%NUM_PEERS)==0)
				target=UNASSIGNED_SYSTEM_ADDRESS;
			else
				target=peers[peerIndex]->GetSystemAddressFromIndex(randomMT()%NUM_PEERS);

			broadcast=(bool)(randomMT()%2);
#ifdef _VERIFY_RECIPIENTS
			broadcast=false; // Temporarily in so I can check recipients
#endif

			peerIndex=randomMT()%NUM_PEERS;
			sprintf(data+3, "dataLength=%i priority=%i reliability=%i orderingChannel=%i target=%i broadcast=%i\n", dataLength, priority, reliability, orderingChannel, target.port, broadcast);
			//unsigned short localPort=60000+i;
#ifdef _VERIFY_RECIPIENTS
			memcpy((char*)data+1, (char*)&target.port, sizeof(unsigned short));
#endif
			data[dataLength-1]=0;
			peers[peerIndex]->Send(data, dataLength, priority, reliability, orderingChannel, target, broadcast);
		}
		else if (nextAction < .18f)
		{
			// RPC
			int dataLength;
			PacketPriority priority;
			PacketReliability reliability;
			unsigned char orderingChannel;
			SystemAddress target;
			bool broadcast;
			char RPCName[10];

			data[0]=ID_USER_PACKET_ENUM+(randomMT()%10);
			dataLength=3+(randomMT()%8000);
			//			dataLength=600+(randomMT()%7000);
			priority=(PacketPriority)(randomMT()%(int)NUMBER_OF_PRIORITIES);
			reliability=(PacketReliability)(randomMT()%((int)RELIABLE_SEQUENCED+1));
			orderingChannel=randomMT()%32;
			peerIndex=randomMT()%NUM_PEERS;
			if ((randomMT()%NUM_PEERS)==0)
				target=UNASSIGNED_SYSTEM_ADDRESS;
			else
				target=peers[peerIndex]->GetSystemAddressFromIndex(randomMT()%NUM_PEERS);
			broadcast=(bool)(randomMT()%2);
#ifdef _VERIFY_RECIPIENTS
			broadcast=false; // Temporarily in so I can check recipients
#endif

			sprintf(data+3, "dataLength=%i priority=%i reliability=%i orderingChannel=%i target=%i broadcast=%i\n", dataLength, priority, reliability, orderingChannel, target.port, broadcast);
#ifdef _VERIFY_RECIPIENTS
			memcpy((char*)data, (char*)&target.port, sizeof(unsigned short));
#endif
			data[dataLength-1]=0;
			sprintf(RPCName, "RPC%i", (randomMT()%4)+1);
			//				autoRpc[i]->Call(RPCName);
			//peers[peerIndex]->RPC(RPCName, data, dataLength*8, priority, reliability, orderingChannel, target, broadcast, 0, UNASSIGNED_NETWORK_ID,0);
		}
		else if (nextAction < .181f)
		{
			// CloseConnection
			SystemAddress target;
			peerIndex=randomMT()%NUM_PEERS;
			target=peers[peerIndex]->GetSystemAddressFromIndex(randomMT()%NUM_PEERS);
			peers[peerIndex]->CloseConnection(target, (bool)(randomMT()%2), 0);
		}
		else if (nextAction < .20f)
		{
			// Offline Ping
			peerIndex=randomMT()%NUM_PEERS;
			peers[peerIndex]->Ping("127.0.0.1", 60000+(randomMT()%NUM_PEERS), (bool)(randomMT()%2));
		}
		else if (nextAction < .21f)
		{
			// Online Ping
			SystemAddress target;
			target=peers[peerIndex]->GetSystemAddressFromIndex(randomMT()%NUM_PEERS);
			peerIndex=randomMT()%NUM_PEERS;
			peers[peerIndex]->Ping(target);
		}
		else if (nextAction < .24f)
		{
			// SetCompileFrequencyTable
			peerIndex=randomMT()%NUM_PEERS;
			peers[peerIndex]->SetCompileFrequencyTable(randomMT()%2);
		}
		else if (nextAction < .25f)
		{
			// GetStatistics
			SystemAddress target, mySystemAddress;
			RakNetStatistics *rss;
			mySystemAddress=peers[peerIndex]->GetInternalID();
			target=peers[peerIndex]->GetSystemAddressFromIndex(randomMT()%NUM_PEERS);
			peerIndex=randomMT()%NUM_PEERS;
			rss=peers[peerIndex]->GetStatistics(mySystemAddress);
			if (rss)
			{
				StatisticsToString(rss, data, 0);
				if (isVerbose)
					printf("Statistics for local system %i:\n%s", mySystemAddress.port, data);

			}

			rss=peers[peerIndex]->GetStatistics(target);
			if (rss)
			{
				StatisticsToString(rss, data, 0);
				if (isVerbose)
					printf("Statistics for target system %i:\n%s", target.port, data);

			}			
		}

		for (i=0; i < NUM_PEERS; i++)
			peers[i]->DeallocatePacket(peers[i]->Receive());


		RakSleep(0);

	}


	

	return 0;

}
Exemplo n.º 29
0
void classRF(double *x, int *dimx, int *cl, int *ncl, int *cat, int *maxcat,
        int *sampsize, int *strata, int *Options, int *ntree, int *nvar,
        int *ipi, double *classwt, double *cut, int *nodesize,
        int *outcl, int *counttr, double *prox,
        double *imprt, double *impsd, double *impmat, int *nrnodes,
        int *ndbigtree, int *nodestatus, int *bestvar, int *treemap,
        int *nodeclass, double *xbestsplit, double *errtr,
        int *testdat, double *xts, int *clts, int *nts, double *countts,
        int *outclts, int labelts, double *proxts, double *errts,
        int *inbag) {
    /******************************************************************
     *  C wrapper for random forests:  get input from R and drive
     *  the Fortran routines.
     *
     *  Input:
     *
     *  x:        matrix of predictors (transposed!)
     *  dimx:     two integers: number of variables and number of cases
     *  cl:       class labels of the data
     *  ncl:      number of classes in the responsema
     *  cat:      integer vector of number of classes in the predictor;
     *            1=continuous
     * maxcat:    maximum of cat
     * Options:   7 integers: (0=no, 1=yes)
     *     add a second class (for unsupervised RF)?
     *         1: sampling from product of marginals
     *         2: sampling from product of uniforms
     *     assess variable importance?
     *     calculate proximity?
     *     calculate proximity based on OOB predictions?
     *     calculate outlying measure?
     *     how often to print output?
     *     keep the forest for future prediction?
     *  ntree:    number of trees
     *  nvar:     number of predictors to use for each split
     *  ipi:      0=use class proportion as prob.; 1=use supplied priors
     *  pi:       double vector of class priors
     *  nodesize: minimum node size: no node with fewer than ndsize
     *            cases will be split
     *
     *  Output:
     *
     *  outcl:    class predicted by RF
     *  counttr:  matrix of votes (transposed!)
     *  imprt:    matrix of variable importance measures
     *  impmat:   matrix of local variable importance measures
     *  prox:     matrix of proximity (if iprox=1)
     ******************************************************************/
    
    int nsample0, mdim, nclass, addClass, mtry, ntest, nsample, ndsize,
            mimp, nimp, near, nuse, noutall, nrightall, nrightimpall,
            keepInbag, nstrata;
    int jb, j, n, m, k, idxByNnode, idxByNsample, imp, localImp, iprox,
            oobprox, keepf, replace, stratify, trace, *nright,
            *nrightimp, *nout, *nclts, Ntree;
    
    int *out, *bestsplitnext, *bestsplit, *nodepop, *jin, *nodex,
            *nodexts, *nodestart, *ta, *ncase, *jerr, *varUsed,
            *jtr, *classFreq, *idmove, *jvr,
            *at, *a, *b, *mind, *nind, *jts, *oobpair;
    int **strata_idx, *strata_size, last, ktmp, anyEmpty, ntry;
    
    double av=0.0;
    
    double *tgini, *tx, *wl, *classpop, *tclasscat, *tclasspop, *win,
            *tp, *wr;
    
    //Do initialization for COKUS's Random generator
    seedMT(2*rand()+1);  //works well with odd number so why don't use that
    
    addClass = Options[0];
    imp      = Options[1];
    localImp = Options[2];
    iprox    = Options[3];
    oobprox  = Options[4];
    trace    = Options[5];
    keepf    = Options[6];
    replace  = Options[7];
    stratify = Options[8];
    keepInbag = Options[9];
    mdim     = dimx[0];
    nsample0 = dimx[1];
    nclass   = (*ncl==1) ? 2 : *ncl;
    ndsize   = *nodesize;
    Ntree    = *ntree;
    mtry     = *nvar;
    ntest    = *nts;
    nsample = addClass ? (nsample0 + nsample0) : nsample0;
    mimp = imp ? mdim : 1;
    nimp = imp ? nsample : 1;
    near = iprox ? nsample0 : 1;
    if (trace == 0) trace = Ntree + 1;
    
    /*printf("\nmdim %d, nclass %d, nrnodes %d, nsample %d, ntest %d\n", mdim, nclass, *nrnodes, nsample, ntest);
    printf("\noobprox %d, mdim %d, nsample0 %d, Ntree %d, mtry %d, mimp %d", oobprox, mdim, nsample0, Ntree, mtry, mimp);
    printf("\nstratify %d, replace %d",stratify,replace);
    printf("\n");*/
    tgini =      (double *) S_alloc_alt(mdim, sizeof(double));
    wl =         (double *) S_alloc_alt(nclass, sizeof(double));
    wr =         (double *) S_alloc_alt(nclass, sizeof(double));
    classpop =   (double *) S_alloc_alt(nclass* *nrnodes, sizeof(double));
    tclasscat =  (double *) S_alloc_alt(nclass*32, sizeof(double));
    tclasspop =  (double *) S_alloc_alt(nclass, sizeof(double));
    tx =         (double *) S_alloc_alt(nsample, sizeof(double));
    win =        (double *) S_alloc_alt(nsample, sizeof(double));
    tp =         (double *) S_alloc_alt(nsample, sizeof(double));
    
    out =           (int *) S_alloc_alt(nsample, sizeof(int));
    bestsplitnext = (int *) S_alloc_alt(*nrnodes, sizeof(int));
    bestsplit =     (int *) S_alloc_alt(*nrnodes, sizeof(int));
    nodepop =       (int *) S_alloc_alt(*nrnodes, sizeof(int));
    nodestart =     (int *) S_alloc_alt(*nrnodes, sizeof(int));
    jin =           (int *) S_alloc_alt(nsample, sizeof(int));
    nodex =         (int *) S_alloc_alt(nsample, sizeof(int));
    nodexts =       (int *) S_alloc_alt(ntest, sizeof(int));
    ta =            (int *) S_alloc_alt(nsample, sizeof(int));
    ncase =         (int *) S_alloc_alt(nsample, sizeof(int));
    jerr =          (int *) S_alloc_alt(nsample, sizeof(int));
    varUsed =       (int *) S_alloc_alt(mdim, sizeof(int));
    jtr =           (int *) S_alloc_alt(nsample, sizeof(int));
    jvr =           (int *) S_alloc_alt(nsample, sizeof(int));
    classFreq =     (int *) S_alloc_alt(nclass, sizeof(int));
    jts =           (int *) S_alloc_alt(ntest, sizeof(int));
    idmove =        (int *) S_alloc_alt(nsample, sizeof(int));
    at =            (int *) S_alloc_alt(mdim*nsample, sizeof(int));
    a =             (int *) S_alloc_alt(mdim*nsample, sizeof(int));
    b =             (int *) S_alloc_alt(mdim*nsample, sizeof(int));
    mind =          (int *) S_alloc_alt(mdim, sizeof(int));
    nright =        (int *) S_alloc_alt(nclass, sizeof(int));
    nrightimp =     (int *) S_alloc_alt(nclass, sizeof(int));
    nout =          (int *) S_alloc_alt(nclass, sizeof(int));
    if (oobprox) {
        oobpair = (int *) S_alloc_alt(near*near, sizeof(int));
    }
    //printf("nsample=%d\n", nsample);
    /* Count number of cases in each class. */
    zeroInt(classFreq, nclass);
    for (n = 0; n < nsample; ++n) classFreq[cl[n] - 1] ++;
    /* Normalize class weights. */
    //Rprintf("ipi %d ",*ipi);
    //for(n=0;n<nclass;n++) Rprintf("%d: %d, %f,",n,classFreq[n],classwt[n]);
    normClassWt(cl, nsample, nclass, *ipi, classwt, classFreq);
    //for(n=0;n<nclass;n++) Rprintf("%d: %d, %f,",n,classFreq[n],classwt[n]);
   
    if (stratify) {
        /* Count number of strata and frequency of each stratum. */
        nstrata = 0;
        for (n = 0; n < nsample0; ++n)
            if (strata[n] > nstrata) nstrata = strata[n];
        /* Create the array of pointers, each pointing to a vector
         * of indices of where data of each stratum is. */
        strata_size = (int  *) S_alloc_alt(nstrata, sizeof(int));
        for (n = 0; n < nsample0; ++n) {
            strata_size[strata[n] - 1] ++;
        }
        strata_idx =  (int **) S_alloc_alt(nstrata, sizeof(int *));
        for (n = 0; n < nstrata; ++n) {
            strata_idx[n] = (int *) S_alloc_alt(strata_size[n], sizeof(int));
        }
        zeroInt(strata_size, nstrata);
        for (n = 0; n < nsample0; ++n) {
            strata_size[strata[n] - 1] ++;
            strata_idx[strata[n] - 1][strata_size[strata[n] - 1] - 1] = n;
        }
    } else {
        nind = replace ? NULL : (int *) S_alloc_alt(nsample, sizeof(int));
    }
    
    /*    INITIALIZE FOR RUN */
    if (*testdat) zeroDouble(countts, ntest * nclass);
    zeroInt(counttr, nclass * nsample);
    zeroInt(out, nsample);
    zeroDouble(tgini, mdim);
    zeroDouble(errtr, (nclass + 1) * Ntree);
    
    if (labelts) {
        nclts  = (int *) S_alloc_alt(nclass, sizeof(int));
        for (n = 0; n < ntest; ++n) nclts[clts[n]-1]++;
        zeroDouble(errts, (nclass + 1) * Ntree);
    }
    //printf("labelts %d\n",labelts);fflush(stdout);
    if (imp) {
        zeroDouble(imprt, (nclass+2) * mdim);
        zeroDouble(impsd, (nclass+1) * mdim);
        if (localImp) zeroDouble(impmat, nsample * mdim);
    }
    if (iprox) {
        zeroDouble(prox, nsample0 * nsample0);
        if (*testdat) zeroDouble(proxts, ntest * (ntest + nsample0));
    }
    makeA(x, mdim, nsample, cat, at, b);
    
    //R_CheckUserInterrupt();
    
    
    /* Starting the main loop over number of trees. */
    GetRNGstate();
    if (trace <= Ntree) {
        /* Print header for running output. */
        Rprintf("ntree      OOB");
        for (n = 1; n <= nclass; ++n) Rprintf("%7i", n);
        if (labelts) {
            Rprintf("|    Test");
            for (n = 1; n <= nclass; ++n) Rprintf("%7i", n);
        }
        Rprintf("\n");
    }
    idxByNnode = 0;
    idxByNsample = 0;
    
    //Rprintf("addclass %d, ntree %d, cl[300]=%d", addClass,Ntree,cl[299]);
    for(jb = 0; jb < Ntree; jb++) {
		//Rprintf("addclass %d, ntree %d, cl[300]=%d", addClass,Ntree,cl[299]);
        //printf("jb=%d,\n",jb);
        /* Do we need to simulate data for the second class? */
        if (addClass) createClass(x, nsample0, nsample, mdim);
        do {
            zeroInt(nodestatus + idxByNnode, *nrnodes);
            zeroInt(treemap + 2*idxByNnode, 2 * *nrnodes);
            zeroDouble(xbestsplit + idxByNnode, *nrnodes);
            zeroInt(nodeclass + idxByNnode, *nrnodes);
            zeroInt(varUsed, mdim);
            /* TODO: Put all sampling code into a function. */
            /* drawSample(sampsize, nsample, ); */
            if (stratify) {  /* stratified sampling */
                zeroInt(jin, nsample);
                zeroDouble(tclasspop, nclass);
                zeroDouble(win, nsample);
                if (replace) {  /* with replacement */
                    for (n = 0; n < nstrata; ++n) {
                        for (j = 0; j < sampsize[n]; ++j) {
                            ktmp = (int) (unif_rand() * strata_size[n]);
                            k = strata_idx[n][ktmp];
                            tclasspop[cl[k] - 1] += classwt[cl[k] - 1];
                            win[k] += classwt[cl[k] - 1];
                            jin[k] = 1;
                        }
                    }
                } else { /* stratified sampling w/o replacement */
                    /* re-initialize the index array */
                    zeroInt(strata_size, nstrata);
                    for (j = 0; j < nsample; ++j) {
                        strata_size[strata[j] - 1] ++;
                        strata_idx[strata[j] - 1][strata_size[strata[j] - 1] - 1] = j;
                    }
                    /* sampling without replacement */
                    for (n = 0; n < nstrata; ++n) {
                        last = strata_size[n] - 1;
                        for (j = 0; j < sampsize[n]; ++j) {
                            ktmp = (int) (unif_rand() * (last+1));
                            k = strata_idx[n][ktmp];
                            swapInt(strata_idx[n][last], strata_idx[n][ktmp]);
                            last--;
                            tclasspop[cl[k] - 1] += classwt[cl[k]-1];
                            win[k] += classwt[cl[k]-1];
                            jin[k] = 1;
                        }
                    }
                }
            } else {  /* unstratified sampling */
                anyEmpty = 0;
                ntry = 0;
                do {
                    zeroInt(jin, nsample);
                    zeroDouble(tclasspop, nclass);
                    zeroDouble(win, nsample);
                    if (replace) {
                        for (n = 0; n < *sampsize; ++n) {
                            k = unif_rand() * nsample;
                            tclasspop[cl[k] - 1] += classwt[cl[k]-1];
                            win[k] += classwt[cl[k]-1];
                            jin[k] = 1;
                        }
                    } else {
                        for (n = 0; n < nsample; ++n) nind[n] = n;
                        last = nsample - 1;
                        for (n = 0; n < *sampsize; ++n) {
                            ktmp = (int) (unif_rand() * (last+1));
                            k = nind[ktmp];
                            swapInt(nind[ktmp], nind[last]);
                            last--;
                            tclasspop[cl[k] - 1] += classwt[cl[k]-1];
                            win[k] += classwt[cl[k]-1];
                            jin[k] = 1;
                        }
                    }
                    /* check if any class is missing in the sample */
                    for (n = 0; n < nclass; ++n) {
                        if (tclasspop[n] == 0) anyEmpty = 1;
                    }
                    ntry++;
                } while (anyEmpty && ntry <= 10);
            }
            
            /* If need to keep indices of inbag data, do that here. */
            if (keepInbag) {
                for (n = 0; n < nsample0; ++n) {
                    inbag[n + idxByNsample] = jin[n];
                }
            }
            
            /* Copy the original a matrix back. */
            memcpy(a, at, sizeof(int) * mdim * nsample);
            modA(a, &nuse, nsample, mdim, cat, *maxcat, ncase, jin);
            
            #ifdef WIN64
            F77_CALL(_buildtree)
            #endif
                    
            #ifndef WIN64
            F77_CALL(buildtree)
            #endif        
            (a, b, cl, cat, maxcat, &mdim, &nsample,
                    &nclass,
                    treemap + 2*idxByNnode, bestvar + idxByNnode,
                    bestsplit, bestsplitnext, tgini,
                    nodestatus + idxByNnode, nodepop,
                    nodestart, classpop, tclasspop, tclasscat,
                    ta, nrnodes, idmove, &ndsize, ncase,
                    &mtry, varUsed, nodeclass + idxByNnode,
                    ndbigtree + jb, win, wr, wl, &mdim,
                    &nuse, mind);
            /* if the "tree" has only the root node, start over */
        } while (ndbigtree[jb] == 1);
        
        Xtranslate(x, mdim, *nrnodes, nsample, bestvar + idxByNnode,
                bestsplit, bestsplitnext, xbestsplit + idxByNnode,
                nodestatus + idxByNnode, cat, ndbigtree[jb]);
        
        /*  Get test set error */
        if (*testdat) {
            predictClassTree(xts, ntest, mdim, treemap + 2*idxByNnode,
                    nodestatus + idxByNnode, xbestsplit + idxByNnode,
                    bestvar + idxByNnode,
                    nodeclass + idxByNnode, ndbigtree[jb],
                    cat, nclass, jts, nodexts, *maxcat);
            TestSetError(countts, jts, clts, outclts, ntest, nclass, jb+1,
                    errts + jb*(nclass+1), labelts, nclts, cut);
        }
        
        /*  Get out-of-bag predictions and errors. */
        predictClassTree(x, nsample, mdim, treemap + 2*idxByNnode,
                nodestatus + idxByNnode, xbestsplit + idxByNnode,
                bestvar + idxByNnode,
                nodeclass + idxByNnode, ndbigtree[jb],
                cat, nclass, jtr, nodex, *maxcat);
        
        zeroInt(nout, nclass);
        noutall = 0;
        for (n = 0; n < nsample; ++n) {
            if (jin[n] == 0) {
                /* increment the OOB votes */
                counttr[n*nclass + jtr[n] - 1] ++;
                /* count number of times a case is OOB */
                out[n]++;
                /* count number of OOB cases in the current iteration.
                 * nout[n] is the number of OOB cases for the n-th class.
                 * noutall is the number of OOB cases overall. */
                nout[cl[n] - 1]++;
                noutall++;
            }
        }
        
        /* Compute out-of-bag error rate. */
        oob(nsample, nclass, jin, cl, jtr, jerr, counttr, out,
                errtr + jb*(nclass+1), outcl, cut);
        
        if ((jb+1) % trace == 0) {
            Rprintf("%5i: %6.2f%%", jb+1, 100.0*errtr[jb * (nclass+1)]);
            for (n = 1; n <= nclass; ++n) {
                Rprintf("%6.2f%%", 100.0 * errtr[n + jb * (nclass+1)]);
            }
            if (labelts) {
                Rprintf("| ");
                for (n = 0; n <= nclass; ++n) {
                    Rprintf("%6.2f%%", 100.0 * errts[n + jb * (nclass+1)]);
                }
            }
            Rprintf("\n");
            
            //R_CheckUserInterrupt();
        }
        
        /*  DO VARIABLE IMPORTANCE  */
        if (imp) {
            nrightall = 0;
            /* Count the number of correct prediction by the current tree
             * among the OOB samples, by class. */
            zeroInt(nright, nclass);
            for (n = 0; n < nsample; ++n) {
                /* out-of-bag and predicted correctly: */
                if (jin[n] == 0 && jtr[n] == cl[n]) {
                    nright[cl[n] - 1]++;
                    nrightall++;
                }
            }
            for (m = 0; m < mdim; ++m) {
                if (varUsed[m]) {
                    nrightimpall = 0;
                    zeroInt(nrightimp, nclass);
                    for (n = 0; n < nsample; ++n) tx[n] = x[m + n*mdim];
                    /* Permute the m-th variable. */
                    permuteOOB(m, x, jin, nsample, mdim);
                    /* Predict the modified data using the current tree. */
                    predictClassTree(x, nsample, mdim, treemap + 2*idxByNnode,
                            nodestatus + idxByNnode,
                            xbestsplit + idxByNnode,
                            bestvar + idxByNnode,
                            nodeclass + idxByNnode, ndbigtree[jb],
                            cat, nclass, jvr, nodex, *maxcat);
                    /* Count how often correct predictions are made with
                     * the modified data. */
                    for (n = 0; n < nsample; n++) {
                        if (jin[n] == 0) {
                            if (jvr[n] == cl[n]) {
                                nrightimp[cl[n] - 1]++;
                                nrightimpall++;
                            }
                            if (localImp && jvr[n] != jtr[n]) {
                                if (cl[n] == jvr[n]) {
                                    impmat[m + n*mdim] -= 1.0;
                                } else {
                                    impmat[m + n*mdim] += 1.0;
                                }
                            }
                        }
                        /* Restore the original data for that variable. */
                        x[m + n*mdim] = tx[n];
                    }
                    /* Accumulate decrease in proportions of correct
                     * predictions. */
                    for (n = 0; n < nclass; ++n) {
                        if (nout[n] > 0) {
                            imprt[m + n*mdim] +=
                                    ((double) (nright[n] - nrightimp[n])) /
                                    nout[n];
                            impsd[m + n*mdim] +=
                                    ((double) (nright[n] - nrightimp[n]) *
                                    (nright[n] - nrightimp[n])) / nout[n];
                        }
                    }
                    if (noutall > 0) {
                        imprt[m + nclass*mdim] +=
                                ((double)(nrightall - nrightimpall)) / noutall;
                        impsd[m + nclass*mdim] +=
                                ((double) (nrightall - nrightimpall) *
                                (nrightall - nrightimpall)) / noutall;
                    }
                }
            }
        }
        
        /*  DO PROXIMITIES */
        if (iprox) {
            computeProximity(prox, oobprox, nodex, jin, oobpair, near);
            /* proximity for test data */
            if (*testdat) {
                computeProximity(proxts, 0, nodexts, jin, oobpair, ntest);
                /* Compute proximity between testset and training set. */
                for (n = 0; n < ntest; ++n) {
                    for (k = 0; k < near; ++k) {
                        if (nodexts[n] == nodex[k])
                            proxts[n + ntest * (k+ntest)] += 1.0;
                    }
                }
            }
        }
        
        if (keepf) idxByNnode += *nrnodes;
        if (keepInbag) idxByNsample += nsample0;
    }
    PutRNGstate();
   
    
    /*  Final processing of variable importance. */
    for (m = 0; m < mdim; m++) tgini[m] /= Ntree;
      
    if (imp) {
        for (m = 0; m < mdim; ++m) {
            if (localImp) { /* casewise measures */
                for (n = 0; n < nsample; ++n) impmat[m + n*mdim] /= out[n];
            }
            /* class-specific measures */
            for (k = 0; k < nclass; ++k) {
                av = imprt[m + k*mdim] / Ntree;
                impsd[m + k*mdim] =
                        sqrt(((impsd[m + k*mdim] / Ntree) - av*av) / Ntree);
                imprt[m + k*mdim] = av;
                /* imprt[m + k*mdim] = (se <= 0.0) ? -1000.0 - av : av / se; */
            }
            /* overall measures */
            av = imprt[m + nclass*mdim] / Ntree;
            impsd[m + nclass*mdim] =
                    sqrt(((impsd[m + nclass*mdim] / Ntree) - av*av) / Ntree);
            imprt[m + nclass*mdim] = av;
            imprt[m + (nclass+1)*mdim] = tgini[m];
        }
    } else {
        for (m = 0; m < mdim; ++m) imprt[m] = tgini[m];
    }
   
    /*  PROXIMITY DATA ++++++++++++++++++++++++++++++++*/
    if (iprox) {
        for (n = 0; n < near; ++n) {
            for (k = n + 1; k < near; ++k) {
                prox[near*k + n] /= oobprox ?
                    (oobpair[near*k + n] > 0 ? oobpair[near*k + n] : 1) :
                        Ntree;
                        prox[near*n + k] = prox[near*k + n];
            }
            prox[near*n + n] = 1.0;
        }
        if (*testdat) {
            for (n = 0; n < ntest; ++n)
                for (k = 0; k < ntest + nsample; ++k)
                    proxts[ntest*k + n] /= Ntree;
        }
    }
    if (trace <= Ntree){
        printf("\nmdim %d, nclass %d, nrnodes %d, nsample %d, ntest %d\n", mdim, nclass, *nrnodes, nsample, ntest);
        printf("\noobprox %d, mdim %d, nsample0 %d, Ntree %d, mtry %d, mimp %d", oobprox, mdim, nsample0, Ntree, mtry, mimp);
        printf("\nstratify %d, replace %d",stratify,replace);
        printf("\n");
    }
    
    //frees up the memory
    free(tgini);free(wl);free(wr);free(classpop);free(tclasscat);
    free(tclasspop);free(tx);free(win);free(tp);free(out);
    free(bestsplitnext);free(bestsplit);free(nodepop);free(nodestart);free(jin);
    free(nodex);free(nodexts);free(ta);free(ncase);free(jerr);
    free(varUsed);free(jtr);free(jvr);free(classFreq);free(jts);
    free(idmove);free(at);free(a);free(b);free(mind);
    free(nright);free(nrightimp);free(nout);
    
    if (oobprox) {
        free(oobpair);
    }
    
    if (stratify) {
        free(strata_size);
        for (n = 0; n < nstrata; ++n) {
            free(strata_idx[n]);
        }
        free(strata_idx);        
    } else {
        if (replace)
            free(nind);
    }
    //printf("labelts %d\n",labelts);fflush(stdout);
    
    if (labelts) {
        free(nclts);        
    }
    //printf("stratify %d",stratify);fflush(stdout);
}
slatkin_result slatkin_mc(int maxreps, int r_obs[]) {
	slatkin_result results;
	double theta_estimate;
	int i, j, k, n, repno, Ecount, Fcount;
	int *r_random, *r_random_to_free;
	double E_obs, F_obs;
	double *ranvec;

	seedMT(time(NULL));
	
	/* Find k and n from the observed configuration  */
	
	k = 0;
	n = 0;
	while (r_obs[k+1]) {
		k++;
		n+=r_obs[k];
    }

    /*
		memory management notes - the following are dynamically allocated and need to be freed:

		Slatkin allocates these in the ivector, vector, and matrix methods, BUT does not pass back
		the raw pointer to the allocated region or "head." Instead, he passes back the spot where he wants
		the calling code to write data (in the case of the vector call especially).  So...we have to readjust
		the pointers *back* in order to free them...wacky, but true.  

		r_random -- what's being returned is head + 1, so free r_random - 1
		ranvec -- just free this pointer
		b -- have to walk K+1 rows, free each of the rows, then free b
    */



	r_random = ivector(0, k+1);
	r_random_to_free = r_random - 1;

	r_random[0] = r_random[k+1] = 0;
	ranvec = vector(1, k-1);  // to avoid doing this in each replicate
	
	/*  fill b matrix  */
	
	//double **b = matrix(1, k, 1, n);
	double **b = create2DDoubleArray(k+1, n+1);

	for (j=1; j<=n; j++)
		b[1][j] = 1.0 / j;
	for (i=2; i<=k; i++)  {
		b[i][i] = 1.0;
		for (j=i; j<n; j++)
			b[i][j+1] = (i * b[i-1][j] + j * b[i][j]) / (j + 1.0);
    }
    
	F_obs = F(k, n, r_obs);
	E_obs = ewens_stat(r_obs);
	/*printf("\nn = %d, k = %d, theta = %g, F = %g, maxrep = %d\n",
           n, k, theta_est(k, n), F_obs, maxrep);*/
	Ecount = 0;
	Fcount = 0;
	for (repno=1; repno<=maxreps; repno++)  {
		generate(k, n, r_random, ranvec, b);
		if (ewens_stat(r_random) <= E_obs) 
			Ecount++;
		if (F(k, n, r_random) <= F_obs)
			Fcount++;
    }

    theta_estimate = theta_est(k, n); 

    results.probability = (double) Ecount / (double) maxreps;
    results.theta_estimate = theta_estimate;

    /* free the dynamically allocated memory - the matrix b is still problematic */
    free(ranvec);
    free(r_random_to_free);



    /* 
		the allocations occur in the matrix() function, starting on line 207

		The matrix first allocates a number of rows (k) as pointers to doubles.  
		then it allocates all the columns (n) as pointers to doubles.  
		So to unwind this, we need to walk the matrix and free each row of doubles, and then we 
		can free the original list of double* pointers to the rows themselves.  

		right?  
    */

	for(i=0; i < k+1; i++) {
		free(b[i]);
	}
	free(b);

    return results;
}