void HierarchyAutoSearchCrossValidationDriver::resetForOptimal(
		CyclicCoordinateDescent& ccd,
		CrossValidationSelector& selector,
		const CCDArguments& arguments) {

	ccd.setWeights(NULL);
	ccd.setHyperprior(maxPoint);
	ccd.setClassHyperprior(maxPointClass);
	ccd.resetBeta(); // Cold-start
}
MaxPoint GridSearchCrossValidationDriver::doCrossValidationLoop(
			CyclicCoordinateDescent& ccd,
			AbstractSelector& selector,
			const CCDArguments& allArguments,
			int nThreads,
			std::vector<CyclicCoordinateDescent*>& ccdPool,
			std::vector<AbstractSelector*>& selectorPool) {

    const auto& arguments = allArguments.crossValidation;

// 	std::vector<double> weights;
	for (int step = 0; step < gridSize; step++) {

		std::vector<double> predLogLikelihood;
		double point = computeGridPoint(step);
		ccd.setHyperprior(point);
		selector.reseed();

		double pointEstimate = doCrossValidationStep(ccd, selector, allArguments, step,
			nThreads, ccdPool, selectorPool,
			predLogLikelihood);
		double value = pointEstimate / (double(arguments.foldToCompute) / double(arguments.fold));

		gridPoint.push_back(point);
		gridValue.push_back(value);
	}

	// Report results
	double maxPoint;
	double maxValue;
	findMax(&maxPoint, &maxValue);

//     std::ostringstream stream;
// 	stream << std::endl;
// 	stream << "Maximum predicted log likelihood (" << maxValue << ") found at:" << std::endl;
// 	stream << "\t" << maxPoint << " (variance)" << std::endl;
// 	if (!allArguments.useNormalPrior) {
// 		double lambda = convertVarianceToHyperparameter(maxPoint);
// 		stream << "\t" << lambda << " (lambda)" << std::endl;
// 	}
// 	logger->writeLine(stream);
    std::vector<double> point(1, maxPoint);
    return MaxPoint{point, maxValue};
    //return std::vector<double>(1, maxPoint);
}
// This is specific to auto-search
std::vector<double> AutoSearchCrossValidationDriver::doCrossValidationLoop(
			CyclicCoordinateDescent& ccd,
			AbstractSelector& selector,
			const CCDArguments& allArguments,
			int nThreads,
			std::vector<CyclicCoordinateDescent*>& ccdPool,
			std::vector<AbstractSelector*>& selectorPool) {

    const auto& arguments = allArguments.crossValidation;

	double tryvalue = (arguments.startingVariance > 0) ?
	    arguments.startingVariance :
		modelData.getNormalBasedDefaultVar();

	std::ostringstream stream;
	stream << "Starting var = " << tryvalue;
	if (arguments.startingVariance == -1) {
	    stream << " (default)";
	}
	logger->writeLine(stream);

	const double tolerance = 1E-2; // TODO Make Cyclops argument

	int nDim = ccd.getHyperprior().size();
	std::vector<double> currentOptimal(nDim, tryvalue);

	bool globalFinished = false;
	std::vector<double> savedOptimal;

	while (!globalFinished) {

	    if (nDim > 1) {
	        savedOptimal = currentOptimal; // make copy
	    }

	    for (int dim = 0; dim < nDim; ++dim) {

	        // Local search
	        UniModalSearch searcher(10, 0.01, log(1.5));

	        int step = 0;
	        bool dimFinished = false;

	        while (!dimFinished) {

	            ccd.setHyperprior(dim, currentOptimal[dim]);
	            selector.reseed();

	            std::vector<double> predLogLikelihood;

	            // Newly re-located code
	            double pointEstimate = doCrossValidationStep(ccd, selector, allArguments, step,
                                                          nThreads, ccdPool, selectorPool,
                                                          predLogLikelihood);

	            double stdDevEstimate = computeStDev(predLogLikelihood, pointEstimate);

	            std::ostringstream stream;
	            stream << "AvgPred = " << pointEstimate << " with stdev = " << stdDevEstimate << std::endl;
	            searcher.tried(currentOptimal[dim], pointEstimate, stdDevEstimate);
	            pair<bool,double> next = searcher.step();
	            stream << "Completed at " << currentOptimal[dim] << std::endl;
	            stream << "Next point at " << next.second << " and " << next.first;
	            logger->writeLine(stream);

	            currentOptimal[dim] = next.second;
	            if (!next.first) {
	                dimFinished = true;
	            }
	            std::ostringstream stream1;
	            stream1 << searcher;
	            logger->writeLine(stream1);
	            step++;
	            if (step >= maxSteps) {
	                std::ostringstream stream;
	                stream << "Max steps reached!";
	                logger->writeLine(stream);
	                dimFinished = true;
	            }
	        }
	    }

	    if (nDim == 1) {
	        globalFinished = true;
	    } else {

	        double diff = 0.0;
	        for (int i = 0; i < nDim; ++i) {
	            diff += std::abs((currentOptimal[i] - savedOptimal[i]) / savedOptimal[i]);
	        }
	        std::ostringstream stream;
	        stream << "Absolute percent difference in cycle: " << diff << std::endl;

	        globalFinished = (diff < tolerance);
	    }
	}
	return currentOptimal;
}
void HierarchyAutoSearchCrossValidationDriver::drive(
		CyclicCoordinateDescent& ccd,
		AbstractSelector& selector,
		const CCDArguments& arguments) {

	// TODO Check that selector is type of CrossValidationSelector
	std::vector<real> weights;


	double tryvalue = modelData.getNormalBasedDefaultVar();
	double tryvalueClass = tryvalue; // start with same variance at the class and element level; // for hierarchy class variance
	UniModalSearch searcher(10, 0.01, log(1.5));
	UniModalSearch searcherClass(10, 0.01, log(1.5)); // Need a better way to do this.

//	const double eps = 0.05; //search stopper
    std::ostringstream stream;
	stream << "Default var = " << tryvalue;
	logger->writeLine(stream);


	bool finished = false;
	bool drugLevelFinished = false;
	bool classLevelFinished = false;

	int step = 0;
	while (!finished) {

		// More hierarchy logic
		ccd.setHyperprior(tryvalue);
		ccd.setClassHyperprior(tryvalueClass);

		std::vector<double> predLogLikelihood;

		// Newly re-located code
		double pointEstimate = doCrossValidation(ccd, selector, arguments, step, predLogLikelihood);

		double stdDevEstimate = computeStDev(predLogLikelihood, pointEstimate);

        std::ostringstream stream;
		stream << "AvgPred = " << pointEstimate << " with stdev = " << stdDevEstimate;
		logger->writeLine(stream);


        // alternate adapting the class and element level, unless one is finished
        if ((step % 2 == 0 && !drugLevelFinished) || classLevelFinished){
        	searcher.tried(tryvalue, pointEstimate, stdDevEstimate);
        	pair<bool,double> next = searcher.step();
        	tryvalue = next.second;
        	std::ostringstream stream;        	
            stream << "Next point at " << next.second << " and " << next.first;
            logger->writeLine(stream);
            if (!next.first) {
               	drugLevelFinished = true;
            }
       	} else {
       		searcherClass.tried(tryvalueClass, pointEstimate, stdDevEstimate);
       		pair<bool,double> next = searcherClass.step();
       		tryvalueClass = next.second;
       		std::ostringstream stream;
       	    stream << "Next Class point at " << next.second << " and " << next.first;
       	    logger->writeLine(stream);
            if (!next.first) {
               	classLevelFinished = true;
            }
        }
        // if everything is finished, end.
        if (drugLevelFinished && classLevelFinished){
        	finished = true;
        }

        std::ostringstream stream2;
        stream2 << searcher;
        logger->writeLine(stream2);
        step++;
        if (step >= maxSteps) {
            std::ostringstream stream;
        	stream << "Max steps reached!";
        	logger->writeLine(stream);
        	finished = true;
        }
	}

	maxPoint = tryvalue;
	maxPointClass = tryvalueClass;

	// Report results
	std::ostringstream stream2;
	stream2 << std::endl;
	stream2 << "Maximum predicted log likelihood estimated at:" << std::endl;
	stream2 << "\t" << maxPoint << " (variance)" << std::endl;
	stream2 << "class level = " << maxPointClass;
	logger->writeLine(stream2);


	if (!arguments.useNormalPrior) {
		double lambda = convertVarianceToHyperparameter(maxPoint);
		std::ostringstream stream;
		stream << "\t" << lambda << " (lambda)";
		logger->writeLine(stream);
	}
	
    std::ostringstream stream3;
	logger->writeLine(stream3);	
}