Exemple #1
0
int* SENNA_POS_forward(SENNA_POS *pos, const int *sentence_words, const int *sentence_caps, const int *sentence_suff, int sentence_size)
{
  int idx;
  //@AureDi sentence_size+pos->window_size-1: broad convolution			(pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size): feature length
  pos->input_state = SENNA_realloc(pos->input_state, sizeof(float), (sentence_size+pos->window_size-1)*(pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size));
  pos->output_state = SENNA_realloc(pos->output_state, sizeof(float), sentence_size*pos->output_state_size);
  
  SENNA_nn_lookup(pos->input_state,                                     pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size, pos->ll_word_weight, pos->ll_word_size, pos->ll_word_max_idx, sentence_words, sentence_size, pos->ll_word_padding_idx, (pos->window_size-1)/2);
  SENNA_nn_lookup(pos->input_state+pos->ll_word_size,                   pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size, pos->ll_caps_weight, pos->ll_caps_size, pos->ll_caps_max_idx, sentence_caps,  sentence_size, pos->ll_caps_padding_idx, (pos->window_size-1)/2);
  SENNA_nn_lookup(pos->input_state+pos->ll_word_size+pos->ll_caps_size, pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size, pos->ll_suff_weight, pos->ll_suff_size, pos->ll_suff_max_idx, sentence_suff,  sentence_size, pos->ll_suff_padding_idx, (pos->window_size-1)/2);

  for(idx = 0; idx < sentence_size; idx++)
  {
    SENNA_nn_linear(pos->hidden_state, pos->hidden_state_size, pos->l1_weight, pos->l1_bias, pos->input_state+idx*(pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size), pos->window_size*(pos->ll_word_size+pos->ll_caps_size+pos->ll_suff_size));
    SENNA_nn_hardtanh(pos->hidden_state, pos->hidden_state, pos->hidden_state_size);
    SENNA_nn_linear(pos->output_state+idx*pos->output_state_size, pos->output_state_size, pos->l2_weight, pos->l2_bias, pos->hidden_state, pos->hidden_state_size);
  }

  pos->labels = SENNA_realloc(pos->labels, sizeof(int), sentence_size);
  SENNA_nn_viterbi(pos->labels, pos->viterbi_score_init, pos->viterbi_score_trans, pos->output_state, pos->output_state_size, sentence_size);

  return pos->labels;
}
Exemple #2
0
void SENNA_PSG_forward(SENNA_PSG *psg, const int *sentence_words, const int *sentence_caps, const int *sentence_posl, int sentence_size,
                       int **labels_, int *n_level_)
{
  int *sentence_psgl = SENNA_malloc(sizeof(int), sentence_size);
  int *sentence_segl = SENNA_malloc(sizeof(int), sentence_size);
  int *start_and_sentence_level_label = SENNA_malloc(sizeof(int), sentence_size+1);
  int t;
  int level;

  for(t = 0; t < sentence_size; t++)
  {
    sentence_psgl[t] = 0;
    sentence_segl[t] = 0;
  }

  psg->input_state = SENNA_realloc(psg->input_state, sizeof(float), sentence_size*psg->input_state_size);
  psg->l1_state = SENNA_realloc(psg->l1_state, sizeof(float), sentence_size*psg->l1_state_size);
  psg->l2_state = SENNA_realloc(psg->l2_state, sizeof(float), sentence_size*psg->l2_state_size);
  psg->l3_state = SENNA_realloc(psg->l3_state, sizeof(float), sentence_size*psg->l3_state_size);
  psg->l4_state = SENNA_realloc(psg->l4_state, sizeof(float), sentence_size*psg->l4_state_size);
  
  SENNA_nn_lookup(psg->input_state,                                                       psg->input_state_size, psg->ll_word_weight, psg->ll_word_size, psg->ll_word_max_idx, sentence_words, sentence_size, 0, 0);
  SENNA_nn_lookup(psg->input_state+psg->ll_word_size,                                     psg->input_state_size, psg->ll_caps_weight, psg->ll_caps_size, psg->ll_caps_max_idx, sentence_caps,  sentence_size, 0, 0);
  SENNA_nn_lookup(psg->input_state+psg->ll_word_size+psg->ll_caps_size,                   psg->input_state_size, psg->ll_posl_weight, psg->ll_posl_size, psg->ll_posl_max_idx, sentence_posl,  sentence_size, 0, 0);

  level = 0;
  while(1)
  {
    int all_tags_are_o;
    int all_in_one_segment;

    SENNA_nn_lookup(psg->input_state+psg->ll_word_size+psg->ll_caps_size+psg->ll_posl_size, psg->input_state_size, psg->ll_psgl_weight, psg->ll_psgl_size, psg->ll_psgl_max_idx, sentence_psgl,  sentence_size, 0, 0);
    
    SENNA_nn_temporal_convolution(psg->l1_state, psg->l1_state_size, psg->l1_weight, psg->l1_bias, psg->input_state, psg->input_state_size, sentence_size, 1);
    SENNA_nn_temporal_max_convolution(psg->l2_state, psg->l2_bias, psg->l1_state, psg->l1_state_size, sentence_size, psg->window_size);
    SENNA_nn_temporal_convolution(psg->l3_state, psg->l3_state_size, psg->l3_weight, psg->l3_bias, psg->l2_state, psg->l1_state_size, sentence_size, 1);
    SENNA_nn_hardtanh(psg->l3_state, psg->l3_state, psg->l3_state_size*sentence_size);
    SENNA_nn_temporal_convolution(psg->l4_state, psg->l4_state_size, psg->l4_weight, psg->l4_bias, psg->l3_state, psg->l3_state_size, sentence_size, 1);
    
    SENNA_Treillis_buildfromscorewithsegmentation(psg->treillis, psg->l4_state, psg->viterbi_score_init, psg->viterbi_score_trans, sentence_segl, psg->l4_state_size, sentence_size);    
    SENNA_Treillis_viterbi(psg->treillis, start_and_sentence_level_label);

    /* update history and segmentation */
    all_tags_are_o = 1;
    for(t = 0; t < sentence_size; t++)
    {
      if(start_and_sentence_level_label[t+1])
      {
        sentence_psgl[t] = start_and_sentence_level_label[t+1]; /* note we always keep if something was there */
        sentence_segl[t] = (start_and_sentence_level_label[t+1]-1)%4+1;
        
        all_tags_are_o = 0;
      }
    }

    /* check if only one big segment */
    if(sentence_size == 1)
      all_in_one_segment = (sentence_segl[0] == SEG_S);
    else
      all_in_one_segment = (sentence_segl[0] == SEG_B) && (sentence_segl[sentence_size-1] == SEG_E);

    for(t = 1; all_in_one_segment && (t < sentence_size-1); t++)
    {
      if(sentence_segl[t] != SEG_I)
        all_in_one_segment = 0;
    }

    level++;

    if(psg->max_labels_size < sentence_size*level)
    {
      psg->labels = SENNA_realloc(psg->labels, sizeof(float), sentence_size*level);
      psg->max_labels_size = sentence_size*level;
    }
    memcpy(psg->labels+(level-1)*sentence_size, start_and_sentence_level_label+1, sizeof(float)*sentence_size);

    if(all_in_one_segment || all_tags_are_o)
      break;
  }

  free(sentence_psgl);
  free(sentence_segl);
  free(start_and_sentence_level_label);

  *labels_ = psg->labels;
  *n_level_ = level;
}