// Parses a network that begins with 'R'. Network* NetworkBuilder::ParseR(const StaticShape& input_shape, char** str) { char dir = (*str)[1]; if (dir == 'x' || dir == 'y') { STRING name = "Reverse"; name += dir; *str += 2; Network* network = BuildFromString(input_shape, str); if (network == nullptr) return nullptr; Reversed* rev = new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED); rev->SetNetwork(network); return rev; } int replicas = strtol(*str + 1, str, 10); if (replicas <= 0) { tprintf("Invalid R spec!:%s\n", *str); return nullptr; } Parallel* parallel = new Parallel("Replicated", NT_REPLICATED); char* str_copy = *str; for (int i = 0; i < replicas; ++i) { str_copy = *str; Network* network = BuildFromString(input_shape, &str_copy); if (network == NULL) { tprintf("Invalid replicated network!\n"); delete parallel; return nullptr; } parallel->AddToStack(network); } *str = str_copy; return parallel; }
// Builds a set of 4 lstms with x and y reversal, running in true parallel. Network* NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) { Parallel* parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM); parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM)); Reversed* rev = new Reversed("L2DLTRXRev", NT_XREVERSED); rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states, true, NT_LSTM)); parallel->AddToStack(rev); rev = new Reversed("L2DRTLYRev", NT_YREVERSED); rev->SetNetwork( new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM)); Reversed* rev2 = new Reversed("L2DXRevU", NT_XREVERSED); rev2->SetNetwork(rev); parallel->AddToStack(rev2); rev = new Reversed("L2DXRevY", NT_YREVERSED); rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM)); parallel->AddToStack(rev); return parallel; }
// Parses a parallel set of networks, defined by (<net><net>...). Network* NetworkBuilder::ParseParallel(const StaticShape& input_shape, char** str) { Parallel* parallel = new Parallel("Parallel", NT_PARALLEL); ++*str; Network* network = NULL; while (**str != '\0' && **str != ')' && (network = BuildFromString(input_shape, str)) != NULL) { parallel->AddToStack(network); } if (**str != ')') { tprintf("Missing ) at end of (Parallel)!\n"); delete parallel; return nullptr; } ++*str; return parallel; }
// Parses an LSTM network, either individual, bi- or quad-directional. Network* NetworkBuilder::ParseLSTM(const StaticShape& input_shape, char** str) { bool two_d = false; NetworkType type = NT_LSTM; char* spec_start = *str; int chars_consumed = 1; int num_outputs = 0; char key = (*str)[chars_consumed], dir = 'f', dim = 'x'; if (key == 'S') { type = NT_LSTM_SOFTMAX; num_outputs = num_softmax_outputs_; ++chars_consumed; } else if (key == 'E') { type = NT_LSTM_SOFTMAX_ENCODED; num_outputs = num_softmax_outputs_; ++chars_consumed; } else if (key == '2' && (((*str)[2] == 'x' && (*str)[3] == 'y') || ((*str)[2] == 'y' && (*str)[3] == 'x'))) { chars_consumed = 4; dim = (*str)[3]; two_d = true; } else if (key == 'f' || key == 'r' || key == 'b') { dir = key; dim = (*str)[2]; if (dim != 'x' && dim != 'y') { tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str); return nullptr; } chars_consumed = 3; if ((*str)[chars_consumed] == 's') { ++chars_consumed; type = NT_LSTM_SUMMARY; } } else { tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str); return nullptr; } int num_states = strtol(*str + chars_consumed, str, 10); if (num_states <= 0) { tprintf("Invalid number of states in L Spec!:%s\n", *str); return nullptr; } Network* lstm = nullptr; if (two_d) { lstm = BuildLSTMXYQuad(input_shape.depth(), num_states); } else { if (num_outputs == 0) num_outputs = num_states; STRING name(spec_start, *str - spec_start); lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type); if (dir != 'f') { Reversed* rev = new Reversed("RevLSTM", NT_XREVERSED); rev->SetNetwork(lstm); lstm = rev; } if (dir == 'b') { name += "LTR"; Parallel* parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM); parallel->AddToStack(new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type)); parallel->AddToStack(lstm); lstm = parallel; } } if (dim == 'y') { Reversed* rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE); rev->SetNetwork(lstm); lstm = rev; } return lstm; }