Пример #1
0
STATIC Adagrad *Adagrad::instance(EasyCL *cl, float learningRate) {
    Adagrad *sgd = new Adagrad(cl);
    sgd->setLearningRate(learningRate);
    return sgd;
}
Пример #2
0
void go(Config config) {
    Timer timer;

    int Ntrain;
    int Ntest;
    int numPlanes;
    int imageSize;

    float *trainData = 0;
    float *testData = 0;
    int *trainLabels = 0;
    int *testLabels = 0;

    int trainAllocateN = 0;
    int testAllocateN = 0;

//    int totalLinearSize;
    GenericLoaderv2 trainLoader( config.dataDir + "/" + config.trainFile );
    Ntrain = trainLoader.getN();
    numPlanes = trainLoader.getPlanes();
    imageSize = trainLoader.getImageSize();
    // GenericLoader::getDimensions( , &Ntrain, &numPlanes, &imageSize );
    Ntrain = config.numTrain == -1 ? Ntrain : config.numTrain;
//    long allocateSize = (long)Ntrain * numPlanes * imageSize * imageSize;
    cout << "Ntrain " << Ntrain << " numPlanes " << numPlanes << " imageSize " << imageSize << endl;
    if( config.loadOnDemand ) {
        trainAllocateN = config.batchSize; // can improve this later
    } else {
        trainAllocateN = Ntrain;
    }
    trainData = new float[ (long)trainAllocateN * numPlanes * imageSize * imageSize ];
    trainLabels = new int[trainAllocateN];
    if( !config.loadOnDemand && Ntrain > 0 ) {
        trainLoader.load( trainData, trainLabels, 0, Ntrain );
    }

    GenericLoaderv2 testLoader( config.dataDir + "/" + config.validateFile );
    Ntest = testLoader.getN();
    numPlanes = testLoader.getPlanes();
    imageSize = testLoader.getImageSize();
    Ntest = config.numTest == -1 ? Ntest : config.numTest;
    if( config.loadOnDemand ) {
        testAllocateN = config.batchSize; // can improve this later
    } else {
        testAllocateN = Ntest;
    }
    testData = new float[ (long)testAllocateN * numPlanes * imageSize * imageSize ];
    testLabels = new int[testAllocateN]; 
    if( !config.loadOnDemand && Ntest > 0 ) {
        testLoader.load( testData, testLabels, 0, Ntest );
    }
    cout << "Ntest " << Ntest << " Ntest" << endl;
    
    timer.timeCheck("after load images");

    const int inputCubeSize = numPlanes * imageSize * imageSize;
    float translate;
    float scale;
    int normalizationExamples = config.normalizationExamples > Ntrain ? Ntrain : config.normalizationExamples;
    if( !config.loadOnDemand ) {
        if( config.normalization == "stddev" ) {
            float mean, stdDev;
            NormalizationHelper::getMeanAndStdDev( trainData, normalizationExamples * inputCubeSize, &mean, &stdDev );
            cout << " image stats mean " << mean << " stdDev " << stdDev << endl;
            translate = - mean;
            scale = 1.0f / stdDev / config.normalizationNumStds;
        } else if( config.normalization == "maxmin" ) {
            float mean, stdDev;
            NormalizationHelper::getMinMax( trainData, normalizationExamples * inputCubeSize, &mean, &stdDev );
            translate = - mean;
            scale = 1.0f / stdDev;
        } else {
            cout << "Error: Unknown normalization: " << config.normalization << endl;
            return;
        }
    } else {
        if( config.normalization == "stddev" ) {
            float mean, stdDev;
            NormalizeGetStdDev normalizeGetStdDev( trainData, trainLabels ); 
            BatchProcessv2::run( &trainLoader, 0, config.batchSize, normalizationExamples, inputCubeSize, &normalizeGetStdDev );
            normalizeGetStdDev.calcMeanStdDev( &mean, &stdDev );
            cout << " image stats mean " << mean << " stdDev " << stdDev << endl;
            translate = - mean;
            scale = 1.0f / stdDev / config.normalizationNumStds;
        } else if( config.normalization == "maxmin" ) {
            NormalizeGetMinMax normalizeGetMinMax( trainData, trainLabels );
            BatchProcessv2::run( &trainLoader, 0, config.batchSize, normalizationExamples, inputCubeSize, &normalizeGetMinMax );
            normalizeGetMinMax.calcMinMaxTransform( &translate, &scale );
        } else {
            cout << "Error: Unknown normalization: " << config.normalization << endl;
            return;
        }
    }
    cout << " image norm translate " << translate << " scale " << scale << endl;
    timer.timeCheck("after getting stats");

//    const int numToTrain = Ntrain;
//    const int batchSize = config.batchSize;

    EasyCL *cl = 0;
    if( config.gpuIndex >= 0 ) {
        cl = EasyCL::createForIndexedGpu( config.gpuIndex );
    } else {
        cl = EasyCL::createForFirstGpuOtherwiseCpu();
    }

    NeuralNet *net;
    net = new NeuralNet(cl);

    WeightsInitializer *weightsInitializer = 0;
    if( toLower( config.weightsInitializer ) == "original" ) {
        weightsInitializer = new OriginalInitializer();
    } else if( toLower( config.weightsInitializer ) == "uniform" ) {
        weightsInitializer = new UniformInitializer( config.initialWeights );
    } else {
        cout << "Unknown weights initializer " << config.weightsInitializer << endl;
        return;
    }

//    net->inputMaker<unsigned char>()->numPlanes(numPlanes)->imageSize(imageSize)->insert();
    net->addLayer( InputLayerMaker::instance()->numPlanes(numPlanes)->imageSize(imageSize) );
    net->addLayer( NormalizationLayerMaker::instance()->translate(translate)->scale(scale) );
    if( !NetdefToNet::createNetFromNetdef( net, config.netDef, weightsInitializer ) ) {
        return;
    }
    // apply the trainer
    Trainer *trainer = 0;
    if( toLower( config.trainer ) == "sgd" ) {
        SGD *sgd = new SGD( cl );
        sgd->setLearningRate( config.learningRate );
        sgd->setMomentum( config.momentum );
        sgd->setWeightDecay( config.weightDecay );
        trainer = sgd;
    } else if( toLower( config.trainer ) == "anneal" ) {
        Annealer *annealer = new Annealer( cl );
        annealer->setLearningRate( config.learningRate );
        annealer->setAnneal( config.anneal );
        trainer = annealer;
    } else if( toLower( config.trainer ) == "nesterov" ) {
        Nesterov *nesterov = new Nesterov( cl );
        nesterov->setLearningRate( config.learningRate );
        nesterov->setMomentum( config.momentum );
        trainer = nesterov;
    } else if( toLower( config.trainer ) == "adagrad" ) {
        Adagrad *adagrad = new Adagrad( cl );
        adagrad->setLearningRate( config.learningRate );
        trainer = adagrad;
    } else if( toLower( config.trainer ) == "rmsprop" ) {
        Rmsprop *rmsprop = new Rmsprop( cl );
        rmsprop->setLearningRate( config.learningRate );
        trainer = rmsprop;
    } else if( toLower( config.trainer ) == "adadelta" ) {
        Adadelta *adadelta = new Adadelta( cl, config.rho );
        trainer = adadelta;
    } else {
        cout << "trainer " << config.trainer << " unknown." << endl;
        return;
    }
    cout << "Using trainer " << trainer->asString() << endl;
//    trainer->bindTo( net );
//    net->setTrainer( trainer );
    net->setBatchSize( config.batchSize );
    net->print();

    bool afterRestart = false;
    int restartEpoch = 0;
    int restartBatch = 0;
    float restartAnnealedLearningRate = 0;
    int restartNumRight = 0;
    float restartLoss = 0;
    if( config.loadWeights && config.weightsFile != "" ) {
        cout << "loadingweights" << endl;
        afterRestart = WeightsPersister::loadWeights( config.weightsFile, config.getTrainingString(), net, &restartEpoch, &restartBatch, &restartAnnealedLearningRate, &restartNumRight, &restartLoss );
        if( !afterRestart && FileHelper::exists( config.weightsFile ) ) {
            // try old trainingstring
            afterRestart = WeightsPersister::loadWeights( config.weightsFile, config.getOldTrainingString(), net, &restartEpoch, &restartBatch, &restartAnnealedLearningRate, &restartNumRight, &restartLoss );
        }
        if( !afterRestart && FileHelper::exists( config.weightsFile ) ) {
            cout << "Weights file " << config.weightsFile << " exists, but doesnt match training options provided." << endl;
            cout << "Continue loading anyway (might crash, or weights might be completely inappropriate)? (y/n)" << endl;
            string response;
            cin >> response;
            if( response != "y" ) {
                cout << "Please either check the training options, or choose a weights file that doesnt exist yet" << endl;
                return;
            }
        }