// Neural Network ------------------------------------------------------------- // Load the snapshot of the CNN we are going to run. Network* construct_gtsrb_net() { fprintf(stderr, "Constructing GTSRB Network \n"); Network* net = make_network(12); network_add(net, make_conv_layer(48, 48, 3, 3, 100, 1, 0)); network_add(net, make_relu_layer(net->layers[0]->out_sx, net->layers[0]->out_sy, net->layers[0]->out_depth)); network_add(net, make_max_pool_layer(net->layers[1]->out_sx, net->layers[1]->out_sy, net->layers[1]->out_depth, 2, 2)); network_add(net, make_conv_layer(net->layers[2]->out_sx, net->layers[2]->out_sy, net->layers[2]->out_depth, 4, 150, 1, 0)); network_add(net, make_relu_layer(net->layers[3]->out_sx, net->layers[3]->out_sy, net->layers[3]->out_depth)); network_add(net, make_max_pool_layer(net->layers[4]->out_sx, net->layers[4]->out_sy, net->layers[4]->out_depth, 2, 2)); network_add(net, make_conv_layer(net->layers[5]->out_sx, net->layers[5]->out_sy, net->layers[5]->out_depth, 3, 250, 1, 0)); network_add(net, make_relu_layer(net->layers[6]->out_sx, net->layers[6]->out_sy, net->layers[6]->out_depth)); network_add(net, make_max_pool_layer(net->layers[7]->out_sx, net->layers[7]->out_sy, net->layers[7]->out_depth, 2, 2)); network_add(net, make_fc_layer(net->layers[8]->out_sx, net->layers[8]->out_sy, net->layers[8]->out_depth, 200)); network_add(net, make_fc_layer(net->layers[9]->out_sx, net->layers[9]->out_sy, net->layers[9]->out_depth, 43)); network_add(net, make_softmax_layer(net->layers[10]->out_sx, net->layers[10]->out_sy, net->layers[10]->out_depth)); // load pre-trained weights conv_load(net->layers[0], conv1_params, conv1_data); conv_load(net->layers[3], conv2_params, conv2_data); conv_load(net->layers[6], conv3_params, conv3_data); fc_load(net->layers[9], ip1_params, ip1_data); fc_load(net->layers[10], ip2_params, ip2_data); return net; }
// Load the snapshot of the CNN we are going to run. network_t* load_cnn_snapshot() { network_t* net = make_network(); conv_load(net->l0, "../data/snapshot/layer1_conv.txt"); conv_load(net->l3, "../data/snapshot/layer4_conv.txt"); conv_load(net->l6, "../data/snapshot/layer7_conv.txt"); fc_load(net->l9, "../data/snapshot/layer10_fc.txt"); return net; }