int customexit(HyPerCol * hc, int argc, char ** argv) { BaseConnection * baseConn; baseConn = hc->getConnFromName("initializeFromInitWeights"); HyPerConn * initializeFromInitWeightsConn = dynamic_cast<HyPerConn *>(baseConn); // There must be a connection named initializeFromInitWeights. It should have a single weight with value 1 assert(initializeFromInitWeightsConn); assert(initializeFromInitWeightsConn->xPatchSize()==1); assert(initializeFromInitWeightsConn->yPatchSize()==1); assert(initializeFromInitWeightsConn->fPatchSize()==1); assert(initializeFromInitWeightsConn->numberOfAxonalArborLists()==1); assert(initializeFromInitWeightsConn->get_wData(0,0)[0] == 1.0f); // There must be a connection named initializeFromCheckpoint. It should have a single weight with value 2 baseConn = hc->getConnFromName("initializeFromCheckpoint"); HyPerConn * initializeFromCheckpointConn = dynamic_cast<HyPerConn *>(baseConn); assert(initializeFromCheckpointConn); assert(initializeFromCheckpointConn->xPatchSize()==1); assert(initializeFromCheckpointConn->yPatchSize()==1); assert(initializeFromCheckpointConn->fPatchSize()==1); assert(initializeFromCheckpointConn->numberOfAxonalArborLists()==1); assert(initializeFromCheckpointConn->get_wData(0,0)[0] == 2.0f); return PV_SUCCESS; }
int NormalizeContrastZeroMean::normalizeWeights() { int status = PV_SUCCESS; assert(numConnections >= 1); // TODO: need to ensure that all connections in connectionList have same nxp,nyp,nfp,numArbors,numDataPatches HyPerConn * conn0 = connectionList[0]; for (int c=1; c<numConnections; c++) { HyPerConn * conn = connectionList[c]; if (conn->numberOfAxonalArborLists() != conn0->numberOfAxonalArborLists()) { if (parent->columnId() == 0) { pvErrorNoExit().printf("Normalizer %s: All connections in the normalization group must have the same number of arbors (Connection \"%s\" has %d; connection \"%s\" has %d).\n", this->getName(), conn0->getName(), conn0->numberOfAxonalArborLists(), conn->getName(), conn->numberOfAxonalArborLists()); } status = PV_FAILURE; } if (conn->getNumDataPatches() != conn0->getNumDataPatches()) { if (parent->columnId() == 0) { pvErrorNoExit().printf("Normalizer %s: All connections in the normalization group must have the same number of data patches (Connection \"%s\" has %d; connection \"%s\" has %d).\n", this->getName(), conn0->getName(), conn0->getNumDataPatches(), conn->getName(), conn->getNumDataPatches()); } status = PV_FAILURE; } if (status==PV_FAILURE) { MPI_Barrier(parent->icCommunicator()->communicator()); exit(EXIT_FAILURE); } } float scale_factor = strength; status = NormalizeBase::normalizeWeights(); // applies normalize_cutoff threshold and symmetrizeWeights int nArbors = conn0->numberOfAxonalArborLists(); int numDataPatches = conn0->getNumDataPatches(); if (normalizeArborsIndividually) { for (int arborID = 0; arborID<nArbors; arborID++) { for (int patchindex = 0; patchindex<numDataPatches; patchindex++) { double sum = 0.0; double sumsq = 0.0; int weights_per_patch = 0; for (int c=0; c<numConnections; c++) { HyPerConn * conn = connectionList[c]; int nxp = conn0->xPatchSize(); int nyp = conn0->yPatchSize(); int nfp = conn0->fPatchSize(); weights_per_patch += nxp*nyp*nfp; pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID) + patchindex * weights_per_patch; accumulateSumAndSumSquared(dataStartPatch, weights_per_patch, &sum, &sumsq); } if (fabs(sum) <= minSumTolerated) { pvWarn().printf("for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d of arbor %d is within minSumTolerated=%f of zero. Weights in this patch unchanged.\n", this->getName(), patchindex, arborID, minSumTolerated); break; // TODO: continue instead of break? continue as opposed to break is more consistent with warning above. } float mean = sum/weights_per_patch; float var = sumsq/weights_per_patch - mean*mean; for (int c=0; c<numConnections; c++) { HyPerConn * conn = connectionList[c]; pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID) + patchindex * weights_per_patch; subtractOffsetAndNormalize(dataStartPatch, weights_per_patch, sum/weights_per_patch, sqrt(var)/scale_factor); } } } } else { for (int patchindex = 0; patchindex<numDataPatches; patchindex++) { double sum = 0.0; double sumsq = 0.0; int weights_per_patch = 0; for (int arborID = 0; arborID<nArbors; arborID++) { for (int c=0; c<numConnections; c++) { HyPerConn * conn = connectionList[c]; int nxp = conn0->xPatchSize(); int nyp = conn0->yPatchSize(); int nfp = conn0->fPatchSize(); weights_per_patch += nxp*nyp*nfp; pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID)+patchindex*weights_per_patch; accumulateSumAndSumSquared(dataStartPatch, weights_per_patch, &sum, &sumsq); } } if (fabs(sum) <= minSumTolerated) { pvWarn().printf("for NormalizeContrastZeroMean \"%s\": sum of weights in patch %d is within minSumTolerated=%f of zero. Weights in this patch unchanged.\n", getName(), patchindex, minSumTolerated); break; // TODO: continue instead of break? continue as opposed to break is more consistent with warning above. } int count = weights_per_patch*nArbors; float mean = sum/count; float var = sumsq/count - mean*mean; for (int arborID = 0; arborID<nArbors; arborID++) { for (int c=0; c<numConnections; c++) { HyPerConn * conn = connectionList[c]; pvwdata_t * dataStartPatch = conn->get_wDataStart(arborID)+patchindex*weights_per_patch; subtractOffsetAndNormalize(dataStartPatch, weights_per_patch, mean, sqrt(var)/scale_factor); } } } } return status; }