tnn_error tnn_module_bprop_linear(tnn_module *m){ tnn_error ret; gsl_matrix w; gsl_matrix dw; //Routine check if(m->t != TNN_MODULE_TYPE_LINEAR){ return TNN_ERROR_MODULE_MISTYPE; } if(m->input->valid != true || m->output->valid != true || m->w.valid != true){ return TNN_ERROR_STATE_INVALID; } //Transform the matrix TNN_MACRO_ERRORTEST(tnn_numeric_v2m(&m->w.x, &w, m->output->size, m->input->size),ret); TNN_MACRO_ERRORTEST(tnn_numeric_v2m(&m->w.dx, &dw, m->output->size, m->input->size), ret); //bprop to input TNN_MACRO_GSLTEST(gsl_blas_dgemv(CblasTrans, 1.0, &w, &m->output->dx, 0.0, &m->input->dx)); //bprop to dw gsl_matrix_set_zero(&dw); TNN_MACRO_GSLTEST(gsl_blas_dger(1.0, &m->output->dx, &m->input->x, &dw)); return TNN_ERROR_SUCCESS; }
tnn_error tnn_module_clone_sum(tnn_module *m1, tnn_module *m2, tnn_param *p, tnn_pstable *t){ tnn_error ret; tnn_module_sum *c; tnn_state *s1, *s2, **s; UT_icd sarray_icd = {sizeof(tnn_state*), NULL, NULL, NULL}; size_t i; //Routine check if(m1->t != TNN_MODULE_TYPE_SUM){ return TNN_ERROR_MODULE_MISTYPE; } //Retrieve input and output TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->input, &m2->input), ret); TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->output, &m2->output), ret); if(m1->input->size != m2->input->size || m1->output->size != m2->output->size){ return TNN_ERROR_STATE_INCOMP; } //Defined type m2->t = TNN_MODULE_TYPE_SUM; //Constant paramter is a new tnn_module_sum c = (void*)malloc(sizeof(tnn_module_sum)); m2->c = c; //Allocate the state tnn_state_init(&m2->w, 0L); //Find the sub states utarray_new(c->sarray, &sarray_icd); for(i = 0; i < utarray_len(((tnn_module_sum*)m1->c)->sarray); i = i + 1){ s2 = (tnn_state *) malloc(sizeof(tnn_state)); if(s2 == NULL){ return TNN_ERROR_ALLOC; } s = (tnn_state **)utarray_eltptr(((tnn_module_sum*)m1->c)->sarray, i); s1 = *s; TNN_MACRO_ERRORTEST(tnn_pstable_find(t, s1, &s2), ret); if (s1->size != s2->size){ return TNN_ERROR_STATE_INCOMP; } utarray_push_back(c->sarray, &s2); } //Store the functions m2->bprop = &tnn_module_bprop_sum; m2->fprop = &tnn_module_fprop_sum; m2->randomize = &tnn_module_randomize_sum; m2->destroy = &tnn_module_destroy_sum; m2->debug = &tnn_module_debug_sum; m2->clone = &tnn_module_clone_sum; return TNN_ERROR_SUCCESS; }
//Initialize a trainer to be nsgd trainer. The lset is managed by the trainer tnn_error tnn_trainer_class_init_nsgd(tnn_trainer_class *t, size_t ninput, size_t noutput, gsl_matrix *lset, double lambda, double eta, double epsilon, size_t eiter, size_t niter){ tnn_error ret; //Check the paramters if(lambda < 0 || eta < 0 || epsilon < 0){ return TNN_ERROR_TRAINER_CLASS_NVALIDP; } if(eiter < 1){ eiter = 1; } if(niter < 1 && epsilon == 0){ return TNN_ERROR_TRAINER_CLASS_NVALIDP; } //Defined type t->t = TNN_TRAINER_CLASS_TYPE_NSGD; //Constant paramters t->c = (tnn_trainer_class_nsgd *) malloc(sizeof(tnn_trainer_class_nsgd)); ((tnn_trainer_class_nsgd*)t->c)->eta = eta; ((tnn_trainer_class_nsgd*)t->c)->epsilon = epsilon; ((tnn_trainer_class_nsgd*)t->c)->eiter = eiter; ((tnn_trainer_class_nsgd*)t->c)->niter = niter; ((tnn_trainer_class_nsgd*)t->c)->titer = 0; //lset t->lset = lset; //Losses t->losses = gsl_vector_alloc(t->lset->size1); //Initialize the machine TNN_MACRO_ERRORTEST(tnn_machine_init(&t->m, ninput, noutput),ret); //Initialize the label t->label = (tnn_state *) malloc(sizeof(tnn_state)); if(t->label == NULL){ return TNN_ERROR_ALLOC; } TNN_MACRO_ERRORTEST(tnn_state_init(t->label, noutput),ret); TNN_MACRO_ERRORTEST(tnn_machine_state_alloc(&t->m, t->label),ret); //Initialize the regularization parameter t->lambda = lambda; //Initialize methods t->learn = tnn_trainer_class_learn_nsgd; t->train = tnn_trainer_class_train_nsgd; t->debug = tnn_trainer_class_debug_nsgd; t->destroy = tnn_trainer_class_destroy_nsgd; return TNN_ERROR_SUCCESS; }
tnn_error tnn_module_init_bias(tnn_module *m, tnn_state *input, tnn_state *output, tnn_param *p){ tnn_error ret; //Check whether the output state has the same size with the input if(input->size != output->size){ return TNN_ERROR_STATE_INCOMP; } //Define type m->t = TNN_MODULE_TYPE_BIAS; //No consant parameters m->c = NULL; //Allocate the paramter states tnn_state_init(&m->w, input->size); TNN_MACRO_ERRORTEST(tnn_param_state_alloc(p, &m->w), ret); //Link the inputs and outputs m->input = input; m->output = output; //Store the functions m->bprop = &tnn_module_bprop_bias; m->fprop = &tnn_module_fprop_bias; m->randomize = &tnn_module_randomize_bias; m->destroy = &tnn_module_destroy_bias; m->clone = &tnn_module_clone_bias; m->debug = &tnn_module_debug_bias; return TNN_ERROR_SUCCESS; }
tnn_error tnn_module_init_linear(tnn_module *m, tnn_state *input, tnn_state *output, tnn_param *p){ tnn_error ret; //Defined type m->t = TNN_MODULE_TYPE_LINEAR; //No constant paramters m->c = NULL; //Allocate the parameter states tnn_state_init(&m->w, input->size*output->size); TNN_MACRO_ERRORTEST(tnn_param_state_alloc(p,&m->w), ret); //Link the inputs and outputs m->input = input; m->output = output; //Store the functions m->bprop = &tnn_module_bprop_linear; m->fprop = &tnn_module_fprop_linear; m->randomize = &tnn_module_randomize_linear; m->destroy = &tnn_module_destroy_linear; m->debug = &tnn_module_debug_linear; m->clone = &tnn_module_clone_linear; return TNN_ERROR_SUCCESS; }
tnn_error tnn_module_init_sum(tnn_module *m, tnn_state *input, tnn_state *output, tnn_param *io){ tnn_error ret; tnn_module_sum *c; size_t i; UT_icd sarray_icd = {sizeof(tnn_state*), NULL, NULL, NULL}; tnn_state *t; //Check the sizes and validness if(input->size % output->size != 0){ return TNN_ERROR_STATE_INCOMP; } if(input->valid != true){ return TNN_ERROR_STATE_INVALID; } //Defined type m->t = TNN_MODULE_TYPE_SUM; //Constant parameter is a new tnn_module_sum c = (tnn_module_sum *)malloc(sizeof(tnn_module_sum)); m->c = c; //Allocate sub-states utarray_new(c->sarray, &sarray_icd); if(c->sarray == NULL){ TNN_ERROR_ALLOC; } for(i = 0; i < input->size; i = i + output->size){ //Alloc and initialize the state t = (tnn_state *)malloc(sizeof(tnn_state)); if(t == NULL){ return TNN_ERROR_ALLOC; } t->size = output->size; //Get the substate and store it TNN_MACRO_ERRORTEST(tnn_param_state_sub(io, input, t, i), ret); utarray_push_back(c->sarray, &t); } //Init the state tnn_state_init(&m->w, 0L); //Link the inputs and outputs m->input = input; m->output = output; //Store the functions m->bprop = &tnn_module_bprop_sum; m->fprop = &tnn_module_fprop_sum; m->randomize = &tnn_module_randomize_sum; m->destroy = &tnn_module_destroy_sum; m->clone = &tnn_module_clone_sum; m->debug = &tnn_module_debug_sum; return TNN_ERROR_SUCCESS; }
//Add the loss of the regularizer to the value l tnn_error tnn_reg_addl(tnn_reg *r, gsl_vector *w, double *l){ tnn_error ret; double regl; if(r->l != NULL){ //Check whether the execution is successful TNN_MACRO_ERRORTEST((*r->l)(r, w, ®l),ret); *l = *l + regl; return TNN_ERROR_SUCCESS; } return TNN_ERROR_REG_FUNCNDEF; }
tnn_error tnn_module_clone_bias(tnn_module *m1, tnn_module *m2, tnn_param *p, tnn_pstable *t){ tnn_error ret; //Routine check if(m1->t != TNN_MODULE_TYPE_BIAS){ return TNN_ERROR_MODULE_MISTYPE; } //Retrieve input and output TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->input, &m2->input), ret); TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->output, &m2->output), ret); if(m1->input->size != m2->input->size || m1->output->size != m2->output->size){ return TNN_ERROR_STATE_INCOMP; } //Defined type m2->t = TNN_MODULE_TYPE_BIAS; //No constant paramters m2->c = NULL; //Allocate the parameter states tnn_state_init(&m2->w, m2->input->size); TNN_MACRO_ERRORTEST(tnn_param_state_alloc(p,&m2->w), ret); //Store the functions m2->bprop = &tnn_module_bprop_bias; m2->fprop = &tnn_module_fprop_bias; m2->randomize = &tnn_module_randomize_bias; m2->destroy = &tnn_module_destroy_bias; m2->debug = &tnn_module_debug_bias; m2->clone = &tnn_module_clone_bias; //Copy the state TNN_MACRO_ERRORTEST(tnn_state_copy(&m1->w, &m2->w), ret); return TNN_ERROR_SUCCESS; }
tnn_error tnn_module_fprop_linear(tnn_module *m){ tnn_error ret; gsl_matrix w; //Routine check if(m->t != TNN_MODULE_TYPE_LINEAR){ return TNN_ERROR_MODULE_MISTYPE; } if(m->input->valid != true || m->output->valid != true || m->w.valid != true){ return TNN_ERROR_STATE_INVALID; } //Transform the matrix TNN_MACRO_ERRORTEST(tnn_numeric_v2m(&m->w.x, &w, m->output->size, m->input->size),ret); //Compute the result using BLAS TNN_MACRO_GSLTEST(gsl_blas_dgemv(CblasNoTrans, 1.0, &w, &m->input->x, 0.0, &m->output->x)); return TNN_ERROR_SUCCESS; }
//Learn one sample using naive stochastic gradient descent tnn_error tnn_trainer_class_learn_nsgd(tnn_trainer_class *t, gsl_vector *input, size_t label){ tnn_error ret; tnn_state *sin; tnn_param *p; gsl_vector_view lb; //Routine check if(t->t != TNN_TRAINER_CLASS_TYPE_NSGD){ return TNN_ERROR_TRAINER_CLASS_MISTYPE; } //Check the input and label TNN_MACRO_ERRORTEST(tnn_machine_get_sin(&t->m, &sin),ret); if(label >= t->lset->size1 || input->size != sin->size){ return TNN_ERROR_STATE_INCOMP; } lb = gsl_matrix_row(t->lset, label); //Set the loss output dx to be 1 gsl_vector_set(&t->l.output->dx, 0, 1.0); //Copy the data into the input/label and do forward and backward propagation TNN_MACRO_GSLTEST(gsl_blas_dcopy(input, &sin->x)); TNN_MACRO_GSLTEST(gsl_blas_dcopy(&lb.vector, &t->label->x)); TNN_MACRO_ERRORTEST(tnn_machine_fprop(&t->m), ret); TNN_MACRO_ERRORTEST(tnn_loss_fprop(&t->l), ret); TNN_MACRO_ERRORTEST(tnn_loss_bprop(&t->l), ret); TNN_MACRO_ERRORTEST(tnn_machine_bprop(&t->m), ret); //Compute the accumulated regularization paramter TNN_MACRO_ERRORTEST(tnn_machine_get_param(&t->m, &p), ret); TNN_MACRO_ERRORTEST(tnn_reg_addd(&t->r, p->x, p->dx, t->lambda), ret); //Compute the parameter update TNN_MACRO_GSLTEST(gsl_blas_daxpy(-((tnn_trainer_class_nsgd*)t->c)->eta, p->dx, p->x)); //Set the titer parameter ((tnn_trainer_class_nsgd*)t->c)->titer = 1; return TNN_ERROR_SUCCESS; }
//Train all the samples using naive stochastic gradient descent tnn_error tnn_trainer_class_train_nsgd(tnn_trainer_class *t, gsl_matrix *inputs, size_t *labels){ tnn_error ret; tnn_state *sin; tnn_param *p; gsl_vector *rd; gsl_vector *pw; gsl_vector_view in; gsl_vector_view lb; double eps; size_t i,j; //Routine check if(t->t != TNN_TRAINER_CLASS_TYPE_NSGD){ return TNN_ERROR_TRAINER_CLASS_MISTYPE; } //Check the input TNN_MACRO_ERRORTEST(tnn_machine_get_sin(&t->m, &sin),ret); if(inputs->size2 != sin->size){ return TNN_ERROR_STATE_INCOMP; } //Set the loss output dx to be 1 gsl_vector_set(&t->l.output->dx, 0, 1.0); //Get the parameter and allocate rd and pw TNN_MACRO_ERRORTEST(tnn_machine_get_param(&t->m, &p), ret); rd = gsl_vector_alloc(p->size); pw = gsl_vector_alloc(p->size); if(rd == NULL || pw == NULL){ return TNN_ERROR_GSL; } //Into the main loop for(eps = DBL_MAX, ((tnn_trainer_class_nsgd*)t->c)->titer = 0; eps > ((tnn_trainer_class_nsgd*)t->c)->epsilon && ((tnn_trainer_class_nsgd*)t->c)->titer < ((tnn_trainer_class_nsgd*)t->c)->niter; ((tnn_trainer_class_nsgd*)t->c)->titer = ((tnn_trainer_class_nsgd*)t->c)->titer + ((tnn_trainer_class_nsgd*)t->c)->eiter){ //Copy the previous pw TNN_MACRO_GSLTEST(gsl_blas_dcopy(p->x, pw)); for(i = 0; i < ((tnn_trainer_class_nsgd*)t->c)->eiter; i = i + 1){ j = (((tnn_trainer_class_nsgd*)t->c)->titer + i)%inputs->size1; //Check the label if(labels[j] >= t->lset->size1){ return TNN_ERROR_STATE_INCOMP; } //Get the inputs and label vector lb = gsl_matrix_row(t->lset, labels[j]); in = gsl_matrix_row(inputs, j); //Copy the data into the input/label and do forward and backward propagation TNN_MACRO_GSLTEST(gsl_blas_dcopy(&in.vector, &sin->x)); TNN_MACRO_GSLTEST(gsl_blas_dcopy(&lb.vector, &t->label->x)); TNN_MACRO_ERRORTEST(tnn_machine_fprop(&t->m), ret); TNN_MACRO_ERRORTEST(tnn_loss_fprop(&t->l), ret); TNN_MACRO_ERRORTEST(tnn_loss_bprop(&t->l), ret); TNN_MACRO_ERRORTEST(tnn_machine_bprop(&t->m), ret); //Compute the accumulated regularization paramter TNN_MACRO_ERRORTEST(tnn_reg_d(&t->r, p->x, rd), ret); TNN_MACRO_GSLTEST(gsl_blas_daxpy(t->lambda, rd, p->dx)); //Compute the parameter update TNN_MACRO_GSLTEST(gsl_blas_daxpy(-((tnn_trainer_class_nsgd*)t->c)->eta, p->dx, p->x)); } //Compute the 2 square norm of difference of p as eps TNN_MACRO_GSLTEST(gsl_blas_daxpy(-1.0, p->x, pw)); eps = gsl_blas_dnrm2(pw); } return TNN_ERROR_SUCCESS; }