void Net::Train(const mxArray *mx_data, const mxArray *mx_labels) { //mexPrintMsg("Start training..."); ReadData(mx_data); ReadLabels(mx_labels); InitNorm(); std::srand(params_.seed_); size_t train_num = labels_.size1(); size_t numbatches = (size_t) ceil((ftype) train_num/params_.batchsize_); trainerror_.resize(params_.numepochs_, numbatches); for (size_t epoch = 0; epoch < params_.numepochs_; ++epoch) { std::vector<size_t> randind(train_num); for (size_t i = 0; i < train_num; ++i) { randind[i] = i; } if (params_.shuffle_) { std::random_shuffle(randind.begin(), randind.end()); } std::vector<size_t>::const_iterator iter = randind.begin(); for (size_t batch = 0; batch < numbatches; ++batch) { size_t batchsize = std::min(params_.batchsize_, (size_t)(randind.end() - iter)); std::vector<size_t> batch_ind = std::vector<size_t>(iter, iter + batchsize); iter = iter + batchsize; Mat data_batch = SubMat(data_, batch_ind, 1); Mat labels_batch = SubMat(labels_, batch_ind, 1); UpdateWeights(epoch, false); InitActiv(data_batch); Mat pred_batch; Forward(pred_batch, 1); InitDeriv(labels_batch, trainerror_(epoch, batch)); Backward(); CalcWeights(); UpdateWeights(epoch, true); if (params_.verbose_ == 2) { std::string info = std::string("Epoch: ") + std::to_string(epoch+1) + std::string(", batch: ") + std::to_string(batch+1); mexPrintMsg(info); } } // batch if (params_.verbose_ == 1) { std::string info = std::string("Epoch: ") + std::to_string(epoch+1); mexPrintMsg(info); } } // epoch //mexPrintMsg("Training finished"); }
void Net::Train(const mxArray *mx_data, const mxArray *mx_labels) { //mexPrintMsg("Start training..."); ReadData(mx_data); ReadLabels(mx_labels); InitNorm(); size_t train_num = data_.size1(); size_t numbatches = DIVUP(train_num, params_.batchsize_); trainerrors_.resize(params_.epochs_, 2); trainerrors_.assign(0); for (size_t epoch = 0; epoch < params_.epochs_; ++epoch) { if (params_.shuffle_) { Shuffle(data_, labels_); } StartTimer(); size_t offset = 0; Mat data_batch, labels_batch, pred_batch; for (size_t batch = 0; batch < numbatches; ++batch) { size_t batchsize = MIN(train_num - offset, params_.batchsize_); UpdateWeights(epoch, false); data_batch.resize(batchsize, data_.size2()); labels_batch.resize(batchsize, labels_.size2()); SubSet(data_, data_batch, offset, true); SubSet(labels_, labels_batch, offset, true); ftype error1; InitActiv(data_batch); Forward(pred_batch, 1); InitDeriv(labels_batch, error1); trainerrors_(epoch, 0) += error1; Backward(); UpdateWeights(epoch, true); offset += batchsize; if (params_.verbose_ == 2) { mexPrintInt("Epoch", (int) epoch + 1); mexPrintInt("Batch", (int) batch + 1); } } // batch MeasureTime("totaltime"); if (params_.verbose_ == 1) { mexPrintInt("Epoch", (int) epoch + 1); } } // epoch trainerrors_ /= (ftype) numbatches; //mexPrintMsg("Training finished"); }