tnn_error tnn_module_init_sum(tnn_module *m, tnn_state *input, tnn_state *output, tnn_param *io){
  tnn_error ret;
  tnn_module_sum *c;
  size_t i;
  UT_icd sarray_icd = {sizeof(tnn_state*), NULL, NULL, NULL};
  tnn_state *t;

  //Check the sizes and validness
  if(input->size % output->size != 0){
    return TNN_ERROR_STATE_INCOMP;
  }
  if(input->valid != true){
    return TNN_ERROR_STATE_INVALID;
  }

  //Defined type
  m->t = TNN_MODULE_TYPE_SUM;

  //Constant parameter is a new tnn_module_sum
  c = (tnn_module_sum *)malloc(sizeof(tnn_module_sum));
  m->c = c;

  //Allocate sub-states
  utarray_new(c->sarray, &sarray_icd);
  if(c->sarray == NULL){
    TNN_ERROR_ALLOC;
  }
  for(i = 0; i < input->size; i = i + output->size){
    //Alloc and initialize the state
    t = (tnn_state *)malloc(sizeof(tnn_state));
    if(t == NULL){
      return TNN_ERROR_ALLOC;
    }
    t->size = output->size;

    //Get the substate and store it
    TNN_MACRO_ERRORTEST(tnn_param_state_sub(io, input, t, i), ret);
    utarray_push_back(c->sarray, &t);
  }

  //Init the state
  tnn_state_init(&m->w, 0L);

  //Link the inputs and outputs
  m->input = input;
  m->output = output;

  //Store the functions
  m->bprop = &tnn_module_bprop_sum;
  m->fprop = &tnn_module_fprop_sum;
  m->randomize = &tnn_module_randomize_sum;
  m->destroy = &tnn_module_destroy_sum;
  m->clone = &tnn_module_clone_sum;
  m->debug = &tnn_module_debug_sum;
  
  return TNN_ERROR_SUCCESS;
}
Example #2
0
int main(){
  tnn_state s[N];
  tnn_state t, u, v;
  tnn_param p;
  tnn_state *sp;
  int i,j;

  //Initialize all the states
  for(i = 0; i < N; i = i + 1){
    printf("Initializing state %d: %s\n", i, TEST_FUNC(tnn_state_init(&s[i], i + M)));
    printf("Initialized state %d to be size %d, valid %d.\n", i, s[i].size, s[i].valid);
  }

  //Initialize the parameter
  printf("Initializing parameter: %s\n", TEST_FUNC(tnn_param_init(&p)));
  tnn_param_debug(&p);

  //Allocate states in the paramter
  for(i = 0; i < N; i = i + 1){
    printf("Allocating state %d: %s\n", i, TEST_FUNC(tnn_param_state_alloc(&p, &s[i])));
    printf("Allocated state %d to be valid %d, vector owner %d.\n", i, s[i].valid, s[i].x.owner);
    tnn_param_debug(&p);
  }

  //Initialize the values in the vector
  for(i = 0; i < p.x->size; i = i + 1){
    gsl_vector_set(p.x, i, i + A);
    gsl_vector_set(p.dx, i, i + B);
  }
  tnn_param_debug(&p);

  //Initialize values for t using calloc
  printf("Initializing state t: %s\n", TEST_FUNC(tnn_state_init(&t, I)));
  printf("Initialized state t to be size %d, valid %d.\n", t.size, t.valid);
  printf("Allocating state t: %s\n", TEST_FUNC(tnn_param_state_calloc(&p, &t)));
  printf("Allocated state t to be valid %d, vector owner %d.\n", t.valid, t.x.owner);
  tnn_param_debug(&p);

  printf("Initializing state u: %s\n", TEST_FUNC(tnn_state_init(&u, J)));
  printf("Getting subvector u: %s\n", TEST_FUNC(tnn_param_state_sub(&p, &s[N-1], &u, 1)));
  tnn_param_debug(&p);

  printf("Initializing state v: %s\n", TEST_FUNC(tnn_state_init(&v, s[N-1].size)));
  tnn_state_debug(&v);
  printf("Allocating state v: %s\n", TEST_FUNC(tnn_param_state_calloc(&p, &v)));
  printf("Copying s to v: %s\n", TEST_FUNC(tnn_state_copy(&s[N-1],&v)));
  tnn_state_debug(&v);
  tnn_param_debug(&p);

  //Destroy the parameter
  printf("Destroying the paramter: %s.\n", TEST_FUNC(tnn_param_destroy(&p)));
  printf("Destroyed paramter x = %d, dx = %d, states = %d, size = %d.\n", p.x, p.dx, p.states, p.size);
  for(i = 0; i < N; i = i + 1){
    printf("Destroyed state %d, valid: %d\n", i, s[i].valid);
  }
  printf("Destroyed state t, valid: %d\n", t.valid);
  printf("Destroyed state u, valid: %d\n", u.valid);
  tnn_param_debug(&p);

  //Delete paramter experiment
  printf("Initializing paramter: %s\n", TEST_FUNC(tnn_param_init(&p)));
  for(i = 0; i < N; i = i + 1){
    sp = (tnn_state *) malloc(sizeof(tnn_state));
    printf("Initializing sp %p: %s\n", sp, TEST_FUNC(tnn_state_init(sp,A+i)));
    printf("Allocating sp %p: %s\n", sp, TEST_FUNC(tnn_param_state_calloc(&p, sp)));
  }
  printf("Debugging paramter: %s\n", TEST_FUNC(tnn_param_debug(&p)));
  printf("Free paramter: %s\n", TEST_FUNC(tnn_param_free(&p)));
  printf("Debugging paramter: %s\n", TEST_FUNC(tnn_param_debug(&p)));
}