/*=========================================================================

  Program:   Insight Segmentation & Registration Toolkit
  Module:    itkWarpImageFilterTest.cxx
  Language:  C++
  Date:      $Date$
  Version:   $Revision$

  Copyright (c) Insight Software Consortium. All rights reserved.
  See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even 
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#endif

#include <iostream>

#include "itkVector.h"
#include "itkIndex.h"
#include "itkImage.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkWarpImageFilter.h"
#include "itkVectorCastImageFilter.h"
#include "itkStreamingImageFilter.h"
#include "itkCommand.h"
#include "vnl/vnl_math.h"

// class to produce a linear image pattern
template <int VDimension>
class ImagePattern
{
public:
  typedef itk::Index<VDimension> IndexType;
  typedef typename IndexType::IndexValueType IndexValueType;
  typedef itk::Size<VDimension> SizeType;
  typedef float PixelType;

  ImagePattern() 
    {
    offset = 0.0;
    for( int j = 0; j < VDimension; j++ )
      {
      coeff[j] = 0.0;
      }
    }

#ifdef ITK_USE_CENTERED_PIXEL_COORDINATES_CONSISTENTLY
    double Evaluate( const IndexType& index , const SizeType& size,
                     const SizeType& clampSize, const PixelType& padValue)
    {
#else
    double Evaluate( const IndexType& index , const SizeType&,
                     const SizeType&, const PixelType&)
    {
#endif
    double accum = offset;
    for( int j = 0; j < VDimension; j++ )
      {
#ifdef ITK_USE_CENTERED_PIXEL_COORDINATES_CONSISTENTLY
         if ( index[j] < static_cast<IndexValueType>(size[j]) )
           {
           if ( index[j] >= static_cast<IndexValueType>(clampSize[j]) )
             {
             //Interpolators behave this way in half-pixel band at image perimeter
             accum += coeff[j] * (double) (clampSize[j]-1);
             }
           else
             {
             accum += coeff[j] * (double) index[j];
             }
           }
         else
           {
           accum = padValue;
           break; 
           }     
#else         
         accum += coeff[j] * (double) index[j];         

#endif
      }
            
    return accum;
      }

  double coeff[VDimension];
  double offset;

};

// The following three classes are used to support callbacks
// on the filter in the pipeline that follows later
class ShowProgressObject
{
public:
  ShowProgressObject(itk::ProcessObject* o)
    {m_Process = o;}
  void ShowProgress()
    {std::cout << "Progress " << m_Process->GetProgress() << std::endl;}
  itk::ProcessObject::Pointer m_Process;
};



int itkWarpImageFilterTest(int, char* [] )
{
  typedef float PixelType;
  enum { ImageDimension = 2 };
  typedef itk::Image<PixelType,ImageDimension> ImageType;

  typedef itk::Vector<float,ImageDimension> VectorType;
  typedef itk::Image<VectorType,ImageDimension> FieldType;

  bool testPassed = true;


  //=============================================================

  std::cout << "Create the input image pattern." << std::endl;
  ImageType::RegionType region;
  ImageType::SizeType size = {{64, 64}};
  region.SetSize( size );
  
  ImageType::Pointer input = ImageType::New();
  input->SetLargestPossibleRegion( region );
  input->SetBufferedRegion( region );
  input->Allocate();

  ImageType::PixelType padValue = 4.0;

  int j;
  ImagePattern<ImageDimension> pattern;

  pattern.offset = 64;
  for( j = 0; j < ImageDimension; j++ )
    {
    pattern.coeff[j] = 1.0;
    }

  typedef itk::ImageRegionIteratorWithIndex<ImageType> Iterator;
  Iterator inIter( input, region );

  ImageType::PointType point;
  for( ; !inIter.IsAtEnd(); ++inIter )
    {
    inIter.Set( pattern.Evaluate( inIter.GetIndex(), size, size, padValue ) );
    }

  //=============================================================

  std::cout << "Create the input deformation field." << std::endl;

  //Tested with { 2, 4 } and { 2, 5 } as well...
  unsigned int factors[ImageDimension] = { 2, 3 };

  ImageType::RegionType fieldRegion;
  ImageType::SizeType fieldSize;
  for( j = 0; j < ImageDimension; j++ )
    {
    fieldSize[j] = size[j] * factors[j] + 5;
    }
  fieldRegion.SetSize( fieldSize );

  FieldType::Pointer field = FieldType::New();
  field->SetLargestPossibleRegion( fieldRegion );
  field->SetBufferedRegion( fieldRegion );
  field->Allocate(); 

  typedef itk::ImageRegionIteratorWithIndex<FieldType> FieldIterator;
  FieldIterator fieldIter( field, fieldRegion );

  for( ; !fieldIter.IsAtEnd(); ++fieldIter )
    {
    ImageType::IndexType index = fieldIter.GetIndex();
    VectorType displacement;
    for( j = 0; j < ImageDimension; j++ )
      {
      displacement[j] = (float) index[j] * ( (1.0 / factors[j]) - 1.0 );
      }
    fieldIter.Set( displacement );
    }

  //=============================================================

  std::cout << "Run WarpImageFilter in standalone mode with progress.";
  std::cout << std::endl;
  typedef itk::WarpImageFilter<ImageType,ImageType,FieldType> WarperType;
  WarperType::Pointer warper = WarperType::New();

  warper->SetInput( input );
  warper->SetDeformationField( field );
  warper->SetEdgePaddingValue( padValue );

  ShowProgressObject progressWatch(warper);
  itk::SimpleMemberCommand<ShowProgressObject>::Pointer command;
  command = itk::SimpleMemberCommand<ShowProgressObject>::New();
  command->SetCallbackFunction(&progressWatch,
                               &ShowProgressObject::ShowProgress);
  warper->AddObserver(itk::ProgressEvent(), command);

  warper->Print( std::cout );

  // exercise Get methods
  std::cout << "Interpolator: " << warper->GetInterpolator() << std::endl;
  std::cout << "DeformationField: " << warper->GetDeformationField() << std::endl;
  std::cout << "EdgePaddingValue: " << warper->GetEdgePaddingValue() << std::endl;

  // exercise Set methods
  itk::FixedArray<double,ImageDimension> array;
  array.Fill( 2.0 );
  warper->SetOutputSpacing( array.GetDataPointer() );
  array.Fill( 1.0 );
  warper->SetOutputSpacing( array.GetDataPointer() );

  array.Fill( -10.0 );
  warper->SetOutputOrigin( array.GetDataPointer() );
  array.Fill( 0.0 );
  warper->SetOutputOrigin( array.GetDataPointer() );
 
  // Update the filter
  warper->Update();

  //=============================================================
  
  std::cout << "Checking the output against expected." << std::endl;
  Iterator outIter( warper->GetOutput(),
    warper->GetOutput()->GetBufferedRegion() );

  // compute non-padded output region
  ImageType::RegionType validRegion;
  ImageType::SizeType validSize = validRegion.GetSize();

#ifndef ITK_USE_CENTERED_PIXEL_COORDINATES_CONSISTENTLY
  for( j = 0; j < ImageDimension; j++ )
    {
    validSize[j] = size[j] * factors[j] - (factors[j] - 1);
    }
#else
  //Needed to deal with incompatibility of various IsInside()s &
  //nearest-neighbour type interpolation on half-band at perimeter of
  //image. Evaluate() now has logic for this outer half-band.   
  ImageType::SizeType decrementForScaling;
  ImageType::SizeType clampSizeDecrement;
  ImageType::SizeType clampSize;
  for( j = 0; j < ImageDimension; j++ )
    {
    validSize[j] = size[j] * factors[j];

    //Consider as inside anything < 1/2 pixel of (size[j]-1)*factors[j]
    //(0-63) map to (0,126), with 127 exactly at 1/2 pixel, therefore
    //edged out; or to (0,190), with 190 just beyond 189 by 1/3 pixel;
    //or to (0,253), with 254 exactly at 1/2 pixel, therefore out
    //also; or (0, 317), with 317 at 2/5 pixel beyond 315. And so on. 

    decrementForScaling[j] =   factors[j] / 2 ;

    validSize[j] -= decrementForScaling[j];
    
    //This part of logic determines what is inside, but in outer
    //1/2 pixel band, which has to be clamped to that nearest outer
    //pixel scaled by factor: (0,63) maps to (0,190) as inside, but
    //pixel 190 is outside of (0,189), and must be clamped to it.
    //If factor is 2 or less, this decrement has no effect. 

    if( factors[j] < 1+decrementForScaling[j])
      {
      clampSizeDecrement[j] = 0;
      }
    else
      {
      clampSizeDecrement[j]  =  (factors[j] - 1 - decrementForScaling[j]) ;
      }
    clampSize[j]= validSize[j] - clampSizeDecrement[j];
    }
#endif

  validRegion.SetSize( validSize );
  
  // adjust the pattern coefficients to match
  for( j = 0; j < ImageDimension; j++ )
    {
    pattern.coeff[j] /= (double) factors[j];
    }
    
  for( ; !outIter.IsAtEnd(); ++outIter )
    {
    ImageType::IndexType index = outIter.GetIndex();

    double value = outIter.Get();

    if( validRegion.IsInside( index ) )
      {
         
#ifdef ITK_USE_CENTERED_PIXEL_COORDINATES_CONSISTENTLY
    double trueValue = pattern.Evaluate( outIter.GetIndex(), validSize, clampSize, padValue );
#else
    double trueValue = pattern.Evaluate( outIter.GetIndex(), validSize, validSize, padValue );   
#endif

      if( vnl_math_abs( trueValue - value ) > 1e-4 )
        {
        testPassed = false;
        std::cout << "Error at Index: " << index << " ";
        std::cout << "Expected: " << trueValue << " ";
        std::cout << "Actual: " << value << std::endl;
        }
      }
    else
      {
      
      if( value != padValue )
        {
        testPassed = false;
        std::cout << "Error at Index: " << index << " ";
        std::cout << "Expected: " << padValue << " ";
        std::cout << "Actual: " << value << std::endl;
        }
      }

    }

  //=============================================================

  std::cout << "Run ExpandImageFilter with streamer";
  std::cout << std::endl;

  typedef itk::VectorCastImageFilter<FieldType,FieldType> VectorCasterType;
  VectorCasterType::Pointer vcaster = VectorCasterType::New();

  vcaster->SetInput( warper->GetDeformationField() );

  WarperType::Pointer warper2 = WarperType::New();

  warper2->SetInput( warper->GetInput() );
  warper2->SetDeformationField( vcaster->GetOutput() );
  warper2->SetEdgePaddingValue( warper->GetEdgePaddingValue() );

  typedef itk::StreamingImageFilter<ImageType,ImageType> StreamerType;
  StreamerType::Pointer streamer = StreamerType::New();
  streamer->SetInput( warper2->GetOutput() );
  streamer->SetNumberOfStreamDivisions( 3 );
  streamer->Update();

  //=============================================================
  std::cout << "Compare standalone and streamed outputs" << std::endl;

  Iterator streamIter( streamer->GetOutput(),
    streamer->GetOutput()->GetBufferedRegion() );

  outIter.GoToBegin();
  streamIter.GoToBegin();

  while( !outIter.IsAtEnd() )
    {
    if( outIter.Get() != streamIter.Get() )
      {
      std::cout << "Error C at Index: " << outIter.GetIndex() << " ";
      std::cout << "Expected: " << outIter.Get() << " ";
      std::cout << "Actual: " << streamIter.Get() << std::endl;
      testPassed = false;
      }
    ++outIter;
    ++streamIter;
    }
  

  if ( !testPassed )
    {
    std::cout << "Test failed." << std::endl;
    return EXIT_FAILURE;
    }

  // Exercise error handling
  
  typedef WarperType::InterpolatorType InterpolatorType;
  InterpolatorType::Pointer interp = warper->GetInterpolator();
 
  try
    {
    std::cout << "Setting interpolator to NULL" << std::endl;
    testPassed = false;
    warper->SetInterpolator( NULL );
    warper->Update();
    }
  catch( itk::ExceptionObject& err )
    {
    std::cout << err << std::endl;
    testPassed = true;
    warper->ResetPipeline();
    warper->SetInterpolator( interp );
    }

  if (!testPassed) {
    std::cout << "Test failed" << std::endl;
    return EXIT_FAILURE;
  }

 std::cout << "Test passed." << std::endl;
 return EXIT_SUCCESS;

}
