THNETWORK *THLoadNetwork(const char *path) { char tmppath[255]; int i; THNETWORK *net = calloc(1, sizeof(*net)); sprintf(tmppath, "%s/model.net", path); net->netobj = malloc(sizeof(*net->netobj)); lasterror = loadtorch(tmppath, net->netobj, longsize); if(lasterror) { free(net->netobj); free(net); return 0; } //printobject(net->netobj, 0); net->net = Object2Network(net->netobj); if(!net->net) { lasterror = ERR_WRONGOBJECT; freeobject(net->netobj); free(net->netobj); free(net); return 0; } sprintf(tmppath, "%s/stat.t7", path); net->statobj = malloc(sizeof(*net->statobj)); lasterror = loadtorch(tmppath, net->statobj, longsize); if(lasterror) { free(net->statobj); freenetwork(net->net); freeobject(net->netobj); free(net->netobj); free(net); return 0; } if(net->statobj->type != TYPE_TABLE || net->statobj->table->nelem != 2) { lasterror = ERR_WRONGOBJECT; freenetwork(net->net); freeobject(net->netobj); free(net->netobj); freeobject(net->statobj); free(net->statobj); free(net); } net->std[0] = net->std[1] = net->std[2] = 1; net->mean[0] = net->mean[1] = net->mean[2] = 0; for(i = 0; i < net->statobj->table->nelem; i++) if(net->statobj->table->records[i].name.type == TYPE_STRING) { if(!strcmp(net->statobj->table->records[i].name.string.data, "mean")) memcpy(net->mean, net->statobj->table->records[i].value.tensor->storage->data, sizeof(net->mean)); else if(!strcmp(net->statobj->table->records[i].name.string.data, "std")) memcpy(net->std, net->statobj->table->records[i].value.tensor->storage->data, sizeof(net->std)); } return net; }
void THFreeNetwork(THNETWORK *network) { freenetwork(network->net); if(network->netobj) { freeobject(network->netobj); free(network->netobj); } if(network->statobj) { freeobject(network->statobj); free(network->statobj); } if(network->out) THFloatTensor_free(network->out); free(network); }