示例#1
0
void runBSplineTransform(MatlabImportFilter::Pointer matlabImport,
			 MatlabExportFilter::Pointer matlabExport) {

  // retrieve pointers to the inputs that we are going to need here
  typedef MatlabImportFilter::MatlabInputPointer MatlabInputPointer; 
  MatlabInputPointer inX         = matlabImport->GetRegisteredInput("X");
  MatlabInputPointer inY         = matlabImport->GetRegisteredInput("Y");
  MatlabInputPointer inXI        = matlabImport->GetRegisteredInput("XI");
  MatlabInputPointer inORDER     = matlabImport->GetRegisteredInput("ORDER");
  MatlabInputPointer inLEVELS    = matlabImport->GetRegisteredInput("LEVELS");

  // register the output for this function at the export filter
  typedef MatlabExportFilter::MatlabOutputPointer MatlabOutputPointer;
  MatlabOutputPointer outYI = matlabExport->RegisterOutput(OUT_YI, "YI");

  // spline order (input argument): default or user-provided
  unsigned int splineOrder = matlabImport->ReadScalarFromMatlab<unsigned int>(inORDER, 3);

  // number of levels (input argument): default or user-provided
  unsigned int numOfLevels = matlabImport->ReadScalarFromMatlab<unsigned int>(inLEVELS, 5);

  // get size of input arguments
  mwSize Mx = mxGetM(inX->pm); // number of source points
  mwSize Mxi = mxGetM(inXI->pm); // number of points to be warped

  // pointers to input matrices
  TScalarType *x 
    = (TScalarType *)mxGetData(inX->pm); // source points
  TScalarType *y 
    = (TScalarType *)mxGetData(inY->pm); // target points
  TScalarType *xi 
    = (TScalarType *)mxGetData(inXI->pm); // points to be warped
  if (x == NULL) {
    mexErrMsgTxt("Cannot get a pointer to input X");
  }
  if (y == NULL) {
    mexErrMsgTxt("Cannot get a pointer to input Y");
  }
  if (xi == NULL) {
    mexErrMsgTxt("Cannot get a pointer to input XI");
  }

  // type definitions for the BSPline transform
  typedef itk::Vector<TScalarType, Dimension> DataType;
  typedef itk::PointSet<DataType, Dimension> PointSetType;
  typedef itk::Image<DataType, Dimension> ImageType;
  typedef typename 
    itk::BSplineScatteredDataPointSetToImageFilter<PointSetType, 
						   ImageType> TransformType;

  // variables to store the input points
  typename PointSetType::Pointer pointSet = PointSetType::New();
  typename PointSetType::PointType xParam;
  DataType v; // v = y-x, i.e. displacement vector between source and
              // target landmark

  // init variables to contain the limits of a bounding box that
  // contains all the points
  typename ImageType::PointType orig, term;
  for (mwSize col=0; col < (mwSize)Dimension; ++col) {
    orig[CAST2MWSIZE(col)] = std::numeric_limits<TScalarType>::max();
    term[CAST2MWSIZE(col)] = std::numeric_limits<TScalarType>::min();
  }

  // find bounding box limits
  for (mwSize row=0; row < Mx; ++row) {
    for (mwSize col=0; col < (mwSize)Dimension; ++col) {
      orig[CAST2MWSIZE(col)] = std::min((TScalarType)orig[CAST2MWSIZE(col)], x[Mx * col + row]);
      term[CAST2MWSIZE(col)] = std::max((TScalarType)term[CAST2MWSIZE(col)], x[Mx * col + row]);
      orig[CAST2MWSIZE(col)] = std::min((TScalarType)orig[CAST2MWSIZE(col)], y[Mx * col + row]);
      term[CAST2MWSIZE(col)] = std::max((TScalarType)term[CAST2MWSIZE(col)], y[Mx * col + row]);
    }
  }
  for (mwSize row=0; row < Mxi; ++row) {
    for (mwSize col=0; col < (mwSize)Dimension; ++col) {
      orig[CAST2MWSIZE(col)] = std::min((TScalarType)orig[CAST2MWSIZE(col)], xi[Mxi * col + row]);
      term[CAST2MWSIZE(col)] = std::max((TScalarType)term[CAST2MWSIZE(col)], xi[Mxi * col + row]);
    }
  }

  // compute length of each size of the bounding box
  DataType len = term - orig;
  TScalarType lenmax = std::numeric_limits<TScalarType>::min();
  for (mwSize col=0; col < (mwSize)Dimension; ++col) {
    lenmax = std::max(lenmax, len[CAST2MWSIZE(col)]);
  }

  // duplicate the input x and y matrices to PointSet format so that
  // we can pass it to the ITK function
  //
  // we also translate and scale all points so that the bounding box
  // fits within the domain [0, 1] x [0, 1] x [0,1]. We need to do
  // this because the BSpline function requires the parametric domain
  // to be within [0, 1] x [0, 1] x [0,1]
  for (mwSize row=0; row < Mx; ++row) {
    for (mwSize col=0; col < (mwSize)Dimension; ++col) {
      v[CAST2MWSIZE(col)] = (y[Mx * col + row] - x[Mx * col + row]) / lenmax;
      xParam[CAST2MWSIZE(col)] = (x[Mx * col + row] - orig[CAST2MWSIZE(col)]) / lenmax;
    }
    pointSet->SetPoint(row, xParam);
    pointSet->SetPointData(row, v);
  }

  // instantiate and set-up transform
  typename TransformType::Pointer transform = TransformType::New();
  transform->SetGenerateOutputImage(false);
  transform->SetInput(pointSet);
  transform->SetSplineOrder(splineOrder);
  typename TransformType::ArrayType ncps ;
  ncps.Fill(splineOrder + 1);
  transform->SetNumberOfControlPoints(ncps);
  transform->SetNumberOfLevels(numOfLevels);

  // note that closedim, spacing, sz and orig are all refered to the
  // parametric domain, i.e. the domain of x and xi
  typename TransformType::ArrayType closedim;
  typename ImageType::SpacingType spacing;
  typename ImageType::SizeType sz;

  // the parametric domain is not periodic in any dimension
  closedim.Fill(0);

  // as we are not creating the image, we don't need to provide a
  // sensible number of voxels. But size has to be at least 2 voxels
  // to avoid a run-time error
  sz.Fill(2);

  // because the parameterization is in [0, 1] x [0, 1] x [0,1], and
  // we have only size = 2 voxels in every dimension, the spacing will
  // be 1.0 / (2 - 1) = 1.0
  spacing.Fill(1.0);

  // because of the reparameterization, the origin we have to pass to
  // the transform is not the origin of the real points, but the
  // origin of the [0, 1] x [0, 1] x [0,1] bounding box
  typename ImageType::PointType origZero;
  origZero.Fill(0.0);
  
  transform->SetCloseDimension(closedim);
  transform->SetSize(sz);
  transform->SetSpacing(spacing);
  transform->SetOrigin(origZero);

  // run transform
  transform->Update();

  // create output vector and pointer to populate it
  mwSize ndimxi = mxGetNumberOfDimensions(inXI->pm); 
  const mwSize *dimsxi = mxGetDimensions(inXI->pm);
  std::vector<mwSize> size;
  for (mwIndex i = 0; i < ndimxi; ++i) {
    size.push_back(dimsxi[i]);
  }
  TScalarType *yi 
    = matlabExport->AllocateNDArrayInMatlab<TScalarType>(outYI, size);

  // from ITK v4.x, we need to instantiate a function to evaluate
  // points of the B-spline, as the Evaluate() method has been removed
  // from the TransformType
#if ITK_VERSION_MAJOR>=4
  // Note: in the following, we have to use TCoordRep=double, because
  // ITK gives a compilation error of an abstract class not having
  // been implemented. Otherwise, we would use
  // TCoordRep=TScalar=float, as in the rest of this program
  typedef typename 
    itk::BSplineControlPointImageFunction<ImageType, double> EvalFunctionType;
  typename EvalFunctionType::Pointer function = EvalFunctionType::New();

  function->SetSplineOrder(splineOrder);
  function->SetOrigin(origZero);
  function->SetSpacing(spacing);
  function->SetSize(sz);
  function->SetInputImage(transform->GetPhiLattice());
#endif

  // sample the warp field
  DataType vi; // warp field sample
  typename PointSetType::PointType xiParam; // sampling coordinates
  for (mwSize row=0; row < Mxi; ++row) {
    for (mwSize col=0; col < (mwSize)Dimension; ++col) {
      xiParam[CAST2MWSIZE(col)] = (xi[Mxi * col + row] - orig[CAST2MWSIZE(col)]) / lenmax;
    }
#if ITK_VERSION_MAJOR<4
    transform->Evaluate(xiParam, vi);
#else
    vi = function->Evaluate(xiParam);
#endif
    for (mwSize col=0; col < (mwSize)Dimension; ++col) {
      yi[Mxi * col + row] = xi[Mxi * col + row] + vi[CAST2MWSIZE(col)] * lenmax;
    }
  }

  // exit function
  return;

}