typename TImage::Pointer modelBasedImageToImageRegistration(std::string referenceFilename,
																					std::string targetFilename, 
																					typename TStatisticalModelType::Pointer model,
																					std::string outputDfFilename,
																					unsigned numberOfIterations){

	typedef itk::ImageFileReader<TImage> ImageReaderType;
	typedef itk::InterpolatingStatisticalDeformationModelTransform<TRepresenter, double, VImageDimension> TransformType;
	typedef itk::LBFGSOptimizer OptimizerType;
	typedef itk::ImageRegistrationMethod<TImage, TImage> RegistrationFilterType;
	typedef itk::WarpImageFilter< TImage, TImage, TVectorImage > WarperType;
	typedef itk::LinearInterpolateImageFunction< TImage, double > InterpolatorType;
	
	typename ImageReaderType::Pointer referenceReader = ImageReaderType::New();
	referenceReader->SetFileName(referenceFilename.c_str());
	referenceReader->Update();
	typename TImage::Pointer referenceImage = referenceReader->GetOutput();
	referenceImage->Update();

	typename ImageReaderType::Pointer targetReader = ImageReaderType::New();
	targetReader->SetFileName(targetFilename.c_str());
	targetReader->Update();
	typename TImage::Pointer targetImage = targetReader->GetOutput();
	targetImage->Update();

	// do the fitting
	typename TransformType::Pointer transform = TransformType::New();
	transform->SetStatisticalModel(model);
	transform->SetIdentity();

	// Setting up the fitting
	OptimizerType::Pointer optimizer = OptimizerType::New();
	optimizer->MinimizeOn();
	optimizer->SetMaximumNumberOfFunctionEvaluations(numberOfIterations);

	typedef  IterationStatusObserver ObserverType;
	ObserverType::Pointer observer = ObserverType::New();
	optimizer->AddObserver( itk::IterationEvent(), observer );

	typename TMetricType::Pointer metric = TMetricType::New();
	typename InterpolatorType::Pointer interpolator = InterpolatorType::New();

	typename RegistrationFilterType::Pointer registration = RegistrationFilterType::New();
	registration->SetInitialTransformParameters(transform->GetParameters());
	registration->SetMetric(metric);
	registration->SetOptimizer(   optimizer   );
	registration->SetTransform(   transform );
	registration->SetInterpolator( interpolator );
	registration->SetFixedImage( targetImage );
	registration->SetFixedImageRegion(targetImage->GetBufferedRegion() );
	registration->SetMovingImage( referenceImage );

	try {
		std::cout << "Performing registration... " << std::flush;
		registration->Update();
		std::cout << "[done]" << std::endl;

	} catch ( itk::ExceptionObject& o ) {
		std::cout << "caught exception " << o << std::endl;
	}

	typename TVectorImage::Pointer df = model->DrawSample(transform->GetCoefficients());

	// write deformation field
	if(outputDfFilename.size()>0){
		typename itk::ImageFileWriter<TVectorImage>::Pointer df_writer = itk::ImageFileWriter<TVectorImage>::New();
		df_writer->SetFileName(outputDfFilename);
		df_writer->SetInput(df);
		df_writer->Update();
	}

	
	// warp reference
	std::cout << "Warping reference... " << std::flush;
	typename WarperType::Pointer warper = WarperType::New();
	warper->SetInput(referenceImage  );
	warper->SetInterpolator( interpolator );
	warper->SetOutputSpacing( targetImage->GetSpacing() );
	warper->SetOutputOrigin( targetImage->GetOrigin() );
	warper->SetOutputDirection( targetImage->GetDirection() );
	warper->SetDisplacementField( df );
	warper->Update();
	std::cout << "[done]" << std::endl;

	return warper->GetOutput();
}
Ejemplo n.º 2
0
int main(int argc, char **argv) {
    
    std::cout<<"Usage: GetHOG3D image.vtk mesh.vtp label_array_name output_hog.csv output_label.csv radius"<<std::endl;
    
//    const char* inputImageFileName = "/home/costa/Dropbox/code/HOG3D/bin/01D-LGE-320.vtk";
//    const char* inputMeshFileName = "/home/costa/Dropbox/code/HOG3D/bin/01D_320_autolabels.vtp";
//    const char* outputHogFileName = "/home/costa/Dropbox/code/HOG3D/bin/01D_320_hog.csv";
//    const char* outputLabelsFileName = "/home/costa/Dropbox/code/HOG3D/bin/01D_320_lbl.csv";
//    const char* labelArrayName = "autolabels";

    
    

    if(argc<6) return EXIT_FAILURE;
    
    int c=1;
    const char* inputImageFileName = argv[c++];
    const char* inputMeshFileName =  argv[c++];
    const char* labelArrayName =  argv[c++];
    const char* outputHogFileName =  argv[c++];
    const char* outputLabelsFileName =  argv[c++];
    float radius = atof(argv[c++]);
    
    

    
   
    const int ngrad_bins = 8; //number of gradient orientation bins
    
    //define some region of interest around the point
    arma::vec radii;
    radii << radius; //mm


    //----------------------------------------------------
    //
    //
    //  read the image
    //
    //
    typedef unsigned int PixelType;
    typedef itk::Image< PixelType, 3 > ImageType;
    typedef itk::ImageFileReader< ImageType > ImageReaderType;
    typename ImageReaderType::Pointer imageReader = ImageReaderType::New();
    imageReader->SetFileName(inputImageFileName);

    try {
        imageReader->Update();
    } catch (itk::ExceptionObject& e) {
        std::cerr << e.what() << std::endl;
        return EXIT_FAILURE;
    }


    //
    //
    //----------------------------------------------------
    
    
    //----------------------------------------------------
    //
    //   read the mesh
    //
    //
    //    
    vtkSmartPointer<vtkXMLPolyDataReader> rd = vtkSmartPointer<vtkXMLPolyDataReader>::New();
    rd->SetFileName(inputMeshFileName);
    rd->Update();
    vtkPolyData* mesh = rd->GetOutput();
    
    //calculate normals
    vtkSmartPointer<vtkPolyDataNormals> ng = vtkSmartPointer<vtkPolyDataNormals>::New();
    ng->SetInputData(mesh);
    ng->SplittingOff();
    ng->Update();
    
    vtkDataArray* pnormals = ng->GetOutput()->GetPointData()->GetNormals();
    //
    //
    //----------------------------------------------------
        



    //----------------------------------------------------
    //
    //   save the labels
    //
    //
    //    
    std::cout<<"Saving the labels"<<std::endl;
    vtkDataArray* mesh_labels = mesh->GetPointData()->GetArray(labelArrayName);
    arma::uvec labels(mesh_labels->GetNumberOfTuples());
    for(int i=0; i<mesh_labels->GetNumberOfTuples(); i++)
    {
        labels(i) = mesh_labels->GetTuple1(i);
    }
    
    labels.save(outputLabelsFileName, arma::raw_ascii);
    

    //
    //
    //----------------------------------------------------
 
    

    //----------------------------------------------------
    //
    //   for each point of the mesh, take normal and
    //   for each scale (radius+sigma) get the HOG
    //
    //
    //
    
    //create the matrix to store everything
    //every row - HOGS of one point
    
    
    
    std::cout<<"Calculating HOGs"<<std::endl;
    
    //sample something to know the size of the neighborhood
    double pt[3];
    mesh->GetPoint(0, pt);    
    arma::ivec samples;
    SampleGradHistogram<ImageType>(pt, radii[0], imageReader->GetOutput(), samples);
    
    
    arma::imat sample_matrix(mesh->GetNumberOfPoints(), samples.size());
    const int nsamples = samples.size();
    //
    
    
 
    for(int j=0; j<radii.size(); j++)
    {

        std::cout<<"Radius "<<radii[j]<<std::endl;

   
    
        for(int i=0; i<mesh->GetNumberOfPoints(); i++)
         {
            if(i%1000==0)
                std::cout<<"Point "<<i<<"/"<<mesh->GetNumberOfPoints()<<"\r"<<std::flush;
        
            double pt[3];
            mesh->GetPoint(i, pt);



            SampleGradHistogram<ImageType>(pt, radii[j], imageReader->GetOutput(), samples);

            if(nsamples!=samples.size())
                std::cout<<std::endl<<"Nonconstant number of samples (vertex "<<i<<"): "<<nsamples<<" vs "<<samples.size()<<std::endl<<std::flush;
            
            //copy the histogram into the matrix
            for(int k=0; k<samples.size(); k++)
            {
                sample_matrix(i,k) = samples(k);
            }
        }
        
        std::cout<<std::endl;
    }
    //
    //
    //----------------------------------------------------
    
    std::cout<<std::endl<<"Saving"<<std::endl;
             
    //save the matrix
    sample_matrix.save(outputHogFileName, arma::arma_binary);
    
    //
    
    return EXIT_SUCCESS;
}
Ejemplo n.º 3
0
int main(int argc, char **argv) {
    
    std::cout<<"Usage: SampleLocalStatistics image.vtk mesh.vtp label_array_name output_features.csv output_label.csv"<<std::endl;
    
//    const char* inputImageFileName = "/home/costa/Dropbox/code/HOG3D/bin/01D-LGE-320.vtk";
//    const char* inputMeshFileName = "/home/costa/Dropbox/code/HOG3D/bin/01D_320_autolabels.vtp";
//    const char* outputHogFileName = "/home/costa/Dropbox/code/HOG3D/bin/01D_320_hog.csv";
//    const char* outputLabelsFileName = "/home/costa/Dropbox/code/HOG3D/bin/01D_320_lbl.csv";
//    const char* labelArrayName = "autolabels";

    
    

    if(argc<6) return EXIT_FAILURE;
    
    int c=1;
    const char* inputImageFileName = argv[c++];
    const char* inputMeshFileName =  argv[c++];
    const char* labelArrayName =  argv[c++];
    const char* outputFeaturesFileName =  argv[c++];
    const char* outputLabelsFileName =  argv[c++];
    //float radius = atof(argv[c++]);
    
    

    
   
    
    //define some region of interest around the point
    arma::vec radii; //does not support more than 1 value for now
    radii << 0.3 << 0.7 <<1.0 <<1.6<<3.5<<5.0; //mm


    //----------------------------------------------------
    //
    //
    //  read the image
    //
    //
    typedef unsigned int PixelType;
    typedef itk::Image< PixelType, 3 > ImageType;
    typedef itk::ImageFileReader< ImageType > ImageReaderType;
    typename ImageReaderType::Pointer imageReader = ImageReaderType::New();
    imageReader->SetFileName(inputImageFileName);

    try {
        imageReader->Update();
    } catch (itk::ExceptionObject& e) {
        std::cerr << e.what() << std::endl;
        return EXIT_FAILURE;
    }


    //
    //
    //----------------------------------------------------
    
    
    //----------------------------------------------------
    //
    //   read the mesh
    //
    //
    //    
    vtkSmartPointer<vtkXMLPolyDataReader> rd = vtkSmartPointer<vtkXMLPolyDataReader>::New();
    rd->SetFileName(inputMeshFileName);
    rd->Update();
    vtkPolyData* mesh = rd->GetOutput();
    
    
    //
    //
    //----------------------------------------------------
        



    //----------------------------------------------------
    //
    //   save the labels
    //
    //
    //    
    std::cout<<"Saving the labels"<<std::endl;
    vtkDataArray* mesh_labels = mesh->GetPointData()->GetArray(labelArrayName);
    arma::uvec labels(mesh_labels->GetNumberOfTuples());
    for(int i=0; i<mesh_labels->GetNumberOfTuples(); i++)
    {
        labels(i) = mesh_labels->GetTuple1(i);
    }
    
    labels.save(outputLabelsFileName, arma::raw_ascii);
    

    //
    //
    //----------------------------------------------------
 
    

    //----------------------------------------------------
    //
    //   for each point of the mesh, take normal and
    //   for each scale (radius+sigma) get the HOG
    //
    //
    //
    
    //create the matrix to store everything
    //every row - HOGS of one point
    
    
    
    std::cout<<"Calculating Features"<<std::endl;
    
    //sample something to know the size of the neighborhood
    double pt[3];
    mesh->GetPoint(0, pt);    
    arma::vec features;
    
    
    const int nfeatures = 4; //4 statistical features
    arma::mat feature_matrix(mesh->GetNumberOfPoints(), nfeatures*radii.size()); 
    //
    
    
 

   
    for(int j=0; j<radii.size(); j++)
    {

        std::cout<<"Radius "<<radii[j]<<std::endl;
    
        for(int i=0; i<mesh->GetNumberOfPoints(); i++)
        {
            if(i%1000==0)
                std::cout<<"Point "<<i<<"/"<<mesh->GetNumberOfPoints()<<"\r"<<std::flush;


        
            double pt[3];
            mesh->GetPoint(i, pt);



            SampleStatistics<ImageType>(pt, radii[j], imageReader->GetOutput(), features);

            if(nfeatures!=features.size())
                std::cout<<std::endl<<"Nonconstant number of samples (vertex "<<i<<"): "<<nfeatures<<" vs "<<features.size()<<std::endl<<std::flush;
            
            //copy the histogram into the matrix
            for(int k=0; k<features.size(); k++)
            {
                feature_matrix(i,k+j*nfeatures) = features(k);
            }
        }
        
        std::cout<<std::endl;
    }
    //
    //
    //----------------------------------------------------
    
    std::cout<<std::endl<<"Saving"<<std::endl;
             
    //save the matrix
    feature_matrix.save(outputFeaturesFileName, arma::csv_ascii);
    
    //
    
    return EXIT_SUCCESS;
}