예제 #1
0
int customexit(HyPerCol * hc, int argc, char ** argv) {
   BaseConnection * baseConn;
   baseConn = hc->getConnFromName("initializeFromInitWeights");
   HyPerConn * initializeFromInitWeightsConn = dynamic_cast<HyPerConn *>(baseConn);
   // There must be a connection named initializeFromInitWeights. It should have a single weight with value 1
   assert(initializeFromInitWeightsConn);
   assert(initializeFromInitWeightsConn->xPatchSize()==1);
   assert(initializeFromInitWeightsConn->yPatchSize()==1);
   assert(initializeFromInitWeightsConn->fPatchSize()==1);
   assert(initializeFromInitWeightsConn->numberOfAxonalArborLists()==1);
   assert(initializeFromInitWeightsConn->get_wData(0,0)[0] == 1.0f);

   // There must be a connection named initializeFromCheckpoint.  It should have a single weight with value 2
   baseConn = hc->getConnFromName("initializeFromCheckpoint");
   HyPerConn * initializeFromCheckpointConn = dynamic_cast<HyPerConn *>(baseConn);
   assert(initializeFromCheckpointConn);
   assert(initializeFromCheckpointConn->xPatchSize()==1);
   assert(initializeFromCheckpointConn->yPatchSize()==1);
   assert(initializeFromCheckpointConn->fPatchSize()==1);
   assert(initializeFromCheckpointConn->numberOfAxonalArborLists()==1);
   assert(initializeFromCheckpointConn->get_wData(0,0)[0] == 2.0f);
   return PV_SUCCESS;
}
int NormalizeContrastZeroMean::normalizeWeights() {
   int status = PV_SUCCESS;

   assert(numConnections >= 1);

   // TODO: need to ensure that all connections in connectionList have same nxp,nyp,nfp,numArbors,numDataPatches
   HyPerConn * conn0 = connectionList[0];
   for (int c=1; c<numConnections; c++) {
      HyPerConn * conn = connectionList[c];
      if (conn->numberOfAxonalArborLists() != conn0->numberOfAxonalArborLists()) {
         if (parent->columnId() == 0) {
            pvErrorNoExit().printf("Normalizer %s: All connections in the normalization group must have the same number of arbors (Connection \"%s\" has %d; connection \"%s\" has %d).\n",
                  this->getName(), conn0->getName(), conn0->numberOfAxonalArborLists(), conn->getName(), conn->numberOfAxonalArborLists());
         }
         status = PV_FAILURE;
      }
      if (conn->getNumDataPatches() != conn0->getNumDataPatches()) {
         if (parent->columnId() == 0) {
            pvErrorNoExit().printf("Normalizer %s: All connections in the normalization group must have the same number of data patches (Connection \"%s\" has %d; connection \"%s\" has %d).\n",
                  this->getName(), conn0->getName(), conn0->getNumDataPatches(), conn->getName(), conn->getNumDataPatches());
         }
         status = PV_FAILURE;
      }
      if (status==PV_FAILURE) {
         MPI_Barrier(parent->icCommunicator()->communicator());
         exit(EXIT_FAILURE);
      }
   }

   float scale_factor = strength;

   status = NormalizeBase::normalizeWeights(); // applies normalize_cutoff threshold and symmetrizeWeights

   int nArbors = conn0->numberOfAxonalArborLists();
   int numDataPatches = conn0->getNumDataPatches();
   if (normalizeArborsIndividually) {
      for (int arborID = 0; arborID<nArbors; arborID++) {
         for (int patchindex = 0; patchindex<numDataPatches; patchindex++) {
            double sum = 0.0;
            double sumsq = 0.0;
            int weights_per_patch = 0;
            for (int c=0; c<numConnections; c++) {
               HyPerConn * conn = connectionList[c];
               int nxp = conn0->xPatchSize();
               int nyp = conn0->yPatchSize();
               int nfp = conn0->fPatchSize();
               weights_per_patch += nxp*nyp*nfp;
               pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID) + patchindex * weights_per_patch;
               accumulateSumAndSumSquared(dataStartPatch, weights_per_patch, &sum, &sumsq);
            }
            if (fabs(sum) <= minSumTolerated) {
               pvWarn().printf("for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d of arbor %d is within minSumTolerated=%f of zero. Weights in this patch unchanged.\n", this->getName(), patchindex, arborID, minSumTolerated);
               break; // TODO: continue instead of break?  continue as opposed to break is more consistent with warning above.
            }
            float mean = sum/weights_per_patch;
            float var = sumsq/weights_per_patch - mean*mean;
            for (int c=0; c<numConnections; c++) {
               HyPerConn * conn = connectionList[c];
               pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID) + patchindex * weights_per_patch;
               subtractOffsetAndNormalize(dataStartPatch, weights_per_patch, sum/weights_per_patch, sqrt(var)/scale_factor);
            }
         }
      }
   }
   else {
      for (int patchindex = 0; patchindex<numDataPatches; patchindex++) {
         double sum = 0.0;
         double sumsq = 0.0;
         int weights_per_patch = 0;
         for (int arborID = 0; arborID<nArbors; arborID++) {
            for (int c=0; c<numConnections; c++) {
               HyPerConn * conn = connectionList[c];
               int nxp = conn0->xPatchSize();
               int nyp = conn0->yPatchSize();
               int nfp = conn0->fPatchSize();
               weights_per_patch += nxp*nyp*nfp;
               pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID)+patchindex*weights_per_patch;
               accumulateSumAndSumSquared(dataStartPatch, weights_per_patch, &sum, &sumsq);
            }
         }
         if (fabs(sum) <= minSumTolerated) {
            pvWarn().printf("for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d is within minSumTolerated=%f of zero. Weights in this patch unchanged.\n", getName(), patchindex, minSumTolerated);
            break; // TODO: continue instead of break?  continue as opposed to break is more consistent with warning above.

         }
         int count = weights_per_patch*nArbors;
         float mean = sum/count;
         float var = sumsq/count - mean*mean;
         for (int arborID = 0; arborID<nArbors; arborID++) {
            for (int c=0; c<numConnections; c++) {
               HyPerConn * conn = connectionList[c];
               pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID)+patchindex*weights_per_patch;
               subtractOffsetAndNormalize(dataStartPatch, weights_per_patch, mean, sqrt(var)/scale_factor);
            }
         }
      }
   }

   return status;
}