示例#1
0
// Neural Network -------------------------------------------------------------
// Load the snapshot of the CNN we are going to run.
Network* construct_gtsrb_net() {
  fprintf(stderr, "Constructing GTSRB Network \n");
  Network* net = make_network(12);

  network_add(net, make_conv_layer(48, 48, 3, 3, 100, 1, 0));
  network_add(net, make_relu_layer(net->layers[0]->out_sx, net->layers[0]->out_sy, net->layers[0]->out_depth));
  network_add(net, make_max_pool_layer(net->layers[1]->out_sx, net->layers[1]->out_sy, net->layers[1]->out_depth, 2, 2));
  network_add(net, make_conv_layer(net->layers[2]->out_sx, net->layers[2]->out_sy, net->layers[2]->out_depth, 4, 150, 1, 0));
  network_add(net, make_relu_layer(net->layers[3]->out_sx, net->layers[3]->out_sy, net->layers[3]->out_depth));
  network_add(net, make_max_pool_layer(net->layers[4]->out_sx, net->layers[4]->out_sy, net->layers[4]->out_depth, 2, 2));
  network_add(net, make_conv_layer(net->layers[5]->out_sx, net->layers[5]->out_sy, net->layers[5]->out_depth, 3, 250, 1, 0));
  network_add(net, make_relu_layer(net->layers[6]->out_sx, net->layers[6]->out_sy, net->layers[6]->out_depth));
  network_add(net, make_max_pool_layer(net->layers[7]->out_sx, net->layers[7]->out_sy, net->layers[7]->out_depth, 2, 2));
  network_add(net, make_fc_layer(net->layers[8]->out_sx, net->layers[8]->out_sy, net->layers[8]->out_depth, 200));
  network_add(net, make_fc_layer(net->layers[9]->out_sx, net->layers[9]->out_sy, net->layers[9]->out_depth, 43));
  network_add(net, make_softmax_layer(net->layers[10]->out_sx, net->layers[10]->out_sy, net->layers[10]->out_depth));

  // load pre-trained weights
  conv_load(net->layers[0], conv1_params, conv1_data);
  conv_load(net->layers[3], conv2_params, conv2_data);
  conv_load(net->layers[6], conv3_params, conv3_data);
  fc_load(net->layers[9], ip1_params, ip1_data);
  fc_load(net->layers[10], ip2_params, ip2_data);
  return net;
}
示例#2
0
// Load the snapshot of the CNN we are going to run.
network_t* load_cnn_snapshot() {
  network_t* net = make_network();
  conv_load(net->l0, "../data/snapshot/layer1_conv.txt");
  conv_load(net->l3, "../data/snapshot/layer4_conv.txt");
  conv_load(net->l6, "../data/snapshot/layer7_conv.txt");
  fc_load(net->l9, "../data/snapshot/layer10_fc.txt");
  return net;  
}