Beispiel #1
0
void testSPDMean(void)
{
	// choose a random seed
	unsigned tt = (unsigned)time(NULL);
	tt = 0;
	genrandseed(tt);

	/*Randomly generate a point on the SPD manifold*/
	integer n = 100, num = 4;
	SPDVariable SPDX(n);
	double *initialX = SPDX.ObtainWriteEntireData();
	for (integer i = 0; i < n; i++)
	{
		for (integer j = 0; j < n; j++)
		{
			initialX[i + j * n] = 0;
		}
		initialX[i + i * n] = 1;
	}

	// Define the manifold
	SPDManifold Domain(n);
	Domain.SetHasHHR(true); /*set whether the manifold uses the idea in [HGA2015, Section 4.3] or not*/

	double *Ls = new double[n * n * num + n * n];
	double *tmp = Ls + n * n * num;
	integer info;
	for (integer i = 0; i < num; i++)
	{
		for (integer j = 0; j < n * n; j++)
			tmp[j] = genrandnormal();

		dgemm_(GLOBAL::N, GLOBAL::T, &n, &n, &n, &GLOBAL::DONE, tmp, &n, tmp, &n, &GLOBAL::DZERO, Ls + i * n * n, &n);


		dpotrf_(GLOBAL::L, &n, Ls + i * n * n, &n, &info);
		if (info != 0)
		{
			std::cout << "Warning: TestSPDMean Cholesky decomposition fails with info:" << info << "!" << std::endl;
		}
		for (integer j = 0; j < n; j++)
			for (integer k = j + 1; k < n; k++)
				Ls[j + k * n + i * n * n] = 0;
	}

	// Define the problem
	SPDMean Prob(Ls, n, num);
	/*The domain of the problem is a SPD manifold*/
	Prob.SetDomain(&Domain);

	//Prob.CheckGradHessian(&SPDX);

	/*Output the parameters of the domain manifold*/
	Domain.CheckParams();

	/*Check the correctness of the manifold operations*/
	//Domain.CheckIntrExtr(&SPDX);
	//Domain.CheckRetraction(&SPDX);
	//Domain.CheckDiffRetraction(&SPDX);
	//Domain.CheckLockingCondition(&SPDX);
	//Domain.CheckcoTangentVector(&SPDX);
	//Domain.CheckIsometryofVectorTransport(&SPDX);
	//Domain.CheckIsometryofInvVectorTransport(&SPDX);
	//Domain.CheckVecTranComposeInverseVecTran(&SPDX);
	//Domain.CheckTranHInvTran(&SPDX);
	//Domain.CheckHaddScaledRank1OPE(&SPDX);

	// test LRBFGS
	std::cout << "********************************Test Geometric mean in LRBFGS*************************************" << std::endl;
	LRBFGS *LRBFGSsolver = new LRBFGS(&Prob, &SPDX);
	LRBFGSsolver->LineSearch_LS = ARMIJO;
	LRBFGSsolver->Debug = ITERRESULT; //ITERRESULT;// 
	LRBFGSsolver->Max_Iteration = 20;
	LRBFGSsolver->Tolerance = 1e-10;
	LRBFGSsolver->Accuracy = 1e-4;
	LRBFGSsolver->Finalstepsize = 1;
	LRBFGSsolver->CheckParams();
	LRBFGSsolver->Run();
	delete LRBFGSsolver;

	delete[] Ls;
};
void testEucQuadratic(double *M, integer dim, double *X, double *Xopt)
{
	// choose a random seed
	unsigned tt = (unsigned)time(NULL);
	std::cout << "tt:" << tt << std::endl;
	tt = 0;
	init_genrand(tt);

	// Obtain an initial iterate
	EucVariable EucX(dim, 1);
	if (X == nullptr)
	{
		EucX.RandInManifold();
	}
	else
	{
		double *EucXptr = EucX.ObtainWriteEntireData();
		for (integer i = 0; i < dim; i++)
			EucXptr[i] = X[i];
	}

	// Define the manifold
	Euclidean Domain(dim);

	// Define the problem
	EucQuadratic Prob(M, dim);
	Prob.SetDomain(&Domain);

	// test RSD
	std::cout << "********************************Check all line search algorithm in RSD*****************************************" << std::endl;
	for (integer i = 0; i < LSALGOLENGTH; i++)
	{
		RSD *RSDsolver = new RSD(&Prob, &EucX);
		RSDsolver->LineSearch_LS = static_cast<LSAlgo> (i);
		RSDsolver->DEBUG = FINALRESULT;
		RSDsolver->CheckParams();
		RSDsolver->Run();
		delete RSDsolver;
	}
	// test RNewton
	std::cout << "********************************Check all line search algorithm in RNewton*************************************" << std::endl;
	for (integer i = 0; i < LSALGOLENGTH; i++)
	{
		RNewton *RNewtonsolver = new RNewton(&Prob, &EucX);
		RNewtonsolver->LineSearch_LS = static_cast<LSAlgo> (i);
		RNewtonsolver->DEBUG = FINALRESULT;
		RNewtonsolver->CheckParams();
		RNewtonsolver->Run();
		delete RNewtonsolver;
	}

	// test RCG
	std::cout << "********************************Check all Formulas in RCG*************************************" << std::endl;
	for (integer i = 0; i < RCGMETHODSLENGTH; i++)
	{
		RCG *RCGsolver = new RCG(&Prob, &EucX);
		RCGsolver->RCGmethod = static_cast<RCGmethods> (i);
		RCGsolver->LineSearch_LS = STRONGWOLFE;
		RCGsolver->LS_beta = 0.1;
		RCGsolver->DEBUG = FINALRESULT;
		RCGsolver->CheckParams();
		RCGsolver->Run();
		delete RCGsolver;
	}

	// test RBroydenFamily
	std::cout << "********************************Check all line search algorithm in RBroydenFamily*************************************" << std::endl;
	for (integer i = 0; i < LSALGOLENGTH; i++)
	{
		RBroydenFamily *RBroydenFamilysolver = new RBroydenFamily(&Prob, &EucX);
		RBroydenFamilysolver->LineSearch_LS = static_cast<LSAlgo> (i);
		RBroydenFamilysolver->DEBUG = FINALRESULT;
		RBroydenFamilysolver->CheckParams();
		RBroydenFamilysolver->Run();
		delete RBroydenFamilysolver;
	}

	// test RWRBFGS
	std::cout << "********************************Check all line search algorithm in RWRBFGS*************************************" << std::endl;
	for (integer i = 0; i < LSALGOLENGTH; i++)
	{
		RWRBFGS *RWRBFGSsolver = new RWRBFGS(&Prob, &EucX);
		RWRBFGSsolver->LineSearch_LS = static_cast<LSAlgo> (i);
		RWRBFGSsolver->DEBUG = FINALRESULT;
		RWRBFGSsolver->CheckParams();
		RWRBFGSsolver->Run();
		delete RWRBFGSsolver;
	}

	// test RBFGS
	std::cout << "********************************Check all line search algorithm in RBFGS*************************************" << std::endl;
	for (integer i = 0; i < LSALGOLENGTH; i++)
	{
		RBFGS *RBFGSsolver = new RBFGS(&Prob, &EucX);
		RBFGSsolver->LineSearch_LS = static_cast<LSAlgo> (i);
		RBFGSsolver->DEBUG = FINALRESULT;
		RBFGSsolver->CheckParams();
		RBFGSsolver->Run();
		delete RBFGSsolver;
	}

	// test LRBFGS
	std::cout << "********************************Check all line search algorithm in LRBFGS*************************************" << std::endl;
	for (integer i = 0; i < LSALGOLENGTH; i++)
	{
		LRBFGS *LRBFGSsolver = new LRBFGS(&Prob, &EucX);
		LRBFGSsolver->LineSearch_LS = static_cast<LSAlgo> (i);
		LRBFGSsolver->DEBUG = FINALRESULT;
		LRBFGSsolver->CheckParams();
		LRBFGSsolver->Run();
		delete LRBFGSsolver;
	}

	std::cout << "********************************Check RTRSD*************************************" << std::endl;
	RTRSD RTRSDsolver(&Prob, &EucX);
	std::cout << std::endl;
	RTRSDsolver.DEBUG = FINALRESULT;
	RTRSDsolver.CheckParams();
	RTRSDsolver.Run();

	std::cout << "********************************Check RTRNewton*************************************" << std::endl;
	RTRNewton RTRNewtonsolver(&Prob, &EucX);
	std::cout << std::endl;
	RTRNewtonsolver.DEBUG = FINALRESULT;
	RTRNewtonsolver.CheckParams();
	RTRNewtonsolver.Run();

	std::cout << "********************************Check RTRSR1*************************************" << std::endl;
	RTRSR1 RTRSR1solver(&Prob, &EucX);
	std::cout << std::endl;
	RTRSR1solver.DEBUG = FINALRESULT;
	RTRSR1solver.CheckParams();
	RTRSR1solver.Run();

	std::cout << "********************************Check LRTRSR1*************************************" << std::endl;
	LRTRSR1 LRTRSR1solver(&Prob, &EucX);
	std::cout << std::endl;
	LRTRSR1solver.DEBUG = FINALRESULT;
	LRTRSR1solver.CheckParams();
	LRTRSR1solver.Run();

	// Check gradient and Hessian
	Prob.CheckGradHessian(&EucX);
	const Variable *xopt = RTRNewtonsolver.GetXopt();
	Prob.CheckGradHessian(xopt);
    
	if (Xopt != nullptr)
	{
		const double *xoptptr = xopt->ObtainReadData();
		for (integer i = 0; i < dim; i++)
			Xopt[i] = xoptptr[i];
	}
};