matrix::Matrix sliceLayerWeights(const matrix::Matrix& weights, const RecurrentOpsHandle& handle, size_t index) { matrix::Dimension begin; matrix::Dimension end; prnn::rnn::getWeightsRange(begin, end, handle, weights.precision(), index); return slice(weights, begin, end); }
void forwardPropRecurrent(matrix::Matrix& activations, matrix::Matrix& reserve, const matrix::Matrix& weights, const RecurrentOpsHandle& handle) { auto scratch = prnn::rnn::getForwardPropScratch(handle, activations.precision()); prnn::rnn::forwardPropRecurrent(matrix::DynamicView(activations), matrix::ConstDynamicView(copy(activations)), matrix::ConstDynamicView(weights), matrix::DynamicView(scratch), matrix::DynamicView(reserve), handle); }
void backPropGradientsRecurrent(matrix::Matrix& dWeights, const matrix::Matrix& activations, const matrix::Matrix& outputActivations, const matrix::Matrix& reserve, const RecurrentOpsHandle& handle) { auto scratch = prnn::rnn::getBackPropGradientsScratch(handle, activations.precision()); prnn::rnn::backPropGradientsRecurrent(matrix::DynamicView(dWeights), matrix::ConstDynamicView(activations), matrix::ConstDynamicView(outputActivations), matrix::DynamicView(scratch), matrix::ConstDynamicView(reserve), handle); }