//Returns a contact normal for the closest point to the triangle t.  p is the point on the triangle.
//The direction is the one in which triangle 1 can move to get away from closestpt
Vector3 ContactNormal(const CollisionMesh& m,const Vector3& p,int t,const Vector3& closestPt)
{
  Triangle3D tri;
  m.GetTriangle(t,tri);
  Vector3 b=tri.barycentricCoords(p);
  int type=FeatureType(b);
  switch(type) {
  case 1:  //pt
    //get the triangle normal
    {
      Vector3 n = VertexNormal(m,t,VertexIndex(b));
      n.inplaceNegative();
      return n;
    }
    break;
  case 2:  //edge
    {
      int e = EdgeIndex(b);
      Vector3 n = EdgeNormal(m,t,e);
      n.inplaceNegative();
      return n;
    }
    break;
  case 3:  //face
    return m.currentTransform.R*(-tri.normal());
  }
  static int warnedCount = 0;
  if(warnedCount % 10000 == 0) 
    printf("ODECustomMesh: Warning, degenerate triangle, types %d\n",type);
  warnedCount++;
  //AssertNotReached();
  return Vector3(Zero);
}
int MeshPrimitiveCollide(CollisionMesh& m1,Real outerMargin1,GeometricPrimitive3D& g2,const RigidTransform& T2,Real outerMargin2,dContactGeom* contact,int maxcontacts)
{
  GeometricPrimitive3D gworld=g2;
  gworld.Transform(T2);
  Sphere3D s;
  if(gworld.type != GeometricPrimitive3D::Point && gworld.type != GeometricPrimitive3D::Sphere) {
    fprintf(stderr,"Distance computations between Triangles and %s not supported\n",gworld.TypeName());
    return 0;
  }
  if(gworld.type == GeometricPrimitive3D::Point) {
    s.center = *AnyCast<Point3D>(&gworld.data);
    s.radius = 0;
  }
  else {
    s = *AnyCast<Sphere3D>(&gworld.data);
  }
    
  Real tol = outerMargin1 + outerMargin2;
  Triangle3D tri;
  vector<int> tris;
  int k=0;
  NearbyTriangles(m1,gworld,tol,tris,maxcontacts);
  for(size_t j=0;j<tris.size();j++) {   
    m1.GetTriangle(tris[j],tri);
    tri.a = m1.currentTransform*tri.a;
    tri.b = m1.currentTransform*tri.b;
    tri.c = m1.currentTransform*tri.c;

    Vector3 cp = tri.closestPoint(s.center);
    Vector3 n = cp - s.center;
    Real nlen = n.length();
    Real d = nlen-s.radius;
    Vector3 pw = s.center;
    if(s.radius > 0)
      //adjust pw to the sphere surface
      pw += n*(s.radius/nlen);
    if(d < gNormalFromGeometryTolerance) {  //compute normal from the geometry
      Vector3 plocal;
      m1.currentTransform.mulInverse(cp,plocal);
      n = ContactNormal(m1,plocal,tris[j],pw);
    }
    else if(d > tol) {  //some penetration -- we can't trust the result of PQP
      continue;
    }
    else n /= nlen;
    //migrate the contact point to the center of the overlap region
    CopyVector(contact[k].pos,0.5*(cp+pw) + ((outerMargin2 - outerMargin1)*0.5)*n);
    CopyVector(contact[k].normal,n);
    contact[k].depth = tol - d;
    k++;
    if(k == maxcontacts) break;
  }
  return k;
}
int MeshPointCloudCollide(CollisionMesh& m1,Real outerMargin1,CollisionPointCloud& pc2,Real outerMargin2,dContactGeom* contact,int maxcontacts)
{
  Real tol = outerMargin1 + outerMargin2;
  int k=0;
  vector<int> tris;
  Triangle3D tri,triw;
  for(size_t i=0;i<pc2.points.size();i++) {
    Vector3 pw = pc2.currentTransform*pc2.points[i];
    NearbyTriangles(m1,pw,tol,tris,maxcontacts-k);
    for(size_t j=0;j<tris.size();j++) {   
      m1.GetTriangle(tris[j],tri);
      triw.a = m1.currentTransform*tri.a;
      triw.b = m1.currentTransform*tri.b;
      triw.c = m1.currentTransform*tri.c;
      Vector3 cp = triw.closestPoint(pw);
      Vector3 n = cp - pw;
      Real d = n.length();
      if(d < gNormalFromGeometryTolerance) {  //compute normal from the geometry
	Vector3 plocal;
	m1.currentTransform.mulInverse(cp,plocal);
	n = ContactNormal(m1,plocal,tris[j],pw);
      }
      else if(d > tol) {  //some penetration -- we can't trust the result of PQP
	continue;
      }
      else n /= d;
      //migrate the contact point to the center of the overlap region
      CopyVector(contact[k].pos,0.5*(cp+pw) + ((outerMargin2 - outerMargin1)*0.5)*n);
      CopyVector(contact[k].normal,n);
      contact[k].depth = tol - d;
      k++;
      if(k == maxcontacts) break;
    }
  }
  return k;
}
int MeshPointCloudCollide(CollisionMesh& m1,Real outerMargin1,CollisionPointCloud& pc2,Real outerMargin2,dContactGeom* contact,int maxcontacts)
{
  Real tol=outerMargin1+outerMargin2;
  vector<int> points;
  vector<int> tris;
  if(!Collides(pc2,tol,m1,points,tris,maxcontacts)) return 0;
  Assert(points.size()==tris.size());
  Triangle3D tri,triw;
  int k=0;
  for(size_t i=0;i<points.size();i++) {
    Vector3 pw = pc2.currentTransform*pc2.points[points[i]];
    m1.GetTriangle(tris[i],tri);
    triw.a = m1.currentTransform*tri.a;
    triw.b = m1.currentTransform*tri.b;
    triw.c = m1.currentTransform*tri.c;
    Vector3 cp = triw.closestPoint(pw);
    Vector3 n = cp - pw;
    Real d = n.length();
    if(d < gNormalFromGeometryTolerance) {  //compute normal from the geometry
      Vector3 plocal;
      m1.currentTransform.mulInverse(cp,plocal);
      n = ContactNormal(m1,plocal,tris[i],pw);
    }
    else if(d > tol) {  //some penetration -- we can't trust the result of PQP
      continue;
    }
    else n /= d;
    //migrate the contact point to the center of the overlap region
    CopyVector(contact[k].pos,0.5*(cp+pw) + ((outerMargin2 - outerMargin1)*0.5)*n);
    CopyVector(contact[k].normal,n);
    contact[k].depth = tol - d;
    k++;
    if(k == maxcontacts) break;
  }
  /*
  Real tol = outerMargin1 + outerMargin2;
  Box3D mbb,mbb_pclocal;
  GetBB(m1,mbb);
  RigidTransform Tw_pc;
  Tw_pc.setInverse(pc2.currentTransform);
  mbb_pclocal.setTransformed(mbb,Tw_pc);
  AABB3D maabb_pclocal;
  mbb_pclocal.getAABB(maabb_pclocal);
  maabb_pclocal.bmin -= Vector3(tol);
  maabb_pclocal.bmax += Vector3(tol);
  maabb_pclocal.setIntersection(pc2.bblocal);
  list<void*> nearpoints;
  pc2.grid.BoxItems(Vector(3,maabb_pclocal.bmin),Vector(3,maabb_pclocal.bmax),nearpoints);
  int k=0;
  vector<int> tris;
  Triangle3D tri,triw;
  for(list<void*>::iterator i=nearpoints.begin();i!=nearpoints.end();i++) {
    Vector3 pcpt = *reinterpret_cast<Vector3*>(*i);
    Vector3 pw = pc2.currentTransform*pcpt;
    NearbyTriangles(m1,pw,tol,tris,maxcontacts-k);
    for(size_t j=0;j<tris.size();j++) {   
      m1.GetTriangle(tris[j],tri);
      triw.a = m1.currentTransform*tri.a;
      triw.b = m1.currentTransform*tri.b;
      triw.c = m1.currentTransform*tri.c;
      Vector3 cp = triw.closestPoint(pw);
      Vector3 n = cp - pw;
      Real d = n.length();
      if(d < gNormalFromGeometryTolerance) {  //compute normal from the geometry
	Vector3 plocal;
	m1.currentTransform.mulInverse(cp,plocal);
	n = ContactNormal(m1,plocal,tris[j],pw);
      }
      else if(d > tol) {  //some penetration -- we can't trust the result of PQP
	continue;
      }
      else n /= d;
      //migrate the contact point to the center of the overlap region
      CopyVector(contact[k].pos,0.5*(cp+pw) + ((outerMargin2 - outerMargin1)*0.5)*n);
      CopyVector(contact[k].normal,n);
      contact[k].depth = tol - d;
      k++;
      if(k == maxcontacts) break;
    }
  }
  return k;
  */
  return k;
}
int MeshMeshCollide(CollisionMesh& m1,Real outerMargin1,CollisionMesh& m2,Real outerMargin2,dContactGeom* contact,int maxcontacts)
{
  CollisionMeshQuery q(m1,m2);
  bool res=q.WithinDistanceAll(outerMargin1+outerMargin2);
  if(!res) {
    return 0;
  }

  vector<int> t1,t2;
  vector<Vector3> cp1,cp2;
  q.TolerancePairs(t1,t2);
  q.TolerancePoints(cp1,cp2);
  //printf("%d Collision pairs\n",t1.size());
  const RigidTransform& T1 = m1.currentTransform;
  const RigidTransform& T2 = m2.currentTransform;
  RigidTransform T21; T21.mulInverseA(T1,T2);
  RigidTransform T12; T12.mulInverseA(T2,T1);
  Real tol = outerMargin1+outerMargin2;
  Real tol2 = Sqr(tol);

  size_t imax=t1.size();
  Triangle3D tri1,tri2,tri1loc,tri2loc;
  if(gDoTriangleTriangleCollisionDetection) {
    //test if more triangle vertices are closer than tolerance
    for(size_t i=0;i<imax;i++) {
      m1.GetTriangle(t1[i],tri1);
      m2.GetTriangle(t2[i],tri2);
      
      tri1loc.a = T12*tri1.a;
      tri1loc.b = T12*tri1.b;
      tri1loc.c = T12*tri1.c;
      tri2loc.a = T21*tri2.a;
      tri2loc.b = T21*tri2.b;
      tri2loc.c = T21*tri2.c;
      bool usecpa,usecpb,usecpc,usecpa2,usecpb2,usecpc2;
      Vector3 cpa = tri1.closestPoint(tri2loc.a);
      Vector3 cpb = tri1.closestPoint(tri2loc.b);
      Vector3 cpc = tri1.closestPoint(tri2loc.c);
      Vector3 cpa2 = tri2.closestPoint(tri1loc.a);
      Vector3 cpb2 = tri2.closestPoint(tri1loc.b);
      Vector3 cpc2 = tri2.closestPoint(tri1loc.c);
      usecpa = (cpa.distanceSquared(tri2loc.a) < tol2);
      usecpb = (cpb.distanceSquared(tri2loc.b) < tol2);
      usecpc = (cpc.distanceSquared(tri2loc.c) < tol2);
      usecpa2 = (cpa2.distanceSquared(tri1loc.a) < tol2);
      usecpb2 = (cpb2.distanceSquared(tri1loc.b) < tol2);
      usecpc2 = (cpc2.distanceSquared(tri1loc.c) < tol2);
      //if already existing, disable it
      if(usecpa && cpa.isEqual(cp1[i],cptol)) usecpa=false;
      if(usecpb && cpb.isEqual(cp1[i],cptol)) usecpb=false;
      if(usecpc && cpc.isEqual(cp1[i],cptol)) usecpc=false;
      if(usecpa2 && cpa2.isEqual(cp2[i],cptol)) usecpa2=false;
      if(usecpb2 && cpb2.isEqual(cp2[i],cptol)) usecpb2=false;
      if(usecpc2 && cpc2.isEqual(cp2[i],cptol)) usecpc2=false;
      
      if(usecpa) {
	if(usecpb && cpb.isEqual(cpa,cptol)) usecpb=false;
	if(usecpc && cpc.isEqual(cpa,cptol)) usecpc=false;
      }
      if(usecpb) {
	if(usecpc && cpc.isEqual(cpb,cptol)) usecpc=false;
      }
      if(usecpa2) {
	if(usecpb2 && cpb2.isEqual(cpa2,cptol)) usecpb2=false;
	if(usecpc2 && cpc2.isEqual(cpa2,cptol)) usecpc2=false;
      }
      if(usecpb) {
	if(usecpc2 && cpc.isEqual(cpb2,cptol)) usecpc2=false;
      }
      
      if(usecpa) {
	t1.push_back(t1[i]);
	t2.push_back(t2[i]);
	cp1.push_back(cpa);
	cp2.push_back(tri2.a);
      }
      if(usecpb) {
	t1.push_back(t1[i]);
	t2.push_back(t2[i]);
	cp1.push_back(cpb);
	cp2.push_back(tri2.b);
      }
      if(usecpc) {
	t1.push_back(t1[i]);
	t2.push_back(t2[i]);
	cp1.push_back(cpc);
	cp2.push_back(tri2.c);
      }
      if(usecpa2) {
	t1.push_back(t1[i]);
	t2.push_back(t2[i]);
	cp1.push_back(tri1.a);
	cp2.push_back(cpa2);
      }
      if(usecpb2) {
	t1.push_back(t1[i]);
	t2.push_back(t2[i]);
	cp1.push_back(tri1.b);
	cp2.push_back(cpb2);
      }
      if(usecpc2) {
	t1.push_back(t1[i]);
	t2.push_back(t2[i]);
	cp1.push_back(tri1.c);
	cp2.push_back(cpc2);
      }
    }
    /*
    if(t1.size() != imax)
      printf("ODECustomMesh: Triangle vert checking added %d points\n",t1.size()-imax);
    */
    //getchar();
  }

  imax = t1.size();
  static int warnedCount = 0;
  for(size_t i=0;i<imax;i++) {
    m1.GetTriangle(t1[i],tri1);
    m2.GetTriangle(t2[i],tri2);

    tri1loc.a = T12*tri1.a;
    tri1loc.b = T12*tri1.b;
    tri1loc.c = T12*tri1.c;
    if(tri1loc.intersects(tri2)) { 
      if(warnedCount % 1000 == 0) {
	printf("ODECustomMesh: Triangles penetrate margin %g+%g: can't trust contact detector\n",outerMargin1,outerMargin2);
      }
      warnedCount++;
      /*
      //the two triangles intersect! can't trust results of PQP
      t1[i] = t1.back();
      t2[i] = t2.back();
      cp1[i] = cp1.back();
      cp2[i] = cp2.back();
      i--;
      imax--;
      */
    }
  }
  if(t1.size() != imax) {
    printf("ODECustomMesh: %d candidate points were removed due to mesh collision\n",t1.size()-imax);
    t1.resize(imax);
    t2.resize(imax);
    cp1.resize(imax);
    cp2.resize(imax);
  }
  
  int k=0;  //count the # of contact points added
  for(size_t i=0;i<cp1.size();i++) {
    Vector3 p1 = T1*cp1[i];
    Vector3 p2 = T2*cp2[i];
    Vector3 n=p1-p2;
    Real d = n.norm();
    if(d < gNormalFromGeometryTolerance) {  //compute normal from the geometry
      n = ContactNormal(m1,m2,cp1[i],cp2[i],t1[i],t2[i]);
    }
    else if(d > tol) {  //some penetration -- we can't trust the result of PQP
      continue;
    }
    else n /= d;
    //check for invalid normals
    Real len=n.length();
    if(len < gZeroNormalTolerance || !IsFinite(len)) continue;
    //cout<<"Local Points "<<cp1[i]<<", "<<cp2[i]<<endl;
    //cout<<"Points "<<p1<<", "<<p2<<endl;
    //Real utol = (tol)*0.5/d + 0.5;
    //CopyVector(contact[k].pos,p1+utol*(p2-p1));
    CopyVector(contact[k].pos,0.5*(p1+p2) + ((outerMargin2 - outerMargin1)*0.5)*n);
    CopyVector(contact[k].normal,n);
    contact[k].depth = tol - d;
    if(contact[k].depth < 0) contact[k].depth = 0;
    //cout<<"Normal "<<n<<", depth "<<contact[i].depth<<endl;
    //getchar();
    k++;
    if(k == maxcontacts) break;
  }
  return k;
}
///Compute normal from mesh geometry: returns the local normal needed for
///triangle 1 on m1 to get out of triangle 2 on m2.
///p1 and p2 are given in local coordinates
Vector3 ContactNormal(const CollisionMesh& m1,const CollisionMesh& m2,const Vector3& p1,const Vector3& p2,int t1,int t2)
{
  Triangle3D tri1,tri2;
  m1.GetTriangle(t1,tri1);
  m2.GetTriangle(t2,tri2);
  Vector3 b1=tri1.barycentricCoords(p1);
  Vector3 b2=tri2.barycentricCoords(p2);
  int type1=FeatureType(b1),type2=FeatureType(b2);
  switch(type1) {
  case 1:  //pt
    switch(type2) {
    case 1:  //pt
      //get the triangle normals
      {
	//printf("ODECustomMesh: Point-point contact\n");
	Vector3 n1 = VertexNormal(m1,t1,VertexIndex(b1));
	Vector3 n2 = VertexNormal(m2,t2,VertexIndex(b2));
	n2 -= n1;
	n2.inplaceNormalize();
	return n2;
      }
      break;
    case 2:  //edge
      {
	//printf("ODECustomMesh: Point-edge contact\n");
	Vector3 n1 = VertexNormal(m1,t1,VertexIndex(b1));
	int e = EdgeIndex(b2);
	Segment3D s = tri2.edge(e);
	Vector3 ev = m2.currentTransform.R*(s.b-s.a);
	Vector3 n2 = EdgeNormal(m2,t2,e);
	n2-=(n1-ev*ev.dot(n1)/ev.dot(ev)); //project onto normal
	n2.inplaceNormalize();
	return n2;
      }
      break;
    case 3:  //face
      return m2.currentTransform.R*tri2.normal();
    }
    break;
  case 2:  //edge
    switch(type2) {
    case 1:  //pt
      {
	//printf("ODECustomMesh: Edge-point contact\n");
	Vector3 n2 = VertexNormal(m2,t2,VertexIndex(b2));
	int e = EdgeIndex(b1);
	Segment3D s = tri1.edge(e);
	Vector3 ev = m1.currentTransform.R*(s.b-s.a);
	Vector3 n1 = EdgeNormal(m1,t1,e);
	n2 = (n2-ev*ev.dot(n2)/ev.dot(ev))-n1; //project onto normal
	n2.inplaceNormalize();
	return n2;
      }
      break;
    case 2:  //edge
      {
	//printf("ODECustomMesh: Edge-edge contact\n");
	int e = EdgeIndex(b1);
	Segment3D s1 = tri1.edge(e);
	Vector3 ev1 = m1.currentTransform.R*(s1.b-s1.a);
	ev1.inplaceNormalize();
	e = EdgeIndex(b2);
	Segment3D s2 = tri2.edge(e);
	Vector3 ev2 = m2.currentTransform.R*(s2.b-s2.a);
	ev2.inplaceNormalize();
	Vector3 n; 
	n.setCross(ev1,ev2);
	Real len = n.length();
	if(len < gZeroNormalTolerance) {
	  //hmm... edges are parallel?
	}
	n /= len;
	//make sure the normal direction points into m1 and out of m2
	if(n.dot(m1.currentTransform*s1.a) < n.dot(m2.currentTransform*s2.a))
	  n.inplaceNegative();
	/*
	if(n.dot(m1.currentTransform.R*tri1.normal()) > 0.0) {
	  if(n.dot(m2.currentTransform.R*tri2.normal()) > 0.0) {
	    printf("ODECustomMesh: Warning, inconsistent normal direction? %g, %g\n",n.dot(m1.currentTransform.R*tri1.normal()),n.dot(m2.currentTransform.R*tri2.normal()));
	  }
	  n.inplaceNegative();
	}
	else {
	  if(n.dot(m2.currentTransform.R*tri2.normal()) < 0.0) {
	    printf("ODECustomMesh: Warning, inconsistent normal direction? %g, %g\n",n.dot(m1.currentTransform.R*tri1.normal()),n.dot(m2.currentTransform.R*tri2.normal()));
	  }
	}
	*/
	//cout<<"Edge vector 1 "<<ev1<<", vector 2" <<ev2<<", normal: "<<n<<endl;
	return n;
      }
      break;
    case 3:  //face
      return m2.currentTransform.R*tri2.normal();
    }
    break;
  case 3:  //face
    if(type2 == 3)
      printf("ODECustomMesh: Warning, face-face contact?\n");
    return m1.currentTransform.R*(-tri1.normal());
  }
  static int warnedCount = 0;
  if(warnedCount % 10000 == 0) 
    printf("ODECustomMesh: Warning, degenerate triangle, types %d %d\n",type1,type2);
  warnedCount++;
  //AssertNotReached();
  return Vector3(Zero);
}