Exemple #1
0
void CGraphUT::testLin()
{
  CNode *n0, *n1, *n2, *n3;
  CGraph cgraph;
  int error = 0;

  VariablePtr v0 = (VariablePtr) new Variable(0, 0, 0.0, 10.0, Continuous, "x0");
  VariablePtr v1 = (VariablePtr) new Variable(1, 1, 0.0, 10.0, Continuous, "x1");
  VariablePtr v2 = (VariablePtr) new Variable(2, 2, 0.0, 10.0, Continuous, "x2");
  VariablePtr v3 = (VariablePtr) new Variable(3, 3, 0.0, 10.0, Continuous, "x3");

  double x[4] = {1.0, 2.0, 5.0, 7.0};
  double g[4] = {0.0, 0.0, 0.0, 0.0};
  double gexp[4] = {1.0, 1.0, 5.4, -1.0};


  n0 = cgraph.newNode(v0);
  n1 = cgraph.newNode(v0);
  CPPUNIT_ASSERT(n0 == n1);

  n1 = cgraph.newNode(v1);
  n2 = cgraph.newNode(v1);
  CPPUNIT_ASSERT(n2 == n1);
  CPPUNIT_ASSERT(n0 != n1);

  n2 = cgraph.newNode(OpPlus, n0, n1);
  n3 = cgraph.newNode(5.4);
  n0 = cgraph.newNode(v2);
  n3 = cgraph.newNode(OpMult, n3, n0);
  n3 = cgraph.newNode(OpPlus, n2, n3); // n3 = x0 + x1 + 5.4*x2
  n0 = cgraph.newNode(v3);
  n0 = cgraph.newNode(OpMinus, n3, n0); // n0 = x0 + x1 + 5.4*x2 - x3

  cgraph.setOut(n0);
  cgraph.finalize();

  CPPUNIT_ASSERT(cgraph.numVars() == 4);
  CPPUNIT_ASSERT(cgraph.getType() == Linear);
  CPPUNIT_ASSERT(fabs(cgraph.eval(x, &error) - 23.0)<1e-10);
  CPPUNIT_ASSERT(0==error);
  cgraph.evalGradient(x, g, &error);
  CPPUNIT_ASSERT(0==error);
  for (UInt i=0; i<4; ++i) {
    CPPUNIT_ASSERT(fabs(g[i]-gexp[i])<1e-10);
  }


  x[0] = 0.0; x[1] = 2.0; x[2] = 0.0; x[3] = 7.0; 
  g[0] = 0.0; g[1] = 0.0; g[2] = 0.0; g[3] = 0.0; 
  cgraph.evalGradient(x, g, &error);
  for (UInt i=0; i<4; ++i) {
    CPPUNIT_ASSERT(fabs(g[i]-gexp[i])<1e-10);
  }
  CPPUNIT_ASSERT(0==error);
  CPPUNIT_ASSERT(fabs(n0->getVal()+5.0)<1e-10);
  CPPUNIT_ASSERT(fabs(cgraph.eval(x,&error)+5.0)<1e-10);
}