/** * Ordinary feed forward pass of a neural network, evaluating the function * f(x) by propagating the activity forward through f. * * @param inputActivation Input data used for evaluating the specified * activity function. * @param outputActivation Datatype to store the resulting output activation. */ void FeedForward(const VecType& inputActivation, VecType& outputActivation) { if (inGate.n_cols < seqLen) { inGate = arma::zeros<MatType>(layerSize, seqLen); inGateAct = arma::zeros<MatType>(layerSize, seqLen); inGateError = arma::zeros<MatType>(layerSize, seqLen); outGate = arma::zeros<MatType>(layerSize, seqLen); outGateAct = arma::zeros<MatType>(layerSize, seqLen); outGateError = arma::zeros<MatType>(layerSize, seqLen); forgetGate = arma::zeros<MatType>(layerSize, seqLen); forgetGateAct = arma::zeros<MatType>(layerSize, seqLen); forgetGateError = arma::zeros<MatType>(layerSize, seqLen); state = arma::zeros<MatType>(layerSize, seqLen); stateError = arma::zeros<MatType>(layerSize, seqLen); cellAct = arma::zeros<MatType>(layerSize, seqLen); } // Split up the inputactivation into the 3 parts (inGate, forgetGate, // outGate). inGate.col(offset) = inputActivation.subvec(0, layerSize - 1); forgetGate.col(offset) = inputActivation.subvec( layerSize, (layerSize * 2) - 1); outGate.col(offset) = inputActivation.subvec( layerSize * 3, (layerSize * 4) - 1); if (peepholes && offset > 0) { inGate.col(offset) += inGatePeepholeWeights % state.col(offset - 1); forgetGate.col(offset) += forgetGatePeepholeWeights % state.col(offset - 1); } VecType inGateActivation = inGateAct.unsafe_col(offset); GateActivationFunction::fn(inGate.unsafe_col(offset), inGateActivation); VecType forgetGateActivation = forgetGateAct.unsafe_col(offset); GateActivationFunction::fn(forgetGate.unsafe_col(offset), forgetGateActivation); VecType cellActivation = cellAct.unsafe_col(offset); StateActivationFunction::fn(inputActivation.subvec(layerSize * 2, (layerSize * 3) - 1), cellActivation); state.col(offset) = inGateAct.col(offset) % cellActivation; if (offset > 0) state.col(offset) += forgetGateAct.col(offset) % state.col(offset - 1); if (peepholes) outGate.col(offset) += outGatePeepholeWeights % state.col(offset); VecType outGateActivation = outGateAct.unsafe_col(offset); GateActivationFunction::fn(outGate.unsafe_col(offset), outGateActivation); OutputActivationFunction::fn(state.unsafe_col(offset), outputActivation); outputActivation = outGateAct.col(offset) % outputActivation; offset = (offset + 1) % seqLen; }
/** * Ordinary feed backward pass of a neural network, calculating the function * f(x) by propagating x backwards trough f. Using the results from the feed * forward pass. * * @param inputActivation Input data used for calculating the function f(x). * @param error The backpropagated error. * @param delta The calculating delta using the partial derivative of the * error with respect to a weight. */ void FeedBackward(const VecType& /* unused */, const VecType& error, VecType& delta) { size_t queryOffset = seqLen - offset - 1; VecType outGateDerivative; GateActivationFunction::deriv(outGateAct.unsafe_col(queryOffset), outGateDerivative); VecType stateActivation; StateActivationFunction::fn(state.unsafe_col(queryOffset), stateActivation); outGateError.col(queryOffset) = outGateDerivative % error % stateActivation; VecType stateDerivative; StateActivationFunction::deriv(stateActivation, stateDerivative); stateError.col(queryOffset) = error % outGateAct.col(queryOffset) % stateDerivative; if (queryOffset < (seqLen - 1)) { stateError.col(queryOffset) += stateError.col(queryOffset + 1) % forgetGateAct.col(queryOffset + 1); if (peepholes) { stateError.col(queryOffset) += inGateError.col(queryOffset + 1) % inGatePeepholeWeights; stateError.col(queryOffset) += forgetGateError.col(queryOffset + 1) % forgetGatePeepholeWeights; } } if (peepholes) { stateError.col(queryOffset) += outGateError.col(queryOffset) % outGatePeepholeWeights; } VecType cellDerivative; StateActivationFunction::deriv(cellAct.col(queryOffset), cellDerivative); VecType cellError = inGateAct.col(queryOffset) % cellDerivative % stateError.col(queryOffset); if (queryOffset > 0) { VecType forgetGateDerivative; GateActivationFunction::deriv(forgetGateAct.col(queryOffset), forgetGateDerivative); forgetGateError.col(queryOffset) = forgetGateDerivative % stateError.col(queryOffset) % state.col(queryOffset - 1); } VecType inGateDerivative; GateActivationFunction::deriv(inGateAct.col(queryOffset), inGateDerivative); inGateError.col(queryOffset) = inGateDerivative % stateError.col(queryOffset) % cellAct.col(queryOffset); if (peepholes) { outGateDerivative += outGateError.col(queryOffset) % state.col(queryOffset); if (queryOffset > 0) { inGatePeepholeDerivatives += inGateError.col(queryOffset) % state.col(queryOffset - 1); forgetGatePeepholeDerivatives += forgetGateError.col(queryOffset) % state.col(queryOffset - 1); } } delta = arma::zeros<VecType>(layerSize * 4); delta.subvec(0, layerSize - 1) = inGateError.col(queryOffset); delta.subvec(layerSize, (layerSize * 2) - 1) = forgetGateError.col(queryOffset); delta.subvec(layerSize * 2, (layerSize * 3) - 1) = cellError; delta.subvec(layerSize * 3, (layerSize * 4) - 1) = outGateError.col(queryOffset); offset = (offset + 1) % seqLen; if (peepholes && offset == 0) { inGatePeepholeGradient = (inGatePeepholeWeights.t() * (inGateError.col(queryOffset) % inGatePeepholeDerivatives)) * inGate.col(queryOffset).t(); forgetGatePeepholeGradient = (forgetGatePeepholeWeights.t() * (forgetGateError.col(queryOffset) % forgetGatePeepholeDerivatives)) * forgetGate.col(queryOffset).t(); outGatePeepholeGradient = (outGatePeepholeWeights.t() * (outGateError.col(queryOffset) % outGatePeepholeDerivatives)) * outGate.col(queryOffset).t(); inGatePeepholeOptimizer->UpdateWeights(inGatePeepholeWeights, inGatePeepholeGradient.t(), 0); forgetGatePeepholeOptimizer->UpdateWeights(forgetGatePeepholeWeights, forgetGatePeepholeGradient.t(), 0); outGatePeepholeOptimizer->UpdateWeights(outGatePeepholeWeights, outGatePeepholeGradient.t(), 0); inGatePeepholeDerivatives.zeros(); forgetGatePeepholeDerivatives.zeros(); outGatePeepholeDerivatives.zeros(); } }