void SingletonGraph::execute(Ptr<data::Batch> batch) { auto costNode = builder_->build(graph_, batch); graph_->forward(); float cost = costNode->scalar(); graph_->backward(); // Get batch stats size_t batch_words = batch->wordsTrg(); if(scaleLearningRate_) { opt_->update(graph_, batch_words / avgBatchWords_); } else { opt_->update(graph_); } if(mvAvg_) { ABORT_IF(!scheduler_, "Scheduler is required for exponential smoothing"); if(!graphAvg_) { graphAvg_ = New<ExpressionGraph>(); graphAvg_->setDevice(graph_->getDeviceId()); graphAvg_->copyParams(graph_); } else { updateAvgParams(graphAvg_->params()->vals(), graph_->params()->vals(), scheduler_->numberOfBatches()); } } if(scheduler_) { scheduler_->update(cost, batch); if(scheduler_->validating()) { if(mvAvg_) { graphAvg_->reuseWorkspace(graph_); scheduler_->validate({graphAvg_}); } else { scheduler_->validate({graph_}); } } if(scheduler_->saving()) this->save(); } }
void SingletonGraph::execute(Ptr<data::Batch> batch) { auto lossNode = builder_->build(graph_, batch); graph_->forward(); graph_->backward(); // Get batch stats opt_->update(graph_); if(mvAvg_) { ABORT_IF(!scheduler_, "Scheduler is required for exponential smoothing"); if(!graphAvg_) { graphAvg_ = New<ExpressionGraph>(); graphAvg_->setDevice(graph_->getDeviceId()); graphAvg_->copyParams(graph_); } else { updateAvgParams(graphAvg_->params()->vals(), graph_->params()->vals(), scheduler_->numberOfBatches()); } } if(scheduler_) { scheduler_->update(*lossNode, batch); if(scheduler_->validating()) { if(mvAvg_) { graphAvg_->reuseWorkspace(graph_); scheduler_->validate({graphAvg_}); } else { scheduler_->validate({graph_}); } } if(scheduler_->saving()) this->save(); } }