void SarsaAgent::endEpisode(){ episodeNumber++; //This will not happen usually, but is a safety. if(lastAction == -1){ return; } else{ FA->setState(lastState); double oldQ = FA->computeQ(lastAction); FA->updateTraces(lastAction); double delta = lastReward - oldQ; FA->updateWeights(delta, learningRate); //Assume lambda is 0. this comment looks wrong. FA->decayTraces(0);//remains 0 } if(toSaveWeights && (episodeNumber + 1) % 5 == 0){ saveWeights(saveWeightsFile); std::cout << "Saving weights to " << saveWeightsFile << std::endl; } lastAction = -1; }
bool MLP::learn(realnumber ME, realnumber MT, realnumber LR, bool ALR, realnumber lambda, realnumber lambda1, realnumber lambda2) // learn permet de réaliser l'apprentissage du MLP { /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * A IMPLEMENTER * * * * normaliser les données d'entrainement * * erreur en dessous de laquelle un exemple n'est plus traité * * weight decay * * OK: variation du taux d'apprentissage (algo de Vogl) OU poids distinct pour chaque connexion (Sanossian & Evans) * * élagage * * injection de bruit * * ensemble de validation * * early stop * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ /* * ME = MAX_ERROR * MT = MAX_TIME * LR = LEARNING_RATE * ALR = ADAPTATIVELR (adaptative learning rate) */ if (isSet()) { integer index, compteur = 0; realnumber nextDisplayTime = 0, newMQE = MQE(lambda, lambda1, lambda2), oldMQE = newMQE; clock_t start = clock(); displayInfo(lambda, lambda1, lambda2); display("learning starting..."); // pour la suite: "index" est le numéro de l'exemple que l'on est en train de traiter // et "j" est le numéro de la couche while(newMQE > ME && (clock() - start) / (realnumber)CLOCKS_PER_SEC < MT) { // affiche "mqe" et "m_learningRate" si le dernier affichage date de plus d'une seconde displayMQE(start, nextDisplayTime, newMQE, LR); // présente un exemple au hasard pour l'apprendre index = rand()% m_input.cols(); // ATTENTION! A améliorer saveWeights(); weightDecay(lambda, lambda1, lambda2); modifyWeights(index, LR); // on vérifie s'ils sont meilleurs que les anciens, sinon on revient en arrière newMQE = MQE(lambda, lambda1, lambda2); modifyLearningRate(LR, ALR, oldMQE, newMQE); compteur++; } display("learning finished! \n"); display("Iterations: " + toStr(int(compteur)) + "; Temps en secondes : " + toStr ((clock() - start) / (realnumber)CLOCKS_PER_SEC) + ""); displayInfo(lambda, lambda1, lambda2); return (newMQE <= ME); } else return 0; }
/* BP算法的训练函数 */ void BP::train(string test_path) { this->initial();//初始化BP模型 int train_times; int i,j; double unit_in,temp; //误差 double error_out[OUT]; double error_hid[HIDE+1]; double error_sum; //开始训练 int count_instance = 0; train_times = 0; list<Material>::iterator it; srand( (int)time(0) ); do{ error_sum = 0; memset((void*)(delta_who), 0, sizeof(double)*OUT*(HIDE+1)); memset((void*)(delta_wih), 0, sizeof(double)*HIDE*(IN+1)); for(it=train_data.begin(); it!=train_data.end(); it++, count_instance++){//对每个一个训练样例 if (it->correct == 1){ if (rand()%4 != 0) continue; } //从当前实例获取输入和取标准输出 getInput_Output(*it); //前向计算,获取每个单元的输出值 //隐藏单元 //#pragma omp parallel for for(i=0; i<HIDE; i++){ unit_in = 0; for(j=0; j<IN+1; j++){ unit_in += wih[i][j] * in_unit[j]; } hid_unit[i] = 1.0/( 1.0 + exp(-unit_in) ); } hid_unit[HIDE] = 1;//阈值项 //输出单元 //#pragma omp parallel for for(i=0; i<OUT; i++){ unit_in = 0; for(j=0; j<HIDE+1; j++){ unit_in += who[i][j] * hid_unit[j]; } out_unit[i] = 1.0/( 1.0 + exp(-unit_in) ); } //对网络的每个输出单元,计算其误差项 for(i=0; i<OUT; i++){ error_out[i] = out_unit[i]*(1-out_unit[i])*(target_out[i] - out_unit[i]); error_sum += fabs( error_out[i] ); } //对网络的每个隐藏单元,计算其误差项 for(i=0; i<HIDE+1; i++){ error_hid[i] = hid_unit[i]*(1-hid_unit[i]); //当前隐藏单元对所有输出单元的误差的贡献 temp = 0; for(j=0; j<OUT; j++){ temp += who[j][i] * error_out[j]; } error_hid[i] *= temp; error_sum += fabs( error_hid[i] ); } //更新每个网络权值------------------------------------------------------------- //隐藏层到输出层权值 for(i=0; i<OUT; i++){ for(j=0; j<HIDE+1; j++){ delta_who[i][j] += alpha * error_out[i] * hid_unit[j]; } } //输入层到隐藏层 for(i=0; i<HIDE; i++){ for(j=0; j<IN+1; j++){ delta_wih[i][j] += alpha * error_hid[i] * in_unit[j]; } } if (count_instance%50 == 0){ //隐藏层到输出层权值 for(i=0; i<OUT; i++){ for(j=0; j<HIDE+1; j++){ who[i][j] += delta_who[i][j]; } } //输入层到隐藏层 for(i=0; i<HIDE; i++){ for(j=0; j<IN+1; j++){ wih[i][j] += delta_wih[i][j]; } } memset((void*)(delta_who), 0, sizeof(double)*OUT*(HIDE+1)); memset((void*)(delta_wih), 0, sizeof(double)*HIDE*(IN+1)); } } if (train_times%500 == 0) alpha *= 0.9; train_times++; if (train_times%100 == 0){ saveWeights(); } if (train_times%20 == 0){ //readWeights(); double accuracy = test(test_path); cout<<train_times<<"\t"<<alpha<<"\t"<<error_sum<<"\t"<<accuracy<<endl; cout<<"----------------------------------------------------------------\n"; //log<<train_times<<"\t"<<alpha<<"\t"<<error_sum<<"\t"<<accuracy<<endl; } //判断当前状态是否符合条件 //根据error_sum和迭代次数来判断 if(train_times > MAX_TIMES) break; if(error_sum <= E_MIN) break; }while(true); cout<<"训练完毕"<<endl; //保存权值 saveWeights(); }
void NeuralNet::backPropagationTraining(string name,int it) { CsvHandler csv; csv.loadCsv(name); vector<double> aux; vector<double> x; vector<double> auxW; //auxilar for weigths adapt double o; int k=0; double sum=0; double gError=0.0; aux.resize(csv.getNumCols()); x.resize(csv.getNumCols()-1); // Step 1: Set All weight and node offset to small random values setRandomWeights(csv.getNumCols()-1); //Step 2 Present Input and Desired Outputs while(k!=it) { for(int j=0;j<csv.getNumRows();j++) //examples { aux=CsvHandler::toDouble(csv.getRow(j)); copy(aux.begin(),aux.end()-1,x.begin()); o=aux[aux.size()-1]; //valor esperado //Step 3 Calcule actual Outputs SolveNeuralNet(x); // Step 4 Calculate the error //output layer for(unsigned int i=0;i<outLayer.size();i++) { gError=o - this->O[i]; outLayer[i].setDelta((gError)*this->O[i]*(1-this->O[i])); globalError+=gError; gError=0.0; } //hidden layer for(unsigned int i=0;i<hiddenLayer.size();i++) { sum=0; for(unsigned int j=0;j<outLayer.size();j++) { sum+=outLayer[j].getDelta()*outLayer[j].getWeight(i); } hiddenLayer[i].setDelta(hiddenLayer[i].getO()*(1-hiddenLayer[i].getO())*sum); } //input layer for(unsigned int i=0;i<inLayer.size();i++) { sum=0; for(unsigned int j=0;j<hiddenLayer.size();j++) { sum+=hiddenLayer[j].getDelta()*hiddenLayer[j].getWeight(i); } inLayer[i].setDelta(inLayer[i].getO()*(1-inLayer[i].getO())*sum); } // Step 5 Adapt Weights //Out Layer for(unsigned int i=0;i<outLayer.size();i++) //neurone { auxW.clear(); auxW=outLayer[i].getWeights(); for(unsigned int j=0;j<outLayer[i].getInputs().size();j++) //inputs { auxW[j]+=ALPHA*outLayer[i].getDelta()*outLayer[i].getInput(j); } outLayer[i].setWeight(auxW); //insert the calculate new weights in the neurone //adapt bias outLayer[i].setTheta(outLayer[i].getTheta()-ALPHA*outLayer[i].getDelta()); } //Hidden Layer for(unsigned int i=0;i<hiddenLayer.size();i++) //neurone { auxW.clear(); auxW=hiddenLayer[i].getWeights(); for(unsigned int j=0;j<hiddenLayer[i].getInputs().size();j++) //inputs { auxW[j]+=ALPHA*hiddenLayer[i].getDelta()*hiddenLayer[i].getInput(j); } hiddenLayer[i].setWeight(auxW); //insert the calculate new weights in the neurone //adapt bias hiddenLayer[i].setTheta(hiddenLayer[i].getTheta()-ALPHA*hiddenLayer[i].getDelta()); } //input layer for(unsigned int i=0;i<inLayer.size();i++) //neurone { auxW.clear(); auxW=inLayer[i].getWeights(); for(unsigned int j=0;j<inLayer[i].getInputs().size();j++) //inputs { auxW[j]+=ALPHA*inLayer[i].getDelta()*inLayer[i].getInput(j); } inLayer[i].setWeight(auxW); //insert the calculate new weights in the neurone //adapt bias inLayer[i].setTheta(inLayer[i].getTheta()-ALPHA*inLayer[i].getDelta()); } } //end examples //Step 6, Calculate Global Error for(unsigned int i=0;i<this->O.size();i++) { cout<<k<<" "<<"Global Error:"<<globalError<<endl; } globalError=0.0; k++; if(k%50==0) { cout<<"checkPoint"<< endl; //save Weights saveWeights(); } //Repeat by going to step 2 } //save Weights saveWeights(); }