// Runs backward propagation of errors on the deltas line. // See NetworkCpp for a detailed discussion of the arguments. bool Parallel::Backward(bool debug, const NetworkIO& fwd_deltas, NetworkScratch* scratch, NetworkIO* back_deltas) { // If this parallel is a replicator of convolvers, or holds a 1-d LSTM pair, // or a 2-d LSTM quad, do debug locally, and don't pass the flag on. if (debug && type_ != NT_PARALLEL) { DisplayBackward(fwd_deltas); debug = false; } int stack_size = stack_.size(); if (type_ == NT_PAR_2D_LSTM) { // Special case, run parallel in parallel. GenericVector<NetworkScratch::IO> in_deltas, out_deltas; in_deltas.init_to_size(stack_size, NetworkScratch::IO()); out_deltas.init_to_size(stack_size, NetworkScratch::IO()); // Split the forward deltas for each stack element. int feature_offset = 0; for (int i = 0; i < stack_.size(); ++i) { int num_features = stack_[i]->NumOutputs(); in_deltas[i].Resize(fwd_deltas, num_features, scratch); out_deltas[i].Resize(fwd_deltas, stack_[i]->NumInputs(), scratch); in_deltas[i]->CopyUnpacking(fwd_deltas, feature_offset, num_features); feature_offset += num_features; } #ifdef _OPENMP #pragma omp parallel for num_threads(stack_size) #endif for (int i = 0; i < stack_size; ++i) { stack_[i]->Backward(debug, *in_deltas[i], scratch, i == 0 ? back_deltas : out_deltas[i]); } if (needs_to_backprop_) { for (int i = 1; i < stack_size; ++i) { back_deltas->AddAllToFloat(*out_deltas[i]); } } } else { // Revolving partial deltas. NetworkScratch::IO in_deltas(fwd_deltas, scratch); // The sum of deltas from different sources, which will eventually go into // back_deltas. NetworkScratch::IO out_deltas; int feature_offset = 0; for (int i = 0; i < stack_.size(); ++i) { int num_features = stack_[i]->NumOutputs(); in_deltas->CopyUnpacking(fwd_deltas, feature_offset, num_features); feature_offset += num_features; if (stack_[i]->Backward(debug, *in_deltas, scratch, back_deltas)) { if (i == 0) { out_deltas.ResizeFloat(*back_deltas, back_deltas->NumFeatures(), scratch); out_deltas->CopyAll(*back_deltas); } else if (back_deltas->NumFeatures() == out_deltas->NumFeatures()) { // Widths are allowed to be different going back, as we may have // input nets, so only accumulate the deltas if the widths are the // same. out_deltas->AddAllToFloat(*back_deltas); } } } if (needs_to_backprop_) back_deltas->CopyAll(*out_deltas); } if (needs_to_backprop_) back_deltas->ScaleFloatBy(1.0f / stack_size); return needs_to_backprop_; }
// Runs backward propagation of errors on the deltas line. // See NetworkCpp for a detailed discussion of the arguments. bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas, NetworkScratch* scratch, NetworkIO* back_deltas) { if (debug) DisplayBackward(fwd_deltas); back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_); // ======Scratch space.====== // Output errors from deltas with recurrence from sourceerr. NetworkScratch::FloatVec outputerr; outputerr.Init(ns_, scratch); // Recurrent error in the state/source. NetworkScratch::FloatVec curr_stateerr, curr_sourceerr; curr_stateerr.Init(ns_, scratch); curr_sourceerr.Init(na_, scratch); ZeroVector<double>(ns_, curr_stateerr); ZeroVector<double>(na_, curr_sourceerr); // Errors in the gates. NetworkScratch::FloatVec gate_errors[WT_COUNT]; for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch); // Rotating buffers of width buf_width allow storage of the recurrent time- // steps used only for true 2-D. Stores one full strip of the major direction. int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1; GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr; if (Is2D()) { stateerr.init_to_size(buf_width, NetworkScratch::FloatVec()); sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec()); for (int t = 0; t < buf_width; ++t) { stateerr[t].Init(ns_, scratch); sourceerr[t].Init(na_, scratch); ZeroVector<double>(ns_, stateerr[t]); ZeroVector<double>(na_, sourceerr[t]); } } // Parallel-generated sourceerr from each of the gates. NetworkScratch::FloatVec sourceerr_temps[WT_COUNT]; for (int w = 0; w < WT_COUNT; ++w) sourceerr_temps[w].Init(na_, scratch); int width = input_width_; // Transposed gate errors stored over all timesteps for sum outer. NetworkScratch::GradientStore gate_errors_t[WT_COUNT]; for (int w = 0; w < WT_COUNT; ++w) { gate_errors_t[w].Init(ns_, width, scratch); } // Used only if softmax_ != NULL. NetworkScratch::FloatVec softmax_errors; NetworkScratch::GradientStore softmax_errors_t; if (softmax_ != NULL) { softmax_errors.Init(no_, scratch); softmax_errors_t.Init(no_, width, scratch); } double state_clip = Is2D() ? 9.0 : 4.0; #if DEBUG_DETAIL > 1 tprintf("fwd_deltas:%s\n", name_.string()); fwd_deltas.Print(10); #endif StrideMap::Index dest_index(input_map_); dest_index.InitToLast(); // Used only by NT_LSTM_SUMMARY. StrideMap::Index src_index(fwd_deltas.stride_map()); src_index.InitToLast(); do { int t = dest_index.t(); bool at_last_x = dest_index.IsLast(FD_WIDTH); // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only // valid if >= 0, which is true if 2d and not on the top/bottom. int up_pos = -1; int down_pos = -1; if (Is2D()) { if (dest_index.index(FD_HEIGHT) > 0) { StrideMap::Index up_index(dest_index); if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t(); } if (!dest_index.IsLast(FD_HEIGHT)) { StrideMap::Index down_index(dest_index); if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t(); } } // Index of the 2-D revolving buffers (sourceerr, stateerr). int mod_t = Modulo(t, buf_width); // Current timestep. // Zero the state in the major direction only at the end of every row. if (at_last_x) { ZeroVector<double>(na_, curr_sourceerr); ZeroVector<double>(ns_, curr_stateerr); } // Setup the outputerr. if (type_ == NT_LSTM_SUMMARY) { if (dest_index.IsLast(FD_WIDTH)) { fwd_deltas.ReadTimeStep(src_index.t(), outputerr); src_index.Decrement(); } else { ZeroVector<double>(ns_, outputerr); } } else if (softmax_ == NULL) { fwd_deltas.ReadTimeStep(t, outputerr); } else { softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors, softmax_errors_t.get(), outputerr); } if (!at_last_x) AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr); if (down_pos >= 0) AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr); // Apply the 1-d forget gates. if (!at_last_x) { const float* next_node_gf1 = node_values_[GF1].f(t + 1); for (int i = 0; i < ns_; ++i) { curr_stateerr[i] *= next_node_gf1[i]; } } if (Is2D() && t + 1 < width) { for (int i = 0; i < ns_; ++i) { if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0; } if (down_pos >= 0) { const float* right_node_gfs = node_values_[GFS].f(down_pos); const double* right_stateerr = stateerr[mod_t]; for (int i = 0; i < ns_; ++i) { if (which_fg_[down_pos][i] == 2) { curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i]; } } } } state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr, curr_stateerr); // Clip stateerr_ to a sane range. ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr); #if DEBUG_DETAIL > 1 if (t + 10 > width) { tprintf("t=%d, stateerr=", t); for (int i = 0; i < ns_; ++i) tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i], curr_sourceerr[ni_ + nf_ + i]); tprintf("\n"); } #endif // Matrix multiply to get the source errors. PARALLEL_IF_OPENMP(GFS) // Cell inputs. node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t, curr_stateerr, gate_errors[CI]); ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get()); gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]); gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]); SECTION_IF_OPENMP // Input Gates. node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t, curr_stateerr, gate_errors[GI]); ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get()); gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]); gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]); SECTION_IF_OPENMP // 1-D forget Gates. if (t > 0) { node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr, gate_errors[GF1]); ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get()); gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1], sourceerr_temps[GF1]); } else { memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0])); memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1])); } gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]); // 2-D forget Gates. if (up_pos >= 0) { node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr, gate_errors[GFS]); ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get()); gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS], sourceerr_temps[GFS]); } else { memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0])); memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS])); } if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]); SECTION_IF_OPENMP // Output gates. state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr, gate_errors[GO]); ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get()); gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]); gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]); END_PARALLEL_IF_OPENMP SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI], sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS], curr_sourceerr); back_deltas->WriteTimeStep(t, curr_sourceerr); // Save states for use by the 2nd dimension only if needed. if (Is2D()) { CopyVector(ns_, curr_stateerr, stateerr[mod_t]); CopyVector(na_, curr_sourceerr, sourceerr[mod_t]); } } while (dest_index.Decrement()); #if DEBUG_DETAIL > 2 for (int w = 0; w < WT_COUNT; ++w) { tprintf("%s gate errors[%d]\n", name_.string(), w); gate_errors_t[w].get()->PrintUnTransposed(10); } #endif // Transposed source_ used to speed-up SumOuter. NetworkScratch::GradientStore source_t, state_t; source_t.Init(na_, width, scratch); source_.Transpose(source_t.get()); state_t.Init(ns_, width, scratch); state_.Transpose(state_t.get()); #ifdef _OPENMP #pragma omp parallel for num_threads(GFS) if (!Is2D()) #endif for (int w = 0; w < WT_COUNT; ++w) { if (w == GFS && !Is2D()) continue; gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false); } if (softmax_ != NULL) { softmax_->FinishBackward(*softmax_errors_t); } return needs_to_backprop_; }