float train_network_datum(network net) { #ifdef GPU if(gpu_index >= 0) return train_network_datum_gpu(net); #endif *net.seen += net.batch; net.train = 1; forward_network(net); backward_network(net); float error = *net.cost; if(((*net.seen)/net.batch)%net.subdivisions == 0) update_network(net); return error; }
float train_network_datum(network net, float *x, float *y) { #ifdef GPU if(gpu_index >= 0) return train_network_datum_gpu(net, x, y); #endif network_state state; state.input = x; state.truth = y; state.train = 1; forward_network(net, state); backward_network(net, state); float error = get_network_cost(net); if((net.seen/net.batch)%net.subdivisions == 0) update_network(net); return error; }