/// Predict the belief vector that will result if the specified action is performed void TransitionModel::anticipateNextBeliefs(const GVec& beliefs, const GVec& actions, GVec& anticipatedBeliefs) { if(tutor) tutor->transition(beliefs, actions, anticipatedBeliefs); else { GAssert(beliefs.size() + actions.size() == model.layer(0).inputs()); buf.resize(beliefs.size() + actions.size()); buf.put(0, beliefs); buf.put(beliefs.size(), actions); model.forwardProp(buf); anticipatedBeliefs.copy(beliefs); anticipatedBeliefs.addScaled(2.0, model.outputLayer().activation()); anticipatedBeliefs.clip(-1.0, 1.0); } }
/// Refines the beliefs to correspond with actual observations void ObservationModel::calibrateBeliefs(GVec& beliefs, const GVec& observations) { if(tutor) tutor->observations_to_state(observations, beliefs); else { GNeuralNetLayer& layIn = encoder.outputLayer(); for(size_t i = 0; i < calibrationIters; i++) { decoder.forwardProp(beliefs); decoder.backpropagate(observations); decoder.layer(0).backPropError(&layIn); beliefs.addScaled(decoder.learningRate(), layIn.error()); beliefs.clip(-1.0, 1.0); } } }
// virtual void GWag::trainInner(const GMatrix& features, const GMatrix& labels) { GNeuralNetLearner* pTemp = NULL; std::unique_ptr<GNeuralNetLearner> hTemp; size_t weights = 0; GVec pWeightBuf; GVec pWeightBuf2; for(size_t i = 0; i < m_models; i++) { m_pNN->train(features, labels); if(pTemp) { // Average m_pNN with pTemp if(!m_noAlign) m_pNN->nn().align(pTemp->nn()); pTemp->nn().weightsToVector(pWeightBuf.data()); m_pNN->nn().weightsToVector(pWeightBuf2.data()); pWeightBuf *= (double(i) / (i + 1)); pWeightBuf.addScaled(1.0 / (i + 1), pWeightBuf2); pTemp->nn().vectorToWeights(pWeightBuf.data()); } else { // Copy the m_pNN GDom doc; GDomNode* pNode = m_pNN->serialize(&doc); GLearnerLoader ll; pTemp = new GNeuralNetLearner(pNode); hTemp.reset(pTemp); weights = pTemp->nn().weightCount(); pWeightBuf.resize(weights); pWeightBuf2.resize(weights); } } pTemp->nn().weightsToVector(pWeightBuf.data()); m_pNN->nn().vectorToWeights(pWeightBuf.data()); }