Beispiel #1
0
/**  Start training/learning a model w.r.t. the loss object (and the data supplied to it).
 *
 */
void CBMRM::Train()
{
   CTimer totalTime;             // total runtime of the training
   CTimer innerSolverTime;       // time for inner optimization (e.g., QP or LP)
   CTimer lossAndGradientTime;   // time for loss and gradient computation
   
   unsigned int iter = 0;        // iteration count
   double loss = 0.0;            // loss function value
   double exactObjVal = 0.0;     // (exact) objective function value
   double approxObjVal = -1e99;  // convex lower-bound (approximate) of objective function value
   double minExactObjVal = 1e99; // minimum of all previously evaluated (exact) objective function value
   double regVal = 0.0;          // value of the regularizer term e.g., 0.5*w'*w
   double epsilon = 0.0;         // := minExactObjVal - approxObjVal       
   double gamma = 0.0;           // := exactObjVal - approxObjVal
   double prevEpsilon = 0.0;
   double innerSolverTol = 1.0;  // optimization tolerance for inner solver
   int exitFlag = 0;
   
   unsigned int row = 0; 
   unsigned int col = 0;
   TheMatrix &w = _model->GetW();
   
   w.Shape(row, col);   
   TheMatrix a(row, col, SML::DENSE);   // gradient vector
   TheMatrix w_best(row,col,SML::DENSE);  // w_t at which pobj is the smallest
   
#ifdef PARALLEL_BMRM
   double someloss = 0.0;                     // temporary loss incurred on each sub-dataset
   double *tmpaw = new double[row*col];       // temporary array for gradient/w
   double *tmpfinalaw = new double[row*col];  // temporary final array for reducing/broadcasting gradient/w
#endif

   // start training
   totalTime.Start();

   // Initialize piecewise linear lower bound of empirical risk
   {
      iter = 1;

      lossAndGradientTime.Start();
      _loss->ComputeLossAndGradient(loss, a);
      lossAndGradientTime.Stop();

#ifdef PARALLEL_BMRM
      MASTER(procID)
#endif
      {
         if(verbosity)
         {
            printf("Initial iteration: computing first linearization at w_0... loss(w_0)=%.6e\n",loss);
            fflush(stdout);
         }
      }


#ifdef PARALLEL_BMRM      
      // Aggregate computed loss value and gradient
      for(unsigned int rowidx=0; rowidx<row; rowidx++)
      {
         unsigned int rowlen = 0;
         a.GetRow(rowidx, rowlen, &tmpaw[rowidx*col]);
         assert(rowlen == col);
      }
      memset(tmpfinalaw,0,sizeof(double)*row*col);
      someloss = loss;
      loss = 0.0;
      MPI_Barrier(MPI_COMM_WORLD);
      MPI_Reduce(&someloss, &loss, 1, MPI_DOUBLE, MPI_SUM, ROOT_PROC_ID, MPI_COMM_WORLD);
      MPI_Reduce(tmpaw, tmpfinalaw, row*col, MPI_DOUBLE, MPI_SUM, ROOT_PROC_ID, MPI_COMM_WORLD);   
      for(unsigned int rowidx=0; rowidx<row; rowidx++)
         a.SetRow(rowidx, col, &tmpfinalaw[rowidx*col]);
#endif
   }
   
   do
   {
      iter++;

#ifdef PARALLEL_BMRM
      MASTER(procID)
#endif
      {
         // Minimize piecewise linear lower bound R_t
         innerSolverTime.Start();
         innerSolver->Solve(w, a, loss, approxObjVal);
         innerSolverTime.Stop();
      }

#ifdef PARALLEL_BMRM
      // Broadcast updated w
      for(unsigned int rowidx=0; rowidx<row; rowidx++)
      {
         unsigned int rowlen = 0;
         w.GetRow(rowidx, rowlen, &tmpaw[rowidx*col]);
         assert(rowlen == col);
      }
      memset(tmpfinalaw, 0, sizeof(double)*row*col);
      MPI_Bcast(tmpaw, row*col, MPI_DOUBLE, ROOT_PROC_ID, MPI_COMM_WORLD);
      for (unsigned int rowidx=0; rowidx<row; rowidx++)
         w.SetRow(rowidx, col, &tmpaw[rowidx*col]);
#endif

      // Compute new linearization with updated w
      lossAndGradientTime.Start();
      _loss->ComputeLossAndGradient(loss, a);
#ifdef LINESEARCH_BMRM
      loss = _loss->GetLossOfWbest();
      _loss->GetWbest(w_best);
#endif
      lossAndGradientTime.Stop();

#ifdef PARALLEL_BMRM      
      // Aggregate computed loss value and gradient
      for(unsigned int rowidx=0; rowidx<row; rowidx++)
      {
         unsigned int rowlen = 0;
         a.GetRow(rowidx, rowlen, &tmpaw[rowidx*col]);
         assert(rowlen == col);
      }
      memset(tmpfinalaw,0,sizeof(double)*row*col);
      someloss = loss;
      loss = 0.0;
      MPI_Barrier(MPI_COMM_WORLD);
      MPI_Reduce(&someloss, &loss, 1, MPI_DOUBLE, MPI_SUM, ROOT_PROC_ID, MPI_COMM_WORLD);
      MPI_Reduce(tmpaw, tmpfinalaw, row*col, MPI_DOUBLE, MPI_SUM, ROOT_PROC_ID, MPI_COMM_WORLD);   
      for(unsigned int rowidx=0; rowidx<row; rowidx++)
         a.SetRow(rowidx, col, &tmpfinalaw[rowidx*col]);
#endif
      
#ifdef PARALLEL_BMRM
      MASTER(procID)
#endif
      {
         // Update iteration details and keep best minimizer
#ifdef LINESEARCH_BMRM
         regVal = innerSolver->ComputeRegularizerValue(w_best);
         exactObjVal = loss + regVal;
         minExactObjVal = min(minExactObjVal,exactObjVal);
#else
         regVal = innerSolver->ComputeRegularizerValue(w);
         exactObjVal = loss + regVal;
         if(exactObjVal < minExactObjVal)
         {
            minExactObjVal = exactObjVal;
            w_best.Assign(w);
         }
#endif
         prevEpsilon = epsilon;
         gamma = exactObjVal - approxObjVal;
         epsilon = minExactObjVal - approxObjVal;
      
         // Optional: Adjust inner solver optimization tolerance
         //   This reduces the number of iteration most of the time.
         //   This in some sense mimics the proximity control in proximal BM
         AdjustInnerSolverOptTol(innerSolverTol, prevEpsilon, epsilon);

         // Display details of each iteration
         DisplayIterationInfo(iter,exactObjVal,approxObjVal,epsilon,gamma,
                              loss,regVal,totalTime.CurrentCPUTotal());
            
         // Save model obtained in previous iteration
         SaveCheckpointModel(iter);
      
         // Check if termination criteria satisfied
         exitFlag = CheckTermination(iter,epsilon,gamma,minExactObjVal,exactObjVal);
      }

#ifdef PARALLEL_BMRM
      // Broadcast loop termination flag
      MPI_Bcast(&exitFlag, 1, MPI_INT, ROOT_PROC_ID, MPI_COMM_WORLD);
#endif
   } while(!exitFlag);

   totalTime.Stop();
   
#ifdef PARALLEL_BMRM
   MASTER(procID)
#endif
   {
      // Display after-training details
      DisplayAfterTrainingInfo(iter,minExactObjVal,approxObjVal,loss,
                               w_best,lossAndGradientTime,innerSolverTime,totalTime);
   }
}