// Parses a network that begins with 'C'. Network* NetworkBuilder::ParseC(const StaticShape& input_shape, char** str) { NetworkType type = NonLinearity((*str)[1]); if (type == NT_NONE) { tprintf("Invalid nonlinearity on C-spec!: %s\n", *str); return nullptr; } int y = 0, x = 0, d = 0; if ((y = strtol(*str + 2, str, 10)) <= 0 || **str != ',' || (x = strtol(*str + 1, str, 10)) <= 0 || **str != ',' || (d = strtol(*str + 1, str, 10)) <= 0) { tprintf("Invalid C spec!:%s\n", *str); return nullptr; } if (x == 1 && y == 1) { // No actual convolution. Just a FullyConnected on the current depth, to // be slid over all batch,y,x. return new FullyConnected("Conv1x1", input_shape.depth(), d, type); } Series* series = new Series("ConvSeries"); Convolve* convolve = new Convolve("Convolve", input_shape.depth(), x / 2, y / 2); series->AddToStack(convolve); StaticShape fc_input = convolve->OutputShape(input_shape); series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type)); return series; }
// Parses an Output spec. Network* NetworkBuilder::ParseOutput(const StaticShape& input_shape, char** str) { char dims_ch = (*str)[1]; if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') { tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str); return nullptr; } char type_ch = (*str)[2]; if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') { tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str); return nullptr; } int depth = strtol(*str + 3, str, 10); if (depth != num_softmax_outputs_) { tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth, num_softmax_outputs_); depth = num_softmax_outputs_; } NetworkType type = NT_SOFTMAX; if (type_ch == 'l') type = NT_LOGISTIC; else if (type_ch == 's') type = NT_SOFTMAX_NO_CTC; if (dims_ch == '0') { // Same as standard fully connected. return BuildFullyConnected(input_shape, type, "Output", depth); } else if (dims_ch == '2') { // We don't care if x and/or y are variable. return new FullyConnected("Output2d", input_shape.depth(), depth, type); } // For 1-d y has to be fixed, and if not 1, moved to depth. if (input_shape.height() == 0) { tprintf("Fully connected requires fixed height!\n"); return nullptr; } int input_size = input_shape.height(); int input_depth = input_size * input_shape.depth(); Network* fc = new FullyConnected("Output", input_depth, depth, type); if (input_size > 1) { Series* series = new Series("FCSeries"); series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1, input_shape.height())); series->AddToStack(fc); fc = series; } return fc; }
// Helper builds a truly (0-d) fully connected layer of the given type. static Network* BuildFullyConnected(const StaticShape& input_shape, NetworkType type, const STRING& name, int depth) { if (input_shape.height() == 0 || input_shape.width() == 0) { tprintf("Fully connected requires positive height and width, had %d,%d\n", input_shape.height(), input_shape.width()); return nullptr; } int input_size = input_shape.height() * input_shape.width(); int input_depth = input_size * input_shape.depth(); Network* fc = new FullyConnected(name, input_depth, depth, type); if (input_size > 1) { Series* series = new Series("FCSeries"); series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), input_shape.width(), input_shape.height())); series->AddToStack(fc); fc = series; } return fc; }
// Parses a sequential series of networks, defined by [<net><net>...]. Network* NetworkBuilder::ParseSeries(const StaticShape& input_shape, Input* input_layer, char** str) { StaticShape shape = input_shape; Series* series = new Series("Series"); ++*str; if (input_layer != nullptr) { series->AddToStack(input_layer); shape = input_layer->OutputShape(shape); } Network* network = NULL; while (**str != '\0' && **str != ']' && (network = BuildFromString(shape, str)) != NULL) { shape = network->OutputShape(shape); series->AddToStack(network); } if (**str != ']') { tprintf("Missing ] at end of [Series]!\n"); delete series; return NULL; } ++*str; return series; }