Пример #1
0
THNETWORK *THLoadNetwork(const char *path)
{
	char tmppath[255];
	int i;
	
	THNETWORK *net = calloc(1, sizeof(*net));
	sprintf(tmppath, "%s/model.net", path);
	net->netobj = malloc(sizeof(*net->netobj));
	lasterror = loadtorch(tmppath, net->netobj, longsize);
	if(lasterror)
	{
		free(net->netobj);
		free(net);
		return 0;
	}
	//printobject(net->netobj, 0);
	net->net = Object2Network(net->netobj);
	if(!net->net)
	{
		lasterror = ERR_WRONGOBJECT;
		freeobject(net->netobj);
		free(net->netobj);
		free(net);
		return 0;
	}
	sprintf(tmppath, "%s/stat.t7", path);
	net->statobj = malloc(sizeof(*net->statobj));
	lasterror = loadtorch(tmppath, net->statobj, longsize);
	if(lasterror)
	{
		free(net->statobj);
		freenetwork(net->net);
		freeobject(net->netobj);
		free(net->netobj);
		free(net);
		return 0;
	}
	if(net->statobj->type != TYPE_TABLE || net->statobj->table->nelem != 2)
	{
		lasterror = ERR_WRONGOBJECT;
		freenetwork(net->net);
		freeobject(net->netobj);
		free(net->netobj);
		freeobject(net->statobj);
		free(net->statobj);
		free(net);
	}
	net->std[0] = net->std[1] = net->std[2] = 1;
	net->mean[0] = net->mean[1] = net->mean[2] = 0;
	for(i = 0; i < net->statobj->table->nelem; i++)
		if(net->statobj->table->records[i].name.type == TYPE_STRING)
		{
			if(!strcmp(net->statobj->table->records[i].name.string.data, "mean"))
				memcpy(net->mean, net->statobj->table->records[i].value.tensor->storage->data, sizeof(net->mean));
			else if(!strcmp(net->statobj->table->records[i].name.string.data, "std"))
				memcpy(net->std, net->statobj->table->records[i].value.tensor->storage->data, sizeof(net->std));
		}
	return net;
}
Пример #2
0
void THFreeNetwork(THNETWORK *network)
{
	freenetwork(network->net);
	if(network->netobj)
	{
		freeobject(network->netobj);
		free(network->netobj);
	}
	if(network->statobj)
	{
		freeobject(network->statobj);
		free(network->statobj);
	}
	if(network->out)
		THFloatTensor_free(network->out);
	free(network);
}