Пример #1
0
 /** ensure that sum_x this(x) = 1 */
 inline void normalize() {
   assert(arity1() > 0);
   assert(arity2() > 0);
   // Compute the max value
   double max_value = logP(0,0);
   for(uint16_t asg1 = 0; asg1 < arity1(); ++asg1) 
     for(uint16_t asg2 = 0; asg2 < arity2(); ++asg2)  
       max_value = std::max(max_value, 
                            logP(asg1, asg2));
   assert( !std::isinf(max_value) );
   assert( !std::isnan(max_value) );
   // scale and compute normalizing constant
   double Z = 0.0;
   for(uint16_t asg1 = 0; asg1 < arity1(); ++asg1) 
     for(uint16_t asg2 = 0; asg2 < arity2(); ++asg2)  
       Z += std::exp(logP(asg1, asg2) -= max_value);
   assert( !std::isinf(Z) );
   assert( !std::isnan(Z) );
   assert( Z > 0.0);
   double logZ = std::log(Z);
   // Normalize
   for(uint16_t asg1 = 0; asg1 < arity1(); ++asg1) 
     for(uint16_t asg2 = 0; asg2 < arity2(); ++asg2)  
       logP(asg1, asg2) -= logZ;
 } // End of normalize
Пример #2
0
 void set_as_agreement(double lambda) {
   for(uint16_t i = 0; i < arity1(); ++i) { 
     for(uint16_t j = 0; j < arity2(); ++j) { 
       if( i != j) logP(i,j) = -lambda;
       else logP(i,j) = 0;
     }
   }
 } // end of set_as_agreement
Пример #3
0
 void set_as_laplace(double lambda) {
   for(uint16_t i = 0; i < arity1(); ++i) { 
     for(uint16_t j = 0; j < arity2(); ++j) { 
       logP(i,j) = -std::abs(double(i) - double(j)) * lambda;
     }
   }
 } // end of set_as_laplace
Пример #4
0
 /**
  * Compute the Mooji Kappen Message derivative. 
  */
 double mk_derivative() const {
   double max_value = -std::numeric_limits<double>::max();
   for(uint16_t a = 0; a < arity1(); ++a) {
     for(uint16_t b = 0; b < arity2(); ++b) {
       for(uint16_t x = 0; x < arity1(); ++x) {
         for(uint16_t y = 0; y < arity2(); ++y) {
           if(a != x && b != y) {
             double value =
               (logP(a,b) + logP(x,y) - (logP(x,b) + logP(a,y)))/4.0;
             value = std::tanh(value);
             max_value = std::max(max_value, value);
           }
         }
       }
     }
   }
   return max_value;
 }
Пример #5
0
 //! Print the factor description
 void printP(std::ostream& out) const {
   out << "Binary Factor(v_" << var1() << " in {1..."
       << arity1() << "}, " 
       << ", v_ " << var2() << " in {1..." 
       << arity2() << "})" << std::endl;
   for(uint16_t i = 0; i < arity1(); ++i) {
     for(uint16_t j = 0; j < arity2(); ++j) {
       out << std::exp(logP(i,j)) << " ";
     }
     out << std::endl;
   }
 }
Пример #6
0
int main(int argc, char **argv) {
    // ------------------------------------------------------------------------
    // Handle commandline arguments
    if (argc < 5) {
        printf("usage: emulator -p <port> -f <filename>\n");
        exit(1);
    }

    char *portStr    = NULL;
    const char *filename   = NULL;

    int cmd;
    while ((cmd = getopt(argc, argv, "p:f:")) != -1) {
        switch(cmd) {
            case 'p': portStr   = optarg; break;
            case 'f': filename  = optarg; break;
            case '?':
                if (optopt == 'p' || optopt == 'f')
                    fprintf(stderr, "Option -%c requires an argument.\n", optopt);
                else if (isprint(optopt))
                    fprintf(stderr, "Unknown option -%c.\n", optopt);
                else
                    fprintf(stderr, "Unknown option character '\\x%x'.\n", optopt);
                exit(EXIT_FAILURE);
            break;
            default: 
                printf("Unhandled argument: %d\n", cmd);
                exit(EXIT_FAILURE); 
        }
    }

    printf("Port        : %s\n", portStr);
    printf("File Name   : %s\n", filename);

    // Convert program args to values
    int emulPort  = atoi(portStr);
    int maxTime = 1500;
    int minTime = 500;

    // Validate the argument values
    if (emulPort <= 1024 ||emulPort >= 65536)
        ferrorExit("Invalid emul port");
    puts("");

    srand(time(NULL));
    initLog("log_file.txt");
	
    // ------------------------------------------------------------------------
    // Setup emul address info 
    struct addrinfo ehints;
    bzero(&ehints, sizeof(struct addrinfo));
    ehints.ai_family   = AF_INET;
    ehints.ai_socktype = SOCK_DGRAM;
    ehints.ai_flags    = AI_PASSIVE;
    
    char localhost[80];
    gethostname(localhost, sizeof(localhost));
    
    // Get the emul's address info
    struct addrinfo *emulinfo;
    int errcode = getaddrinfo(localhost, portStr, &ehints, &emulinfo);
    if (errcode != 0) {
        fprintf(stderr, "emul getaddrinfo: %s\n", gai_strerror(errcode));
        exit(EXIT_FAILURE);
    }

    // Loop through all the results of getaddrinfo and try to create a socket for emul
    int sockfd;
    struct addrinfo *sp;
    for(sp = emulinfo; sp != NULL; sp = sp->ai_next) {
        // Try to create a new socket
        sockfd = socket(sp->ai_family, sp->ai_socktype, sp->ai_protocol);
        if (sockfd == -1) {
            perror("Socket error");
            continue;
        }
		
		// Try to bind the socket
        if (bind(sockfd, sp->ai_addr, sp->ai_addrlen) == -1) {
            perror("Bind error");
            close(sockfd);
            continue;
        }
		
        break;
    }
    if (sp == NULL) perrorExit("Send socket creation failed");
    else            printf("emul socket created.\n");

	struct sockaddr_in *tmp = (struct sockaddr_in *)sp->ai_addr;
	unsigned long eIpAddr = tmp->sin_addr.s_addr;
	//printf("eIpAddr: %lu\n", eIpAddr);
  initTable(filename, tmp);

  exit(0);
  // ------------------------------------------------------------------------
	// The Big Loop of DOOM

	struct sockaddr_in *nextSock;
  int shouldForward;
	fd_set fds;
	
  struct timespec *tv = malloc(sizeof(struct timespec));
	tv->tv_sec = (long) 0;
	tv->tv_nsec = 0;
  int retval = 0;
	int numRecv = 0;
  int routesMade = 0;
	unsigned long long start;
  struct packet *dpkt;
  struct ip_packet *pkt = malloc(sizeof(struct ip_packet));
	void *msg = malloc(sizeof(struct ip_packet));
  for (;;) {
		FD_ZERO(&fds);
		FD_SET(sockfd, &fds);
		start = getTimeMS();
		retval = pselect(sockfd + 1, &fds, NULL, NULL, tv, NULL);
	
		// ------------------------------------------------------------------------
		// receiving half
        
		if (retval > 0 /*&& routesMade == 1*/) {
			// Receive and forward packet
			printf("retval > 0\n");
			bzero(msg, sizeof(struct ip_packet));
			size_t bytesRecvd = recvfrom(sockfd, msg, sizeof(struct ip_packet), 0, NULL, NULL);
			if (bytesRecvd == -1) {
				perror("Recvfrom error");
				fprintf(stderr, "Failed/incomplete receive: ignoring\n");
				continue;
			}
			
			// Deserialize the message into a packet 
			bzero(pkt, sizeof(struct ip_packet));
			deserializeIpPacket(msg, pkt);
			dpkt = (struct packet *)pkt->payload;
			printIpPacketInfo(pkt, NULL);
      
      // Check packet type to see if any action needs to be taken
      nextSock = malloc(sizeof(struct sockaddr_in));
      if (dpkt->type == 'T') {
        if (dpkt->len == 0) {
          bzero(nextSock, sizeof(struct sockaddr_in));
          shouldForward = 1;
          nextSock->sin_family = AF_INET;
					nextSock->sin_addr.s_addr = htonl(pkt->src);
					nextSock->sin_port = htons(pkt->srcPort);
          pkt->src = eIpAddr;
          pkt->srcPort = emulPort;
        }
        else {
          dpkt->len--;
          shouldForward = nextHop(pkt, nextSock);
        }
      }
      else if (dpkt->type == 'S') {
        /*if ((pkt->dest == eIpAddr && pkt->destPort == emulPort) || dpkt->len == 0) {
          dpkt = createNeighborPkt();
          bzero(pkt, sizeof(struct ip_packet));
          pkt->length = dpkt->len + HEADER_SIZE;
          pkt->priority = 0;
          memcpy(pkt->payload, dpkt, sizeof(struct packet));
          
        }*/
        if (updateLSP(pkt)) {
          floodLSP(sockfd, tmp, pkt);
        }
        shouldForward = 0;
        routesMade = 0;
      }
      else {
        shouldForward = nextHop(pkt, nextSock);
      }
      // Forward the packet if there is an entry for it
			if (shouldForward)	{
        printf("send packet\n");
        //printf("socket is %lu  %u", nextSock->sin_addr.s_addr, nextSock->sin_port);
        sendIpPacketTo(sockfd, pkt, (struct sockaddr*)nextSock);
        free(nextSock);
			}
			else {
				logP(pkt, "No forwarding entry found");
			}
			
      // update timespec
			long sec = tv->tv_sec - (long)((getTimeMS() - start) / 1000);
			long nsec = tv->tv_nsec - (long)(1000000 * ((getTimeMS() - start) % 1000));
			if (nsec < 0) {
				nsec = 1000000 * 1000 + nsec;
				sec--;
				
			}
			if (sec < 0 || !numRecv) {
					sec = 0;
					nsec = 0;
			}
			tv->tv_sec = sec;
			tv->tv_nsec = nsec;
		}
		if (retval == 0 || routesMade == 0) {
			// ------------------------------------------------------------------------
			// refresh forward table
			printf("retval == 0\n");
			if (retval == 0) {
        floodLSP(sockfd, tmp, NULL);
      }
      routesMade = createRoutes(tmp);
      
      int delay = minTime + (rand() % (maxTime - minTime));
      tv->tv_sec = (long)delay / 1000;
      tv->tv_nsec = (long)(delay % 1000) * 1000000;
		}
		else {
			//printf("Sockfd = %d\n", sockfd);
			//printf("tv%d  delay=%li s   %li us\n", x, tv->tv_sec, tv->tv_nsec);
			perrorExit("Select()");
		}
  }
	// Cleanup packets
  free(pkt);
  free(msg);
}
Пример #7
0
double TrinomialRand::P(const IntPair &point) const
{
    int x = point.first, y = point.second;
    return (x < 0 || x > n || y < 0 || y > n) ? 0.0 : std::exp(logP(point));
}
Пример #8
0
int main()
{
  // LOAD DATA
  arma::mat X;

  // A) Toy Data
  //  char inputFile[] = "../data_files/toyclusters/toyclusters.dat";
  
  // B) X4.dat
  //char inputFile[] = "./X4.dat";
  
  // C) fisher data
  //char inputFile[] = "./fisher.dat";

  // D) MNIST data
  //char inputFile[] = "../data_files/MNIST/MNIST.dat";
  
  // E) Reduced MNIST (5000x400)
  //char inputFile[] = "../data_files/MNIST/MNISTreduced.dat";
  
  // F) Reduced MNIST (0 and 1) (1000x400)
  //char inputFile[] = "../data_files/MNIST/MNISTlittle.dat";

  // G) Girl.png (512x768, RGB, already unrolled)
  //char inputFile[] = "girl.dat";

  // H) Pool.png (383x512, RGB, already unrolled)
  //char inputFile[] = "pool.dat";

  // I) Cat.png (733x490, RGB, unrolled)
  //char inputFile[] = "cat.dat";
  
  // J) Airplane.png (512x512, RGB, unrolled)
  //char inputFile[] = "airplane.dat";

  // K) Monarch.png (512x768, RGB, unrolled)
  //char inputFile[] = "monarch.dat";

  // L) tulips.png (512x768 ,RGB, unrolled)
  //char inputFile[] = "tulips.dat";
  
  // M) demo.dat (2d data)
  char inputFile[] = "demo.dat";

  // INITIALIZE PARAMETERS
  X.load(inputFile);
  const arma::uword N = X.n_rows;
  const arma::uword D = X.n_cols;

  arma::umat ids(N,1);	// needed to shuffle indices later
  arma::umat shuffled_ids(N,1);
  for (arma::uword i = 0; i < N; ++i) {
    ids(i,0) = i;
  }
  
  arma::arma_rng::set_seed_random(); // set arma rng
  // int seed = time(NULL);	// set RNG seed to current time
  // srand(seed);

  arma::uword initial_K = 32;   // initial number of clusters
  arma::uword K = initial_K;
  
  arma::umat clusters(N,1); // contains cluster assignments for each data point
  for (arma::uword i = 0; i < N; ++i) {
    clusters(i, 0) = i%K;		// initialize as [0,1,...,K-1,0,1,...,K-1,0,...]
  }

  arma::umat cluster_sizes(N,1,arma::fill::zeros); // contains num data points in cluster k
  for (arma::uword i = 0; i < N; ++i) {
    cluster_sizes(clusters(i,0), 0) += 1;
  }

  arma::mat mu(N, D, arma::fill::zeros); // contains cluster mean parameters
  arma::mat filler(D,D,arma::fill::eye);
  std::vector<arma::mat> sigma(N,filler); // contains cluster covariance parameters

  if (K < N) {			// set parameters not belonging to any cluster to -999
    mu.rows(K,N-1).fill(-999);
    for (arma::uword k = K; k < N; ++k) {
      sigma[k].fill(-999);
    }
  }

  arma::umat uword_dummy(1,1);	// dummy 1x1 matrix;
  // for (arma::uword i = 0; i <N; ++i) {
  //   std::cout << sigma[i] << std::endl;
  // }
  // std::cout << X << std::endl
  // 	    << N << std::endl
  // 	    << D << std::endl
  // 	    << K << std::endl
  // 	    << clusters << std::endl
  // 	    << cluster_sizes << std::endl
  //	    << ids << std::endl;

  // INITIALIZE HYPER PARAMETERS
  // Dirichlet Process concentration parameter is alpha:
  double alpha = 1;
  // Dirichlet Process base distribution (i.e. prior) is 
  // H(mu,sigma) = NIW(mu,Sigma|m_0,k_0,S_0,nu_0) = N(mu|m_0,Sigma/k_0)IW(Sigma|S_0,nu_0)
  arma::mat perturbation(D,D,arma::fill::eye);
  perturbation *= 0.000001;
  //const arma::mat S_0 = arma::cov(X,X,1) + perturbation; // S_xbar / N
  const arma::mat S_0(D,D,arma::fill::eye);
  const double nu_0 = D + 2;
  const arma::mat m_0 = mean(X).t();
  const double k_0 = 0.01;
  // std::cout << "S_0" << S_0 << std::endl;
  // std::cout << S_0 << std::endl
  // 	    << nu_0 << std::endl
  // 	    << m_0 << std::endl
  // 	    << k_0 << std::endl;


  // INITIALIZE SAMPLING PARAMETERS
  arma::uword NUM_SWEEPS = 250;	// number of Gibbs iterations
  bool SAVE_CHAIN = false;	// save output of each Gibbs iteration?
  // arma::uword BURN_IN = NUM_SWEEPS - 10;
  // arma::uword CHAINSIZE = NUM_SWEEPS - BURN_IN;
  // std::vector<arma::uword> chain_K(CHAINSIZE, K); // Initialize chain variable to initial parameters for convinience
  // std::vector<arma::umat> chain_clusters(CHAINSIZE, clusters);
  // std::vector<arma::umat> chain_clusterSizes(CHAINSIZE, cluster_sizes);
  // std::vector<arma::mat> chain_mu(CHAINSIZE, mu);
  // std::vector<std::vector<arma::mat> > chain_sigma(CHAINSIZE, sigma);
  
  // for (arma::uword sweep = 0; sweep < CHAINSIZE; ++sweep) {
  //   std::cout << sweep << " K\n" << chain_K[sweep] << std::endl
  // 		<< sweep << " clusters\n" << chain_clusters[sweep] << std::endl
  // 		<< sweep << " sizes\n" << chain_clusterSizes[sweep] << std::endl
  // 		<< sweep << " mu\n" << chain_mu[sweep] << std::endl;
  //   for (arma::uword i = 0; i < N; ++i) { 
  // 	std::cout << sweep << " " << i << " sigma\n" << chain_sigma[sweep][i] << std::endl;
  //   }
  // }
  

  // START CHAIN
  std::cout << "Starting Algorithm with K = " << K << std::endl;
  for (arma::uword sweep = 0; sweep < NUM_SWEEPS; ++sweep) { 
    // shuffle indices
    shuffled_ids = shuffle(ids);
    // std::cout << shuffled_ids << std::endl;

    // SAMPLE CLUSTERS
    for (arma::uword j = 0; j < N; ++j){
      //      std::cout << "j = " << j << std::endl;
      arma::uword i = shuffled_ids(j);
      arma::mat x = X.row(i).t(); // current data point
      
      // Remove i's statistics and any empty clusters
      arma::uword c = clusters(i,0); // current cluster
      cluster_sizes(c,0) -= 1;
      //std::cout << "old c = " << c << std::endl;

      if (cluster_sizes(c,0) == 0) { // remove empty cluster
      	cluster_sizes(c,0) = cluster_sizes(K-1,0); // move entries for K onto position c
	mu.row(c) = mu.row(K-1);
      	sigma[c] = sigma[K-1];
      	uword_dummy(0,0) = c;
	arma::uvec idx = find(clusters == K - 1);
      	clusters.each_row(idx) = uword_dummy;
	cluster_sizes(K-1,0) = 0;
	mu.row(K-1).fill(-999);
	sigma[K-1].fill(-999);
	--K;
      }

      // quick test of logMvnPdf:
      // arma::mat m_(2,1);
      // arma::mat s_(2,2);
      // arma::mat t_(2,1);
      // m_ << 1 << arma::endr << 2;
      // s_ << 3 << -0.2 << arma::endr << -0.2 << 1;
      // t_ << -3 << arma::endr << -3;
      // double lpdf = logMvnPdf(t_, m_, s_);
      // std::cout << lpdf << std::endl; // śhould be -19.1034 (works)


      
      // Find categorical distribution over clusters (tested)
      arma::mat logP(K+1, 1, arma::fill::zeros);
      
      // quick test of logInvWishPdf
      // arma::mat si_(2,2), s_(2,2);
      // double nu_ = 4;
      // si_ << 1 << 0.5 << arma::endr << 0.5 << 4;
      // s_ << 3 << -0.2 << arma::endr << -0.2 << 1;
      // double lpdf = logInvWishPdf(si_, s_, nu_);
      // std::cout << lpdf << std::endl; // should be -7.4399 (it is)

      // quick test for logNormInvWishPdf
      // arma::mat si_(2,2), s_(2,2);
      // arma :: mat mu_(2,1), m_(2,1);
      // double k_ = 0.5, nu_ = 4;
      // si_ << 1 << 0.5 << arma::endr << 0.5 << 4;
      // s_ << 3 << -0.2 << arma::endr << -0.2 << 1;
      // mu_ << -3 << arma::endr << -3;
      // m_ << 1 << arma::endr << 2;
      // double lp = logNormInvWishPdf(mu_,si_,m_,k_,s_,nu_);
      // std::cout << lp << std::endl; // should equal -15.2318 (it is)
      
      // p(existing clusters) (tested)
      for (arma::uword k = 0; k < K; ++k) {
	arma::mat m_ = mu.row(k).t();
	arma::mat s_ = sigma[k];
	logP(k,0) = log(cluster_sizes(k,0)) - log(N-1+alpha) + logMvnPdf(x,m_,s_);
      }

      // p(new cluster): find partition function (tested)
      // arma::mat dummy_mu(D, 1, arma::fill::zeros);
      // arma::mat dummy_sigma(D, D, arma::fill::eye);
      // double logPrior, logLklihd, logPstr, logPartition; 
      // posterior hyperparameters (tested)
      arma::mat m_1(D,1), S_1(D,D);
      double k_1, nu_1;
      k_1 = k_0 + 1;
      nu_1 = nu_0 + 1;
      m_1 = (k_0*m_0 + x) / k_1;
      S_1 = S_0 + x * x.t() + k_0 * (m_0 * m_0.t()) - k_1 * (m_1 * m_1.t());
      // std::cout << k_1 << std::endl
      // 		<< nu_1 << std::endl
      // 		<< m_1 << std::endl
      // 		<< S_1 << std::endl;
      
      // // partition = likelihood*prior/posterior (tested)
      // // (perhaps a direct computation of the partition function would be better)
      // logPrior = logNormInvWishPdf(dummy_mu, dummy_sigma, m_0, k_0, S_0, nu_0);
      // logLklihd = logMvnPdf(x, dummy_mu, dummy_sigma);
      // logPstr = logNormInvWishPdf(dummy_mu, dummy_sigma, m_1, k_1, S_1, nu_1);
      // logPartition = logPrior + logLklihd - logPstr;      
      // std::cout << "log  Prior = " << logPrior << std::endl
      // 		<< "log Likelihood = " << logLklihd << std::endl
      // 		<< "log Posterior = " << logPstr << std::endl
      // 		<< "log Partition = " << logPartition << std::endl;
      
      // Computing partition directly
      double logS0,signS0,logS1,signS1;
      arma::log_det(logS0,signS0,S_0);
      arma::log_det(logS1,signS1,S_1);
      /*std::cout << "log(det(S_0)) = " << logS0 << std::endl
	<< "log(det(S_1)) = " << logS1 << std::endl;*/
      double term1 = 0.5*D*(log(k_0)-log(k_1));
      double term2 = -0.5*D*log(arma::datum::pi);
      double term3 = 0.5*(nu_0*logS0 - nu_1*logS1);
      double term4 = lgamma(0.5*nu_1);
      double term5 = -lgamma(0.5*(nu_1-D));
      double logPartition = term1+term2+term3+term4+term5;
      /*double logPartition = 0.5*D*(log(k_0)-log(k_1)-log(arma::datum::pi)) \
	/+0.5*(nu_0*logS0 - nu_1*logS1) + lgamma(0.5*nu_1) - lgamma(0.5*(nu_1-D));*/
      
      /*      std::cout << "term1 = " << term1 << std::endl
		<< "term2 = " << term2 << std::endl
		<< "term3 = " << term3 << std::endl
		<< "term4 = " << term4 << std::endl
		<< "term5 = " << term5 << std::endl;*/
      
      //std::cout << "logP = " << logPartition << std::endl;

      // p(new cluster): (tested)
      logP(K,0) = log(alpha) - log(N - 1 + alpha) + logPartition;
      //      std::cout << "logP(new cluster) = " << logP(K,0) << std::endl;
      //if(i == 49)
      //assert(false);
      // sample cluster for i
      
      

      arma::uword c_ = logCatRnd(logP);
      clusters(i,0) = c_;
      
      //if (j % 10 == 0){
      //std::cout << "New c = " << c_ << std::endl;
      //std::cout << "logP = \n" << logP << std::endl;
      //}
      // quick test for mvnRnd
      // arma::mat mu, si;
      // mu << 1 << arma::endr << 2;
      // si << 1 << 0.9 << arma::endr << 0.9 << 1;
      // arma::mat m = mvnRnd(mu, si);
      // std::cout << m << std::endl;

      // quick test for invWishRnd
      // double df = 4;
      // arma::mat si(2,2);
      // si << 1 << 1 << arma::endr << 1 << 1;
      // arma::mat S = invWishRnd(si,df);
      // std::cout << S << std::endl;

      if (c_ == K) {	// Sample parameters for any new-born clusters from posterior
	cluster_sizes(K, 0) = 1;
	arma::mat si_ = invWishRnd(S_1, nu_1);
	//arma::mat si_ = S_1;
	arma::mat mu_ = mvnRnd(m_1, si_/k_1);
	//arma::mat mu_ = m_1;
	mu.row(K) = mu_.t();
	sigma[K] = si_;
	K += 1;
      } else {
	cluster_sizes(c_,0) += 1;
      }
      // if (sweep == 0)
      // 	std::cout << " K = " << K << std::endl;
      // // if (j == N-1) {
      // // 	std::cout << logP << std::endl;
      // // 	std::cout << K << std::endl;
      // // 	assert(false);
      // // }
      // std::cout << "K = " << K << "\n" << std::endl;
    }
    
    // sample CLUSTER PARAMETERS FROM POSTERIOR
    for (arma::uword k = 0; k < K; ++k) {
      // std::cout << "k = " << k << std::endl;
      // cluster data
      arma::mat Xk = X.rows(find(clusters == k));
      arma::uword Nk = cluster_sizes(k,0);

      // posterior hyperparameters
      arma::mat m_Nk(D,1), S_Nk(D,D);
      double k_Nk, nu_Nk;
      
      arma::mat sum_k = sum(Xk,0).t();
      arma::mat cov_k(D, D, arma::fill::zeros);
      for (arma::uword l = 0; l < Nk; ++l) { 
	cov_k += Xk.row(l).t() * Xk.row(l);
      }
      
      k_Nk = k_0 + Nk;
      nu_Nk = nu_0 + Nk;
      m_Nk = (k_0 * m_0 + sum_k) / k_Nk;
      S_Nk = S_0 + cov_k + k_0 * (m_0 * m_0.t()) - k_Nk * (m_Nk * m_Nk.t());
      
      // sample fresh parameters
      arma::mat si_ = invWishRnd(S_Nk, nu_Nk);
      //arma::mat si_ = S_Nk;
      arma::mat mu_ = mvnRnd(m_Nk, si_/k_Nk);
      //arma::mat mu_ = m_Nk;
      mu.row(k) = mu_.t();
      sigma[k] = si_;
    }
    std::cout << "Iteration " << sweep + 1 << "/" << NUM_SWEEPS<< " done. K = " << K << std::endl;
        
    // // STORE CHAIN
    // if (SAVE_CHAIN) {
    //   if (sweep >= BURN_IN) { 
    // 	chain_K[sweep - BURN_IN] = K;
    // 	chain_clusters[sweep - BURN_IN] = clusters;
    // 	chain_clusterSizes[sweep - BURN_IN] = cluster_sizes;
    // 	chain_mu[sweep - BURN_IN] = mu;
    // 	chain_sigma[sweep - BURN_IN] = sigma;
    //   }
    // }
	
  }
  std::cout << "Final cluster sizes: " << std::endl
	    << cluster_sizes.rows(0, K-1) << std::endl;






  
  // WRITE OUPUT DATA TO FILE
  arma::mat MU = mu.rows(0,K-1);
  arma::mat SIGMA(D*K,D);
  for (arma::uword k = 0; k < K; ++k) { 
    SIGMA.rows(k*D,(k+1)*D-1) = sigma[k];
  }
  arma::umat IDX = clusters;

  // A) toycluster data
  // char MuFile[] = "../data_files/toyclusters/dpmMU.out";
  // char SigmaFile[] = "../data_files/toyclusters/dpmSIGMA.out";
  // char IdxFile[] = "../data_files/toyclusters/dpmIDX.out";

  // B) X4.dat
  char MuFile[] = "dpmMU.out";
  char SigmaFile[] = "dpmSIGMA.out";
  char IdxFile[] = "dpmIDX.out";
  std::ofstream KFile("K.out");
  
  MU.save(MuFile, arma::raw_ascii);
  SIGMA.save(SigmaFile, arma::raw_ascii);
  IDX.save(IdxFile, arma::raw_ascii);
  KFile << "K = " << K << std::endl;
  
  if (SAVE_CHAIN) {}
  //   std::ofstream chainKFile("chainK.out");
  //   std::ofstream chainClustersFile("chainClusters.out");
  //   std::ofstream chainClusterSizesFile("chainClusterSizes.out");
  //   std::ofstream chainMuFile("chainMu.out");
  //   std::ofstream chainSigmaFile("chainSigma.out");
    
  //   chainKFile << "Dirichlet Process Mixture Model.\nInput: " << inputFile << std::endl
  // 	       << "Number of iterations of Gibbs Sampler: " << NUM_SWEEPS << std::endl
  // 	       << "Burn-In: " << BURN_IN << std::endl
  // 	       << "Initial number of clusters: " << initial_K << std::endl
  // 	       << "Output: Number of cluster (K)\n" << std::endl; 
  //   chainClustersFile << "Dirichlet Process Mixture Model.\nInput: " << inputFile << std::endl
  // 		      << "Number of iterations of Gibbs Sampler: " << NUM_SWEEPS << std::endl
  // 		      << "Burn-In: " << BURN_IN << std::endl
  // 		      << "Initial number of clusters: " << initial_K << std::endl
  // 		      << "Output: Cluster identities (clusters)\n" << std::endl; 
  //   chainClusterSizesFile << "Dirichlet Process Mixture Model.\nInput: " << inputFile << std::endl
  // 			  << "Number of iterations of Gibbs Sampler: " << NUM_SWEEPS << std::endl
  // 			  << "Burn-In: " << BURN_IN << std::endl
  // 			  << "Initial number of clusters: " << initial_K << std::endl
  // 			  << "Output: Size of clusters (cluster_sizes)\n" << std::endl; 
  //   chainMuFile << "Dirichlet Process Mixture Model.\nInput: " << inputFile << std::endl
  // 		<< "Number of iterations of Gibbs Sampler: " << NUM_SWEEPS << std::endl
  // 		<< "Burn-In: " << BURN_IN << std::endl
  // 		<< "Initial number of clusters: " << initial_K << std::endl
  // 		<< "Output: Samples for cluster mean parameters (mu. Note: means stored in rows)\n" << std::endl; 
  //   chainSigmaFile << "Dirichlet Process Mixture Model.\nInput: " << inputFile << std::endl
  // 		   << "Number of iterations of Gibbs Sampler: " << NUM_SWEEPS << std::endl
  // 		   << "Burn-In " << BURN_IN << std::endl
  // 		   << "Initial number of clusters: " << initial_K << std::endl
  // 		   << "Output: Samples for cluster covariances (sigma)\n" << std::endl; 


  //   for (arma::uword sweep = 0; sweep < CHAINSIZE; ++sweep) {
  //     arma::uword K = chain_K[sweep];
  //     chainKFile << "Sweep #" << BURN_IN + sweep + 1 << "\n" << chain_K[sweep] << std::endl;
  //     chainClustersFile << "Sweep #" << BURN_IN + sweep + 1 << "\n" << chain_clusters[sweep]  << std::endl;
  //     chainClusterSizesFile << "Sweep #" << BURN_IN + sweep + 1 << "\n" << chain_clusterSizes[sweep].rows(0, K - 1) << std::endl;
  //     chainMuFile << "Sweep #" << BURN_IN + sweep + 1 << "\n" << chain_mu[sweep].rows(0, K - 1) << std::endl;
  //     chainSigmaFile << "Sweep #" << BURN_IN + sweep + 1<< "\n";
  //     for (arma::uword i = 0; i < K; ++i) { 
  // 	chainSigmaFile << chain_sigma[sweep][i] << std::endl;
  //     }
  //     chainSigmaFile << std::endl;
  //   }
  // }

  return 0;
}
double NegativeBinomialDistribution<T>::P(const int & k) const
{
    return (k < 0) ? 0.0 : std::exp(logP(k));
}