PxU32 raycast_capsule(GU_RAY_FUNC_PARAMS)
{
    PX_UNUSED(maxHits);
    PX_ASSERT(geom.getType() == PxGeometryType::eCAPSULE);
    PX_ASSERT(maxHits && hits);

    const PxCapsuleGeometry& capsuleGeom = static_cast<const PxCapsuleGeometry&>(geom);

    // TODO: PT: could we simplify this ?
    Capsule capsule;
    getCapsuleSegment(pose, capsuleGeom, capsule);
    capsule.radius = capsuleGeom.radius;

    PxReal t;
    if(!intersectRayCapsule(rayOrigin, rayDir, capsule, t))
        return 0;

    if(t>maxDist)
        return 0;

    // PT: we can't avoid computing the position here since it's needed to compute the normal anyway
    hits->position	= rayOrigin + rayDir*t;	// PT: will be rayOrigin for t=0.0f (i.e. what the spec wants)
    hits->distance	= t;
    hits->faceIndex	= 0xffffffff;
    hits->u			= 0.0f;
    hits->v			= 0.0f;

    // Compute additional information if needed
    PxHitFlags outFlags = PxHitFlag::eDISTANCE|PxHitFlag::ePOSITION;
    if(hitFlags & PxHitFlag::eNORMAL)
    {
        outFlags |= PxHitFlag::eNORMAL;

        if(t==0.0f)
        {
            hits->normal = -rayDir;
        }
        else
        {
            PxReal capsuleT;
            distancePointSegmentSquared(capsule, hits->position, &capsuleT);
            capsule.computePoint(hits->normal, capsuleT);
            hits->normal = hits->position - hits->normal;	 //this should never be zero. It should have a magnitude of the capsule radius.
            hits->normal.normalize();
        }
    }
    else
    {
        hits->normal = PxVec3(0.0f);
    }
    hits->flags = outFlags;

    return 1;
}