Ejemplo n.º 1
0
int NormalizeScale::normalizeWeights() {
   int status = PV_SUCCESS;

   assert(numConnections >= 1);

   // All connections in the group must have the same values of sharedWeights, numArbors, and numDataPatches
   HyPerConn * conn0 = connectionList[0];

#ifdef USE_SHMGET
#ifdef PV_USE_MPI
   if (conn->getShmgetFlag() && !callingConn->getShmgetOwner(0)) { // Assumes that all arbors are owned by the same process
      MPI_Barrier(conn->getParent()->icCommunicator()->communicator());
      return status;
   }
#endif // PV_USE_MPI
#endif // USE_SHMGET

   float scale_factor = strength;

   status = NormalizeMultiply::normalizeWeights(); // applies normalize_cutoff threshold and symmetrizeWeights

   int nxp = conn0->xPatchSize();
   int nyp = conn0->yPatchSize();
   int nfp = conn0->fPatchSize();
   int nxpShrunken = conn0->getNxpShrunken();
   int nypShrunken = conn0->getNypShrunken();
   int offsetShrunken = conn0->getOffsetShrunken();
   int xPatchStride = conn0->xPatchStride();
   int yPatchStride = conn0->yPatchStride();
   int weights_per_patch = nxp*nyp*nfp;
   int nArbors = conn0->numberOfAxonalArborLists();
   int numDataPatches = conn0->getNumDataPatches();

   for (int patchindex = 0; patchindex<numDataPatches; patchindex++) {
      for (int arborID = 0; arborID<nArbors; arborID++) {
         for (int c=0; c<numConnections; c++) {
            HyPerConn * conn = connectionList[c];
            pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID)+patchindex*weights_per_patch;
            normalizePatch(dataStartPatch, weights_per_patch, scale_factor);
         }
      }
   }
#ifdef OBSOLETE // Marked obsolete Dec 9, 2014.
#ifdef USE_SHMGET
#ifdef PV_USE_MPI
   if (conn->getShmgetFlag()) {
      assert(conn->getShmgetOwner(0)); // Assumes that all arbors are owned by the same process
      MPI_Barrier(conn->getParent()->icCommunicator()->communicator());
   }
#endif // PV_USE_MPI
#endif // USE_SHMGET
#endif // OBSOLETE
   return status;
}
int NormalizeContrastZeroMean::normalizeWeights() {
   int status = PV_SUCCESS;

   assert(numConnections >= 1);

   // TODO: need to ensure that all connections in connectionList have same nxp,nyp,nfp,numArbors,numDataPatches
   HyPerConn * conn0 = connectionList[0];
   for (int c=1; c<numConnections; c++) {
      HyPerConn * conn = connectionList[c];
      if (conn->numberOfAxonalArborLists() != conn0->numberOfAxonalArborLists()) {
         if (parent->columnId() == 0) {
            pvErrorNoExit().printf("Normalizer %s: All connections in the normalization group must have the same number of arbors (Connection \"%s\" has %d; connection \"%s\" has %d).\n",
                  this->getName(), conn0->getName(), conn0->numberOfAxonalArborLists(), conn->getName(), conn->numberOfAxonalArborLists());
         }
         status = PV_FAILURE;
      }
      if (conn->getNumDataPatches() != conn0->getNumDataPatches()) {
         if (parent->columnId() == 0) {
            pvErrorNoExit().printf("Normalizer %s: All connections in the normalization group must have the same number of data patches (Connection \"%s\" has %d; connection \"%s\" has %d).\n",
                  this->getName(), conn0->getName(), conn0->getNumDataPatches(), conn->getName(), conn->getNumDataPatches());
         }
         status = PV_FAILURE;
      }
      if (status==PV_FAILURE) {
         MPI_Barrier(parent->icCommunicator()->communicator());
         exit(EXIT_FAILURE);
      }
   }

   float scale_factor = strength;

   status = NormalizeBase::normalizeWeights(); // applies normalize_cutoff threshold and symmetrizeWeights

   int nArbors = conn0->numberOfAxonalArborLists();
   int numDataPatches = conn0->getNumDataPatches();
   if (normalizeArborsIndividually) {
      for (int arborID = 0; arborID<nArbors; arborID++) {
         for (int patchindex = 0; patchindex<numDataPatches; patchindex++) {
            double sum = 0.0;
            double sumsq = 0.0;
            int weights_per_patch = 0;
            for (int c=0; c<numConnections; c++) {
               HyPerConn * conn = connectionList[c];
               int nxp = conn0->xPatchSize();
               int nyp = conn0->yPatchSize();
               int nfp = conn0->fPatchSize();
               weights_per_patch += nxp*nyp*nfp;
               pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID) + patchindex * weights_per_patch;
               accumulateSumAndSumSquared(dataStartPatch, weights_per_patch, &sum, &sumsq);
            }
            if (fabs(sum) <= minSumTolerated) {
               pvWarn().printf("for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d of arbor %d is within minSumTolerated=%f of zero. Weights in this patch unchanged.\n", this->getName(), patchindex, arborID, minSumTolerated);
               break; // TODO: continue instead of break?  continue as opposed to break is more consistent with warning above.
            }
            float mean = sum/weights_per_patch;
            float var = sumsq/weights_per_patch - mean*mean;
            for (int c=0; c<numConnections; c++) {
               HyPerConn * conn = connectionList[c];
               pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID) + patchindex * weights_per_patch;
               subtractOffsetAndNormalize(dataStartPatch, weights_per_patch, sum/weights_per_patch, sqrt(var)/scale_factor);
            }
         }
      }
   }
   else {
      for (int patchindex = 0; patchindex<numDataPatches; patchindex++) {
         double sum = 0.0;
         double sumsq = 0.0;
         int weights_per_patch = 0;
         for (int arborID = 0; arborID<nArbors; arborID++) {
            for (int c=0; c<numConnections; c++) {
               HyPerConn * conn = connectionList[c];
               int nxp = conn0->xPatchSize();
               int nyp = conn0->yPatchSize();
               int nfp = conn0->fPatchSize();
               weights_per_patch += nxp*nyp*nfp;
               pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID)+patchindex*weights_per_patch;
               accumulateSumAndSumSquared(dataStartPatch, weights_per_patch, &sum, &sumsq);
            }
         }
         if (fabs(sum) <= minSumTolerated) {
            pvWarn().printf("for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d is within minSumTolerated=%f of zero. Weights in this patch unchanged.\n", getName(), patchindex, minSumTolerated);
            break; // TODO: continue instead of break?  continue as opposed to break is more consistent with warning above.

         }
         int count = weights_per_patch*nArbors;
         float mean = sum/count;
         float var = sumsq/count - mean*mean;
         for (int arborID = 0; arborID<nArbors; arborID++) {
            for (int c=0; c<numConnections; c++) {
               HyPerConn * conn = connectionList[c];
               pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID)+patchindex*weights_per_patch;
               subtractOffsetAndNormalize(dataStartPatch, weights_per_patch, mean, sqrt(var)/scale_factor);
            }
         }
      }
   }

   return status;
}