// this does two sampling steps for a rbm. TODO: Do one sampling step less please. rbm get_next(rbm* r) { // first sample the hidden layer from the visible one std::vector<int> hidden = sample_hidden(r, r->visible); // make a new rbm rbm r2 = *r; std::vector<int> new_visible = sample_visible(r, r->hidden); r2.visible = new_visible; // you don't have to sample the hidden one //std::vector<int> new_hidden = sample_hidden(&r2, r2.visible); //r2.hidden = new_hidden; return r2; }
void train_crbm(dataset_blas *train_set, int nvisible, int nhidden, int nlabel, int epoch, double lr, int minibatch, double momentum, char *model_file){ crbm m; int i, j, k; int nepoch; double *v1, *y1; uint8_t *l; int batch_size, niter; int wc; int err; double lik; double delta; int *idx; time_t start_t, end_t; init_crbm(&m, nvisible, nhidden, nlabel); //wc = (int)cblas_dasum(train_set->N * train_set->n_feature, train_set->input, 1); niter = (train_set->N-1)/minibatch + 1; delta = lr / (1.0*minibatch); /* * shuffle training data v1 = (double*)malloc(minibatch * m.nvisible * sizeof(double)); y1 = (double*)malloc(minibatch * m.ncat * sizeof(double)); l = (uint8_t*)malloc(minibatch * sizeof(uint8_t)); idx = (int*)malloc(train_set->N * sizeof(int)); for(i = 0; i < train_set->N; i++) idx[i] = i;*/ bzero(w_u, sizeof(w_u)); bzero(u_u, sizeof(u_u)); bzero(bh_u, sizeof(bh_u)); bzero(bv_u, sizeof(bv_u)); bzero(by_u, sizeof(by_u)); //shuffle(idx, train_set->N); for(nepoch = 0; nepoch < epoch; nepoch++){ lik = 0; err = 0; start_t = time(NULL); for(k = 0; k < niter; k++){ #ifdef DEBUG if((k+1) % 200 == 0){ printf("batch %d\n", k+1); } #endif if(k == niter - 1){ batch_size = train_set->N - minibatch * (niter-1); }else{ batch_size = minibatch; } v1 = train_set->input + train_set->n_feature * minibatch * k; y1 = train_set->label + train_set->nlabel * minibatch * k; l = train_set->output + minibatch * k; /* * shuffle training data for(i = 0; i < batch_size; i++){ cblas_dcopy(m.nvisible, train_set->input + m.nvisible * idx[k*minibatch+i], 1, v1 + m.nvisible * i, 1); cblas_dcopy(m.ncat, train_set->label + m.ncat * idx[k*minibatch+i], 1, y1 + m.ncat * i, 1); l[i] = train_set->output[idx[k*minibatch+i]]; }*/ get_hidden(&m, v1, y1, ph1, batch_size); sample_hidden(&m, ph1, h1, batch_size); get_visible(&m, h1, pv, batch_size); sample_visible(&m, pv, v2, batch_size); get_class(&m, h1, py, batch_size); sample_class(&m, py, y2, batch_size); get_hidden(&m, v2, y2, ph2, batch_size); sample_hidden(&m, ph2, h2, batch_size); //lik += get_likelihood(&m, v1, pv, batch_size); err += get_error(&m, y2, l, batch_size); //update w_u cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.nhidden, m.nvisible, batch_size, 1.0, ph2, m.nhidden, v2, m.nvisible, 0, a, m.nvisible); cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.nhidden, m.nvisible, batch_size, 1.0, ph1, m.nhidden, v1, m.nvisible, -1, a, m.nvisible); cblas_daxpy(m.nvisible * m.nhidden, momentum, w_u, 1, a, 1); cblas_dcopy(m.nvisible * m.nhidden, a, 1, w_u, 1); //update u_u cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.nhidden, m.ncat, batch_size, 1.0, ph2, m.nhidden, y2, m.ncat, 0, a, m.ncat); cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.nhidden, m.ncat, batch_size, 1.0, ph1, m.nhidden, y1, m.ncat, -1, a, m.ncat); cblas_daxpy(m.ncat * m.nhidden, momentum, u_u, 1, a, 1); cblas_dcopy(m.ncat * m.nhidden, a, 1, u_u, 1); //update bv_u cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.nvisible, 1, batch_size, 1.0, v2, m.nvisible, I, 1, 0, a, 1); cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.nvisible, 1, batch_size, 1.0, v1, m.nvisible, I, 1, -1, a, 1); cblas_daxpy(m.nvisible, momentum, bv_u, 1, a, 1); cblas_dcopy(m.nvisible, a, 1, bv_u, 1); //update by_u cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.ncat, 1, batch_size, 1.0, y2, m.ncat, I, 1, 0, a, 1); cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.ncat, 1, batch_size, 1.0, y1, m.ncat, I, 1, -1, a, 1); cblas_daxpy(m.ncat, momentum, by_u, 1, a, 1); cblas_dcopy(m.ncat, a, 1, by_u, 1); //update bh_u cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.nhidden, 1, batch_size, 1.0, ph2, m.nhidden, I, 1, 0, a, 1); cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, m.nhidden, 1, batch_size, 1.0, ph1, m.nhidden, I, 1, -1, a, 1); cblas_daxpy(m.nhidden, momentum, bh_u, 1, a, 1); cblas_dcopy(m.nhidden, a, 1, bh_u, 1); //change parameter cblas_daxpy(m.nvisible * m.nhidden, delta, w_u, 1, m.w, 1); cblas_daxpy(m.ncat * m.nhidden, delta, u_u, 1, m.u, 1); cblas_daxpy(m.nvisible, delta, bv_u, 1, m.bv, 1); cblas_daxpy(m.ncat, delta, by_u, 1, m.by, 1); cblas_daxpy(m.nhidden, delta, bh_u, 1, m.bh, 1); } end_t = time(NULL); printf("[epoch %d] error:%.5lf%%\ttime:%.2fmin\n", nepoch + 1, err * 100.0 / train_set->N, (end_t - start_t) / 60.0); } dump_model(&m, model_file); //print_prob(&m, train_set, "../data/rsm/test.prob"); /* * shuffle training data free(v1); free(y1); free(l);*/ free_crbm(&m); }