Пример #1
0
/*
 * inferece of HMM parameters based on variational Bayes method
 */
int TrainVBHMM::train(const Array2D<Real>& music_spec_td, const HyperparameterVBHMM& hyper_param){
	const int T = music_spec_td.rows(), D = music_spec_td.cols();
	
	Array2D<Array2D<PN> > prev_on_s_td_kj; // 4D-array for transition probability
	// prev_on_s_td_kj(k,j)(t,d): probability of s_{t,d} being j where s_{t-1,d} = k and k = {0,1}
	
	// p_on_s_td_(t,d): probability of s_{t,d} being 1. 1-p_on_s_td_(t,d) is the probability of s_{t,d}=0.
	
	// 1. initialize
	Real active_proportion;
	if(is_silent_)
		active_proportion = init_1st_time(music_spec_td, hyper_param.m0_, hyper_param.init_active_probability_, p_on_s_td_, prev_on_s_td_kj);
	else
		active_proportion = init(music_spec_td, p_on_s_td_, prev_on_s_td_kj);
	
	if(active_proportion < hyper_param.active_proportion_){ // no inference for now
		// the observation is too biased to silence
		return 1;
	}
	is_silent_ = false;
	
	// initialize the suffcient statistics
	std::vector<Real> suff_1_j(2,0), suff_x_j(2,0), suff_xx_j(2,0);
	update_suff_stat(music_spec_td, p_on_s_td_, suff_1_j, suff_x_j, suff_xx_j);
	update_normal_gamma(music_spec_td, hyper_param, suff_1_j, suff_x_j, suff_xx_j); // hat_*** variables are updated
	
	// 2. iteration
	const Real convergence_threshold = 0.001;
	Array2D<PN> before_p_on_s_td = p_on_s_td_;
	for(int iter = 0; iter < hyper_param.max_iteration_; ++iter){
		forward_backward(music_spec_td, p_on_s_td_, prev_on_s_td_kj);
		update_alpha(hyper_param, prev_on_s_td_kj);
		update_suff_stat(music_spec_td, p_on_s_td_, suff_1_j, suff_x_j, suff_xx_j);
		update_normal_gamma(music_spec_td, hyper_param, suff_1_j, suff_x_j, suff_xx_j);
		
		// check convergence
		if(iter == 0)
			continue;
		
		Real diff_p=0;
		for(int t = 0; t < T; ++t){
			for(int d = 0; d < D; ++d){
				const Real diff_td = static_cast<Real>(p_on_s_td_(t,d) - before_p_on_s_td(t,d));
				diff_p += diff_td > 0 ? diff_td : -diff_td;
			}
		}
		diff_p /= (T*D);
		//printf("iter (%02d): diff_p = %f\n",iter,diff_p);
		
		if(diff_p < convergence_threshold){
			break; // end the iteration after converged
		}
		
		before_p_on_s_td = p_on_s_td_;
	}
	
	return 0;
}
Пример #2
0
int main(int argc, char *argv[])
{
  char *configfile = NULL;
  FILE *fin, *bin;

  char *linebuf = NULL;
  size_t buflen = 0;

  int iterations = 3;
  int mode = 3;

  int c;
  float d;
  float *loglik;
  float p;
  int i, j, k;
  opterr = 0;


  while ((c = getopt(argc, argv, "c:n:hp:")) != -1) {
    switch (c) {
    case 'c':
      configfile = optarg;
      break;
    case 'h':
      usage();
      exit(EXIT_SUCCESS);
    case 'n':
      iterations = atoi(optarg);
      break;
    case 'p':
      mode = atoi(optarg);
      if (mode != 1 && mode != 2 && mode != 3) {
        fprintf(stderr, "illegal mode: %d\n", mode);
        exit(EXIT_FAILURE);
      }
      break;
    case '?':
      fprintf(stderr, "illegal options\n");
      exit(EXIT_FAILURE);
    default:
      abort();
    }
  }

  if (configfile == NULL) {
    fin = stdin;
  } else {
    fin = fopen(configfile, "r");
    if (fin == NULL) {
      handle_error("fopen");
    }
  }
  
  i = 0;
  while ((c = getline(&linebuf, &buflen, fin)) != -1) {
    if (c <= 1 || linebuf[0] == '#')
      continue;
    
    if (i == 0) {
      if (sscanf(linebuf, "%d", &nstates) != 1) {
        fprintf(stderr, "config file format error: %d\n", i);
        freeall();
        exit(EXIT_FAILURE);
      }

      prior = (float *) malloc(sizeof(float) * nstates);
      if (prior == NULL) handle_error("malloc");

      trans = (float *) malloc(sizeof(float) * nstates * nstates);
      if (trans == NULL) handle_error("malloc");

      xi = (float *) malloc(sizeof(float) * nstates * nstates);
      if (xi == NULL) handle_error("malloc");

      pi = (float *) malloc(sizeof(float) * nstates);
      if (pi == NULL) handle_error("malloc");

    } else if (i == 1) {
      if (sscanf(linebuf, "%d", &nobvs) != 1) {
        fprintf(stderr, "config file format error: %d\n", i);
        freeall();
        exit(EXIT_FAILURE);
      }

      obvs = (float *) malloc(sizeof(float) * nstates * nobvs);
      if (obvs == NULL) handle_error("malloc");

      gmm = (float *) malloc(sizeof(float) * nstates * nobvs);
      if (gmm == NULL) handle_error("malloc");

    } else if (i == 2) {
      /* read initial state probabilities */ 
      bin = fmemopen(linebuf, buflen, "r");
      if (bin == NULL) handle_error("fmemopen");
      for (j = 0; j < nstates; j++) {
        if (fscanf(bin, "%f", &d) != 1) {
          fprintf(stderr, "config file format error: %d\n", i);
          freeall();
          exit(EXIT_FAILURE);
        }
        prior[j] = logf(d);
      }
      fclose(bin);

    } else if (i <= 2 + nstates) {
      /* read state transition  probabilities */ 
      bin = fmemopen(linebuf, buflen, "r");
      if (bin == NULL) handle_error("fmemopen");
      for (j = 0; j < nstates; j++) {
        if (fscanf(bin, "%f", &d) != 1) {
          fprintf(stderr, "config file format error: %d\n", i);
          freeall();
          exit(EXIT_FAILURE);
        }
        trans[IDX((i - 3),j,nstates)] = logf(d);
      }
      fclose(bin);
    } else if (i <= 2 + nstates * 2) {
      /* read output probabilities */
      bin = fmemopen(linebuf, buflen, "r");
      if (bin == NULL) handle_error("fmemopen");
      for (j = 0; j < nobvs; j++) {
        if (fscanf(bin, "%f", &d) != 1) {
          fprintf(stderr, "config file format error: %d\n", i);
          freeall();
          exit(EXIT_FAILURE);
        }
        obvs[IDX((i - 3 - nstates),j,nobvs)] = logf(d);
      }
      fclose(bin);
    } else if (i == 3 + nstates * 2) {
      if (sscanf(linebuf, "%d %d", &nseq, &length) != 2) {
        fprintf(stderr, "config file format error: %d\n", i);
        freeall();
        exit(EXIT_FAILURE);
      }
      data = (int *) malloc (sizeof(int) * nseq * length);
      if (data == NULL) handle_error("malloc");
    } else if (i <= 3 + nstates * 2 + nseq) {
      /* read data */
      bin = fmemopen(linebuf, buflen, "r");
      if (bin == NULL) handle_error("fmemopen");
      for (j = 0; j < length; j++) {
        if (fscanf(bin, "%d", &k) != 1 || k < 0 || k >= nobvs) {
          fprintf(stderr, "config file format error: %d\n", i);
          freeall();
          exit(EXIT_FAILURE);
        }
        data[(i - 4 - nstates * 2) * length + j] = k;
      }
      fclose(bin);
    }

    i++;
  }
  fclose(fin);
  if (linebuf) free(linebuf);

  if (i < 4 + nstates * 2 + nseq) {
    fprintf(stderr, "configuration incomplete.\n");
    freeall();
    exit(EXIT_FAILURE);
  }

  if (mode == 3) {
    loglik = (float *) malloc(sizeof(float) * nseq);
    if (loglik == NULL) handle_error("malloc");

    for (i = 0; i < iterations; i++) {
      init_count();
      for (j = 0; j < nseq; j++) {
        loglik[j] = forward_backward(data + length * j, length, 1);
      }
      p = sumf(loglik, nseq);

      update_prob();

      printf("iteration %d log-likelihood: %.4f\n", i + 1, p);
      printf("updated parameters:\n");
      printf("# initial state probability\n");
      for (j = 0; j < nstates; j++) {
        printf(" %.4f", exp(prior[j]));
      }
      printf("\n");
      printf("# state transition probability\n");
      for (j = 0; j < nstates; j++) {
        for (k = 0; k < nstates; k++) {
          printf(" %.4f", exp(trans[IDX(j,k,nstates)]));
        }
        printf("\n");
      }
      printf("# state output probility\n");
      for (j = 0; j < nstates; j++) {
        for (k = 0; k < nobvs; k++) {
          printf(" %.4f", exp(obvs[IDX(j,k,nobvs)]));
        }
        printf("\n");
      }
      printf("\n");
    }
    free(loglik);
  } else if (mode == 2) {
    for (i = 0; i < nseq; i++) {
      viterbi(data + length * i, length);
    }
  } else if (mode == 1) {
    loglik = (float *) malloc(sizeof(float) * nseq);
    if (loglik == NULL) handle_error("malloc");
    for (i = 0; i < nseq; i++) {
      loglik[i] = forward_backward(data + length * i, length, 0);
    }
    p = sumf(loglik, nseq);

    for (i = 0; i < nseq; i++)
      printf("%.4f\n", loglik[i]);
    printf("total: %.4f\n", p);
    free(loglik);
  }

  freeall();
  return 0;
}