tnn_error tnn_module_clone_sum(tnn_module *m1, tnn_module *m2, tnn_param *p, tnn_pstable *t){ tnn_error ret; tnn_module_sum *c; tnn_state *s1, *s2, **s; UT_icd sarray_icd = {sizeof(tnn_state*), NULL, NULL, NULL}; size_t i; //Routine check if(m1->t != TNN_MODULE_TYPE_SUM){ return TNN_ERROR_MODULE_MISTYPE; } //Retrieve input and output TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->input, &m2->input), ret); TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->output, &m2->output), ret); if(m1->input->size != m2->input->size || m1->output->size != m2->output->size){ return TNN_ERROR_STATE_INCOMP; } //Defined type m2->t = TNN_MODULE_TYPE_SUM; //Constant paramter is a new tnn_module_sum c = (void*)malloc(sizeof(tnn_module_sum)); m2->c = c; //Allocate the state tnn_state_init(&m2->w, 0L); //Find the sub states utarray_new(c->sarray, &sarray_icd); for(i = 0; i < utarray_len(((tnn_module_sum*)m1->c)->sarray); i = i + 1){ s2 = (tnn_state *) malloc(sizeof(tnn_state)); if(s2 == NULL){ return TNN_ERROR_ALLOC; } s = (tnn_state **)utarray_eltptr(((tnn_module_sum*)m1->c)->sarray, i); s1 = *s; TNN_MACRO_ERRORTEST(tnn_pstable_find(t, s1, &s2), ret); if (s1->size != s2->size){ return TNN_ERROR_STATE_INCOMP; } utarray_push_back(c->sarray, &s2); } //Store the functions m2->bprop = &tnn_module_bprop_sum; m2->fprop = &tnn_module_fprop_sum; m2->randomize = &tnn_module_randomize_sum; m2->destroy = &tnn_module_destroy_sum; m2->debug = &tnn_module_debug_sum; m2->clone = &tnn_module_clone_sum; return TNN_ERROR_SUCCESS; }
tnn_error tnn_module_clone_bias(tnn_module *m1, tnn_module *m2, tnn_param *p, tnn_pstable *t){ tnn_error ret; //Routine check if(m1->t != TNN_MODULE_TYPE_BIAS){ return TNN_ERROR_MODULE_MISTYPE; } //Retrieve input and output TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->input, &m2->input), ret); TNN_MACRO_ERRORTEST(tnn_pstable_find(t, m1->output, &m2->output), ret); if(m1->input->size != m2->input->size || m1->output->size != m2->output->size){ return TNN_ERROR_STATE_INCOMP; } //Defined type m2->t = TNN_MODULE_TYPE_BIAS; //No constant paramters m2->c = NULL; //Allocate the parameter states tnn_state_init(&m2->w, m2->input->size); TNN_MACRO_ERRORTEST(tnn_param_state_alloc(p,&m2->w), ret); //Store the functions m2->bprop = &tnn_module_bprop_bias; m2->fprop = &tnn_module_fprop_bias; m2->randomize = &tnn_module_randomize_bias; m2->destroy = &tnn_module_destroy_bias; m2->debug = &tnn_module_debug_bias; m2->clone = &tnn_module_clone_bias; //Copy the state TNN_MACRO_ERRORTEST(tnn_state_copy(&m1->w, &m2->w), ret); return TNN_ERROR_SUCCESS; }
int main(){ int i; size_t l; tnn_state *key; tnn_state *s; tnn_param p,q; tnn_pstable t,tp; printf("Initializing t: %s\n", TEST_FUNC(tnn_pstable_init(&t))); //Adding dummy variables key = (tnn_state *)A; s = (tnn_state *)B; for(i = 0; i < 4; i = i + 1){ printf("Inserting key=%p, s=%p: %s\n", key, s, TEST_FUNC(tnn_pstable_add(&t, key, s))); key = key + 1; s = s - 1; } key = key - 1; s = s + 1; printf("Inserting key=%p, s=%p: %s\n", key, s, TEST_FUNC(tnn_pstable_add(&t, key, s))); //Debug printf("Debugging t: %s\n", TEST_FUNC(tnn_pstable_debug(&t))); //Delete a key key = (tnn_state *)A; key = key + C; printf("Deleting key=%p: %s\n", key, TEST_FUNC(tnn_pstable_delete(&t, key))); printf("Debugging t: %s\n", TEST_FUNC(tnn_pstable_debug(&t))); key = (tnn_state *)A; key = key - D; printf("Deleting key=%p: %s\n", key, TEST_FUNC(tnn_pstable_delete(&t, key))); printf("Debugging t: %s\n", TEST_FUNC(tnn_pstable_debug(&t))); //Find a key key = (tnn_state *)A; key = key + E; printf("Finding key=%p, s=%p: %s\n", key, s, TEST_FUNC(tnn_pstable_find(&t, key, &s))); key = (tnn_state *)A; key = key + C; printf("Finding key=%p, s=%p: %s\n", key, s, TEST_FUNC(tnn_pstable_find(&t, key, &s))); //Get the length printf("Getting the length=%ld: %s\n", l, TEST_FUNC(tnn_pstable_get_length(&t, &l))); //Destroy the table printf("Destroying t: %s\n", TEST_FUNC(tnn_pstable_destroy(&t))); printf("Debugging t: %s\n", TEST_FUNC(tnn_pstable_debug(&t))); //Initialize the two paramters printf("Initializing p: %s\n", TEST_FUNC(tnn_param_init(&p))); printf("Initializing q: %s\n", TEST_FUNC(tnn_param_init(&q))); //Allocating states in q for(i = 0; i < N; i = i + 1){ s = (tnn_state *) malloc(sizeof(tnn_state)); printf("Initializing s %d: %s\n", i, TEST_FUNC(tnn_state_init(s,i+F))); printf("Allocating s in p: %s\n", TEST_FUNC(tnn_param_state_calloc(&p,s))); } printf("Debugging p: %s\n", TEST_FUNC(tnn_param_debug(&p))); printf("Initialize q: %s\n", TEST_FUNC(tnn_param_init(&q))); printf("Initialize tp: %s\n", TEST_FUNC(tnn_pstable_init(&tp))); printf("Copying p to q: %s\n", TEST_FUNC(tnn_pstable_param_alloc(&tp,&p,&q))); printf("Debugging tp: %s\n", TEST_FUNC(tnn_pstable_debug(&tp))); printf("Debugging q: %s\n", TEST_FUNC(tnn_param_debug(&q))); return 0; }