コード例 #1
0
tnn_error tnn_module_clone_sum(tnn_module *m1, tnn_module *m2, tnn_param *p, tnn_pstable *t){
  tnn_error ret;
  tnn_module_sum *c;
  tnn_state *s1, *s2, **s;
  UT_icd sarray_icd = {sizeof(tnn_state*), NULL, NULL, NULL};
  size_t i;

  //Routine check
  if(m1->t != TNN_MODULE_TYPE_SUM){
    return TNN_ERROR_MODULE_MISTYPE;
  }

  //Retrieve input and output
  TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->input, &m2->input), ret);
  TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->output, &m2->output), ret);
  if(m1->input->size != m2->input->size || m1->output->size != m2->output->size){
    return TNN_ERROR_STATE_INCOMP;
  }

  //Defined type
  m2->t = TNN_MODULE_TYPE_SUM;

  //Constant paramter is a new tnn_module_sum
  c = (void*)malloc(sizeof(tnn_module_sum));
  m2->c = c;

  //Allocate the state
  tnn_state_init(&m2->w, 0L);
  
  //Find the sub states
  utarray_new(c->sarray, &sarray_icd);
  for(i = 0; i < utarray_len(((tnn_module_sum*)m1->c)->sarray); i = i + 1){
    s2 = (tnn_state *) malloc(sizeof(tnn_state));
    if(s2 == NULL){
      return TNN_ERROR_ALLOC;
    }
    s = (tnn_state **)utarray_eltptr(((tnn_module_sum*)m1->c)->sarray, i);
    s1 = *s;
    TNN_MACRO_ERRORTEST(tnn_pstable_find(t, s1, &s2), ret);
    if (s1->size != s2->size){
      return TNN_ERROR_STATE_INCOMP;
    }
    utarray_push_back(c->sarray, &s2);
  }

  //Store the functions
  m2->bprop = &tnn_module_bprop_sum;
  m2->fprop = &tnn_module_fprop_sum;
  m2->randomize = &tnn_module_randomize_sum;
  m2->destroy = &tnn_module_destroy_sum;
  m2->debug = &tnn_module_debug_sum;
  m2->clone = &tnn_module_clone_sum;

  return TNN_ERROR_SUCCESS;
}
コード例 #2
0
tnn_error tnn_module_clone_bias(tnn_module *m1, tnn_module *m2, tnn_param *p, tnn_pstable *t){
  tnn_error ret;

  //Routine check
  if(m1->t != TNN_MODULE_TYPE_BIAS){
    return TNN_ERROR_MODULE_MISTYPE;
  }

  //Retrieve input and output
  TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->input, &m2->input), ret);
  TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->output, &m2->output), ret);
  if(m1->input->size != m2->input->size || m1->output->size != m2->output->size){
    return TNN_ERROR_STATE_INCOMP;
  }

  //Defined type
  m2->t = TNN_MODULE_TYPE_BIAS;

  //No constant paramters
  m2->c = NULL;
  
  //Allocate the parameter states
  tnn_state_init(&m2->w, m2->input->size);
  TNN_MACRO_ERRORTEST(tnn_param_state_alloc(p,&m2->w), ret);

  //Store the functions
  m2->bprop = &tnn_module_bprop_bias;
  m2->fprop = &tnn_module_fprop_bias;
  m2->randomize = &tnn_module_randomize_bias;
  m2->destroy = &tnn_module_destroy_bias;
  m2->debug = &tnn_module_debug_bias;
  m2->clone = &tnn_module_clone_bias;

  //Copy the state
  TNN_MACRO_ERRORTEST(tnn_state_copy(&m1->w, &m2->w), ret);

  return TNN_ERROR_SUCCESS;
}
コード例 #3
0
int main(){
  int i;
  size_t l;
  tnn_state *key;
  tnn_state *s;
  tnn_param p,q;
  tnn_pstable t,tp;

  printf("Initializing t: %s\n", TEST_FUNC(tnn_pstable_init(&t)));

  //Adding dummy variables
  key = (tnn_state *)A;
  s = (tnn_state *)B;
  for(i = 0; i < 4; i = i + 1){
    printf("Inserting key=%p, s=%p: %s\n", key, s, TEST_FUNC(tnn_pstable_add(&t, key, s)));
    key = key + 1;
    s = s - 1;
  }
  key = key - 1;
  s = s + 1;
  printf("Inserting key=%p, s=%p: %s\n", key, s, TEST_FUNC(tnn_pstable_add(&t, key, s)));

  //Debug
  printf("Debugging t: %s\n", TEST_FUNC(tnn_pstable_debug(&t)));

  //Delete a key
  key = (tnn_state *)A;
  key = key + C;
  printf("Deleting key=%p: %s\n", key, TEST_FUNC(tnn_pstable_delete(&t, key)));
  printf("Debugging t: %s\n", TEST_FUNC(tnn_pstable_debug(&t)));
  key = (tnn_state *)A;
  key = key - D;
  printf("Deleting key=%p: %s\n", key, TEST_FUNC(tnn_pstable_delete(&t, key)));
  printf("Debugging t: %s\n", TEST_FUNC(tnn_pstable_debug(&t)));

  //Find a key
  key = (tnn_state *)A;
  key = key + E;
  printf("Finding key=%p, s=%p: %s\n", key, s, TEST_FUNC(tnn_pstable_find(&t, key, &s)));
  key = (tnn_state *)A;
  key = key + C;
  printf("Finding key=%p, s=%p: %s\n", key, s, TEST_FUNC(tnn_pstable_find(&t, key, &s)));

  //Get the length
  printf("Getting the length=%ld: %s\n", l, TEST_FUNC(tnn_pstable_get_length(&t, &l)));

  //Destroy the table
  printf("Destroying t: %s\n", TEST_FUNC(tnn_pstable_destroy(&t)));
  printf("Debugging t: %s\n", TEST_FUNC(tnn_pstable_debug(&t)));

  //Initialize the two paramters
  printf("Initializing p: %s\n", TEST_FUNC(tnn_param_init(&p)));
  printf("Initializing q: %s\n", TEST_FUNC(tnn_param_init(&q)));

  //Allocating states in q
  for(i = 0; i < N; i = i + 1){
    s = (tnn_state *) malloc(sizeof(tnn_state));
    printf("Initializing s %d: %s\n", i, TEST_FUNC(tnn_state_init(s,i+F)));
    printf("Allocating s in p: %s\n", TEST_FUNC(tnn_param_state_calloc(&p,s)));
  }
  printf("Debugging p: %s\n", TEST_FUNC(tnn_param_debug(&p)));
  printf("Initialize q: %s\n", TEST_FUNC(tnn_param_init(&q)));
  printf("Initialize tp: %s\n", TEST_FUNC(tnn_pstable_init(&tp)));
  printf("Copying p to q: %s\n", TEST_FUNC(tnn_pstable_param_alloc(&tp,&p,&q)));
  printf("Debugging tp: %s\n", TEST_FUNC(tnn_pstable_debug(&tp)));
  printf("Debugging q: %s\n", TEST_FUNC(tnn_param_debug(&q)));

  return 0;
}