int* SENNA_POS_forward(SENNA_POS *pos, const int *sentence_words, const int *sentence_caps, const int *sentence_suff, int sentence_size) { int idx; //@AureDi sentence_size+pos->window_size-1: broad convolution (pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size): feature length pos->input_state = SENNA_realloc(pos->input_state, sizeof(float), (sentence_size+pos->window_size-1)*(pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size)); pos->output_state = SENNA_realloc(pos->output_state, sizeof(float), sentence_size*pos->output_state_size); SENNA_nn_lookup(pos->input_state, pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size, pos->ll_word_weight, pos->ll_word_size, pos->ll_word_max_idx, sentence_words, sentence_size, pos->ll_word_padding_idx, (pos->window_size-1)/2); SENNA_nn_lookup(pos->input_state+pos->ll_word_size, pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size, pos->ll_caps_weight, pos->ll_caps_size, pos->ll_caps_max_idx, sentence_caps, sentence_size, pos->ll_caps_padding_idx, (pos->window_size-1)/2); SENNA_nn_lookup(pos->input_state+pos->ll_word_size+pos->ll_caps_size, pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size, pos->ll_suff_weight, pos->ll_suff_size, pos->ll_suff_max_idx, sentence_suff, sentence_size, pos->ll_suff_padding_idx, (pos->window_size-1)/2); for(idx = 0; idx < sentence_size; idx++) { SENNA_nn_linear(pos->hidden_state, pos->hidden_state_size, pos->l1_weight, pos->l1_bias, pos->input_state+idx*(pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size), pos->window_size*(pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size)); SENNA_nn_hardtanh(pos->hidden_state, pos->hidden_state, pos->hidden_state_size); SENNA_nn_linear(pos->output_state+idx*pos->output_state_size, pos->output_state_size, pos->l2_weight, pos->l2_bias, pos->hidden_state, pos->hidden_state_size); } pos->labels = SENNA_realloc(pos->labels, sizeof(int), sentence_size); SENNA_nn_viterbi(pos->labels, pos->viterbi_score_init, pos->viterbi_score_trans, pos->output_state, pos->output_state_size, sentence_size); return pos->labels; }
void SENNA_PSG_forward(SENNA_PSG *psg, const int *sentence_words, const int *sentence_caps, const int *sentence_posl, int sentence_size, int **labels_, int *n_level_) { int *sentence_psgl = SENNA_malloc(sizeof(int), sentence_size); int *sentence_segl = SENNA_malloc(sizeof(int), sentence_size); int *start_and_sentence_level_label = SENNA_malloc(sizeof(int), sentence_size+1); int t; int level; for(t = 0; t < sentence_size; t++) { sentence_psgl[t] = 0; sentence_segl[t] = 0; } psg->input_state = SENNA_realloc(psg->input_state, sizeof(float), sentence_size*psg->input_state_size); psg->l1_state = SENNA_realloc(psg->l1_state, sizeof(float), sentence_size*psg->l1_state_size); psg->l2_state = SENNA_realloc(psg->l2_state, sizeof(float), sentence_size*psg->l2_state_size); psg->l3_state = SENNA_realloc(psg->l3_state, sizeof(float), sentence_size*psg->l3_state_size); psg->l4_state = SENNA_realloc(psg->l4_state, sizeof(float), sentence_size*psg->l4_state_size); SENNA_nn_lookup(psg->input_state, psg->input_state_size, psg->ll_word_weight, psg->ll_word_size, psg->ll_word_max_idx, sentence_words, sentence_size, 0, 0); SENNA_nn_lookup(psg->input_state+psg->ll_word_size, psg->input_state_size, psg->ll_caps_weight, psg->ll_caps_size, psg->ll_caps_max_idx, sentence_caps, sentence_size, 0, 0); SENNA_nn_lookup(psg->input_state+psg->ll_word_size+psg->ll_caps_size, psg->input_state_size, psg->ll_posl_weight, psg->ll_posl_size, psg->ll_posl_max_idx, sentence_posl, sentence_size, 0, 0); level = 0; while(1) { int all_tags_are_o; int all_in_one_segment; SENNA_nn_lookup(psg->input_state+psg->ll_word_size+psg->ll_caps_size+psg->ll_posl_size, psg->input_state_size, psg->ll_psgl_weight, psg->ll_psgl_size, psg->ll_psgl_max_idx, sentence_psgl, sentence_size, 0, 0); SENNA_nn_temporal_convolution(psg->l1_state, psg->l1_state_size, psg->l1_weight, psg->l1_bias, psg->input_state, psg->input_state_size, sentence_size, 1); SENNA_nn_temporal_max_convolution(psg->l2_state, psg->l2_bias, psg->l1_state, psg->l1_state_size, sentence_size, psg->window_size); SENNA_nn_temporal_convolution(psg->l3_state, psg->l3_state_size, psg->l3_weight, psg->l3_bias, psg->l2_state, psg->l1_state_size, sentence_size, 1); SENNA_nn_hardtanh(psg->l3_state, psg->l3_state, psg->l3_state_size*sentence_size); SENNA_nn_temporal_convolution(psg->l4_state, psg->l4_state_size, psg->l4_weight, psg->l4_bias, psg->l3_state, psg->l3_state_size, sentence_size, 1); SENNA_Treillis_buildfromscorewithsegmentation(psg->treillis, psg->l4_state, psg->viterbi_score_init, psg->viterbi_score_trans, sentence_segl, psg->l4_state_size, sentence_size); SENNA_Treillis_viterbi(psg->treillis, start_and_sentence_level_label); /* update history and segmentation */ all_tags_are_o = 1; for(t = 0; t < sentence_size; t++) { if(start_and_sentence_level_label[t+1]) { sentence_psgl[t] = start_and_sentence_level_label[t+1]; /* note we always keep if something was there */ sentence_segl[t] = (start_and_sentence_level_label[t+1]-1)%4+1; all_tags_are_o = 0; } } /* check if only one big segment */ if(sentence_size == 1) all_in_one_segment = (sentence_segl[0] == SEG_S); else all_in_one_segment = (sentence_segl[0] == SEG_B) && (sentence_segl[sentence_size-1] == SEG_E); for(t = 1; all_in_one_segment && (t < sentence_size-1); t++) { if(sentence_segl[t] != SEG_I) all_in_one_segment = 0; } level++; if(psg->max_labels_size < sentence_size*level) { psg->labels = SENNA_realloc(psg->labels, sizeof(float), sentence_size*level); psg->max_labels_size = sentence_size*level; } memcpy(psg->labels+(level-1)*sentence_size, start_and_sentence_level_label+1, sizeof(float)*sentence_size); if(all_in_one_segment || all_tags_are_o) break; } free(sentence_psgl); free(sentence_segl); free(start_and_sentence_level_label); *labels_ = psg->labels; *n_level_ = level; }