コード例 #1
0
ファイル: net.cpp プロジェクト: khunglongs73/ConvNet
void Net::Train(const mxArray *mx_data, const mxArray *mx_labels) {  
  
  //mexPrintMsg("Start training...");  
  ReadData(mx_data);
  ReadLabels(mx_labels);
  InitNorm();
  
  std::srand(params_.seed_);  
  
  size_t train_num = labels_.size1();
  size_t numbatches = (size_t) ceil((ftype) train_num/params_.batchsize_);
  trainerror_.resize(params_.numepochs_, numbatches);
  for (size_t epoch = 0; epoch < params_.numepochs_; ++epoch) {    
    std::vector<size_t> randind(train_num);
    for (size_t i = 0; i < train_num; ++i) {
      randind[i] = i;
    }
    if (params_.shuffle_) {
      std::random_shuffle(randind.begin(), randind.end());
    }
    std::vector<size_t>::const_iterator iter = randind.begin();
    for (size_t batch = 0; batch < numbatches; ++batch) {
      size_t batchsize = std::min(params_.batchsize_, (size_t)(randind.end() - iter));
      std::vector<size_t> batch_ind = std::vector<size_t>(iter, iter + batchsize);
      iter = iter + batchsize;      
      Mat data_batch = SubMat(data_, batch_ind, 1);      
      Mat labels_batch = SubMat(labels_, batch_ind, 1);      
      UpdateWeights(epoch, false);      
      InitActiv(data_batch);
      Mat pred_batch;
      Forward(pred_batch, 1);      
      InitDeriv(labels_batch, trainerror_(epoch, batch));
      Backward();
      CalcWeights();      
      UpdateWeights(epoch, true); 
      if (params_.verbose_ == 2) {
        std::string info = std::string("Epoch: ") + std::to_string(epoch+1) +
                           std::string(", batch: ") + std::to_string(batch+1);
        mexPrintMsg(info);
      }      
    } // batch    
    if (params_.verbose_ == 1) {
      std::string info = std::string("Epoch: ") + std::to_string(epoch+1);
      mexPrintMsg(info);
    }
  } // epoch
  //mexPrintMsg("Training finished");
}
コード例 #2
0
ファイル: net.cpp プロジェクト: Geekrick88/ConvNet
void Net::Train(const mxArray *mx_data, const mxArray *mx_labels) {  

  //mexPrintMsg("Start training...");  
  ReadData(mx_data);
  ReadLabels(mx_labels);
  InitNorm();
  
  size_t train_num = data_.size1();
  size_t numbatches = DIVUP(train_num, params_.batchsize_);
  trainerrors_.resize(params_.epochs_, 2);
  trainerrors_.assign(0);
  for (size_t epoch = 0; epoch < params_.epochs_; ++epoch) {    
    if (params_.shuffle_) {
      Shuffle(data_, labels_);      
    }
    StartTimer();    
    size_t offset = 0;
    Mat data_batch, labels_batch, pred_batch;      
    for (size_t batch = 0; batch < numbatches; ++batch) {
      size_t batchsize = MIN(train_num - offset, params_.batchsize_);      
      UpdateWeights(epoch, false);
      data_batch.resize(batchsize, data_.size2());
      labels_batch.resize(batchsize, labels_.size2());      
      SubSet(data_, data_batch, offset, true);      
      SubSet(labels_, labels_batch, offset, true);
      ftype error1;      
      InitActiv(data_batch);
      Forward(pred_batch, 1);            
      InitDeriv(labels_batch, error1);
      trainerrors_(epoch, 0) += error1;
      Backward();
      UpdateWeights(epoch, true); 
      offset += batchsize;
      if (params_.verbose_ == 2) {
        mexPrintInt("Epoch", (int) epoch + 1);
        mexPrintInt("Batch", (int) batch + 1);
      }
    } // batch  
    MeasureTime("totaltime");
    if (params_.verbose_ == 1) {
      mexPrintInt("Epoch", (int) epoch + 1);
    }        
  } // epoch  
  trainerrors_ /= (ftype) numbatches;
  //mexPrintMsg("Training finished");
}