!---------------------------------------------------------------------
!
! plait.f90:
! Elaborated on original ideas from Ch. Nieten 
!   see https://articles.adsabs.harvard.edu/pdf/2000ASPC..217...72N
! and the NOD2 package, namely the "Basket Weaving" technique described
! in Emerson & Graeve 1988
!   see https://articles.adsabs.harvard.edu/pdf/1988A%26A...190..353E.
!
! P. Hily-Blant feb-2005
! Documented and Cleaned    Stéphane Guilloteau May-2023
!---------------------------------------------------------------------
!
module plait_common
  integer(kind=4), parameter :: sp=4
  integer(kind=4), parameter :: dp=8
  character(len=*), parameter :: rname = 'PLAIT'
  real(kind=sp), parameter :: wfunc_thres = 1.e-3_sp
end module plait_common
!
subroutine plait_command(line,comm,error)
  use image_def
  use gkernel_interfaces
  use gbl_message
  use imager_interfaces, only : map_message
  !
  ! PLAIT Output In1 Angle1 Length1 In2 Angle2 Length2 ...
  !
  character(len=*), intent(in) :: line
  character(len=*), intent(in) :: comm
  logical :: error
  integer(kind=4), parameter :: mmap=5
  type(gildas) :: cin(mmap)
  type(gildas) :: cout 
  real :: thres(2)
  real :: theta(mmap)
  real :: length(mmap)
  logical :: dump
  !
  integer :: i, nmap, narg, n
  !
  narg = sic_narg(0)
  if (mod(narg,3).ne.1) then
    call map_message(seve%e,comm,'Invalid number of arguments')
    error = .true.
    return
  endif
  if (narg.lt.7) then
    call map_message(seve%e,comm,'Insufficient number of arguments')
    error = .true.
    return
  endif
  if (narg.gt.3*mmap+1) then
    call map_message(seve%e,comm,'Too many arguments')
    error = .true.
    return
  endif
  !
  nmap = (narg-1)/3
  !
  call gildas_null(cout)
  !
  call sic_ch(line,0,1,cout%file,n,.true.,error)
  if (error) return
  !
  do i=1,nmap 
    call gildas_null(cin(i))
    call sic_ch(line,0,3*i-1,cin(i)%file,n,.true.,error)
    if (error) return
    call sic_r4(line,0,3*i,theta(i),.true.,error)
    if (error) return
    call sic_r4(line,0,3*i+1,length(i),.true.,error)
  enddo
  thres = 0.0 ! No cutoff ?
  dump = .false.
  call sic_get_logi('DUMP',dump,error)
  call plait_routine(nmap,cin,cout,theta,length,thres,dump,error)
contains
!
subroutine plait_check(next,prev,error)
  use plait_common
  use image_def
  !---------------------------------------------------------------------
  ! CHECK INPUT CUBES HAVE IDENTICAL DIMENSIONS
  !---------------------------------------------------------------------
  type(gildas), intent(in)  :: prev,next
  logical,      intent(out) :: error
  !
  error = .false.
  if (next%gil%ndim.ne.prev%gil%ndim) then
    error = .true.
    return
  endif
  if (.not.all(next%gil%dim.eq.prev%gil%dim)) then
    error = .true.
  endif
end subroutine plait_check
!
function plait_wfunc(r,width)
  use plait_common
  use phys_const
  !---------------------------------------------------------------------
  ! WEIGHTING FUNCTION
  !---------------------------------------------------------------------
  real(kind=sp) :: plait_wfunc
  real(kind=sp), intent(in) :: r      ! Current distance in Fourier plane
  real(kind=sp), intent(in) :: width  ! Width up to which to down-scale weight
  ! Local
  real(kind=sp) :: zsc
  !
  plait_wfunc = 1.0_sp
  if (width.eq.0.0) return
  zsc = r*width
  if (abs(zsc).ge.1.0_sp) return
  zsc = zsc*pi*0.5_sp
  plait_wfunc = max(wfunc_thres,(sin(zsc))**2)
end function plait_wfunc
!
subroutine plait_weight(w,theta,scale,head)
  use plait_common
  use image_def
  use phys_const
  !---------------------------------------------------------------------
  ! COMPUTES WEIGHTING IMAGE
  !---------------------------------------------------------------------
  real(kind=4),  intent(inout) :: w(:,:)  ! Data
  real(kind=sp), intent(in)    :: theta   ! Scanning angle
  real(kind=sp), intent(in)    :: scale   ! Correlation Scale  
  type(gildas),  intent(in)    :: head    ! Data header
  ! Local
  integer(kind=4) :: i,j,m,n
  real(kind=sp) :: phi,cphi,sphi
  real(kind=sp) :: x(head%gil%dim(1)),y(head%gil%dim(2))
  real(kind=sp) :: d
  real(kind=sp) :: dx,dy,xinc,yinc
  !
  n = head%gil%dim(1)
  m = head%gil%dim(2)
  xinc = head%gil%convert(3,1)
  yinc = head%gil%convert(3,2)
  ! orientation of weighting in FT space
  phi  = theta * pi / 180._sp
  cphi = cos(phi)
  sphi = sin(phi)
  ! discretization of axes in FT space
  ! Why is it (n-1) ? (S.G.)
  dx   = 1.0_sp / (xinc*(n-1))
  dy   = 1.0_sp / (yinc*(m-1))
  do i=1,n
    x(i) = (i-1-(i/(n/2+1))*(n-1))
  enddo
  do j=1,m
    y(j) = (j-1-(j/(m/2+1))*(m-1))
  enddo
  ! Make it N,M
  dx   = 1.0_sp / (xinc*n) 
  dy   = 1.0_sp / (yinc*m)
  do i=1,n
    x(i) = (i-1-(i/(n/2+1))*n)
  enddo
  do j=1,m
    y(j) = (j-1-(j/(m/2+1))*m)
  enddo

  x = x * dx * cphi
  y = y * dy * sphi
  ! DX and DY are redefined in the loops
  !   DX+DY is a strange "distance", but in a Rectangular
  ! scanning, it is null in one direction, due to the sin/cos
  do i=1,n
    dx = x(i)
    do j=1,m
      dy = y(j)
      d  = dx + dy
      w(i,j) = plait_wfunc(d,scale)
    enddo
  enddo
end subroutine plait_weight
!
subroutine plait_cutoff(a,lower,upper,bval)
  use plait_common
  !---------------------------------------------------------------------
  ! BLANK PIXELS OUTSIDE [LOWER:UPPER] INTERVAL
  !---------------------------------------------------------------------
  real(kind=sp), intent(in)    :: lower,upper,bval
  real(kind=sp), intent(inout) :: a(:,:,:)
  !
  if ((lower.eq.0._sp).and.(upper.eq.0._sp)) return
  where (a.lt.lower.or.a.gt.upper) a = bval
  !
end subroutine plait_cutoff
!
subroutine plait_replace(a,b,bval,eval)
  use plait_common
  !---------------------------------------------------------------------
  ! WHERE A IS BLANKED REPLACE WITH CORRESPONDING B VALUE
  !---------------------------------------------------------------------
  real(kind=4),  intent(in)    :: b(:,:,:)
  real(kind=4),  intent(inout) :: a(:,:,:)
  real(kind=sp), intent(in)    :: bval,eval
  !
  where (abs(a-bval).le.eval) a = b
  !
end subroutine plait_replace
!
subroutine plait_norm(a,w,bval,eval)
  use plait_common
  !---------------------------------------------------------------------
  ! NORMALIZE THE OUTPUT OF PLAITED FFTs
  !---------------------------------------------------------------------
  complex(kind=4), intent(inout) :: a(:,:,:)    ! Data
  real(kind=sp),   intent(in)    :: w(:,:)      ! Weight array
  real(kind=4),    intent(in)    :: bval,eval   ! Blanking
  ! Local
  integer(kind=4) :: i,dim3
  !
  dim3 = size(a,3)
  do i=1,dim3
    where (w.ne.0._sp.and.abs(a(:,:,i)-bval).gt.eval) &
        &  a(:,:,i) = a(:,:,i) / w
  enddo
end subroutine plait_norm
!
subroutine plait_add(aa,ww,head,a1)
  use plait_common
  use image_def
  !---------------------------------------------------------------------
  ! ACCUMULATE A DATA CUBE TAKING INTO ACCOUNT BLANKING VALUE
  !---------------------------------------------------------------------
  real(kind=4),    intent(inout) :: aa(:,:,:)
  integer(kind=4), intent(inout) :: ww(:,:,:)
  type(gildas),    intent(in)    :: head
  real(kind=4),    intent(in)    :: a1(:,:,:)
  ! Local
  integer(kind=4) :: mask(head%gil%dim(1),head%gil%dim(2))
  real(kind=sp) :: bval , eval
  integer(kind=4) :: i

  bval = head%gil%bval
  eval = head%gil%eval
  do i=1,head%gil%dim(3)
    !
    mask = 0
    where (abs(a1(:,:,i)-bval).gt.eval) mask = 1
    ww(:,:,i) = ww(:,:,i) + mask
    aa(:,:,i) = aa(:,:,i) + a1(:,:,i) * mask
    !
  enddo
  !
end subroutine plait_add
!
subroutine plait_mean(w,a,bval,eval)
  use gkernel_interfaces
  use plait_common
  !---------------------------------------------------------------------
  ! COMPUTES THE MEAN OF INPUT CUBES
  ! IF THE MEAN CONTAINS BLANKED PIXEL(S) RETURN
  !---------------------------------------------------------------------
  real(kind=sp),   intent(in)    :: bval,eval
  real(kind=sp),   intent(inout) :: a(:,:,:)
  integer(kind=4), intent(in)    :: w(:,:,:)
  !
  ! First, check if any zero in weights
  if (minval(w).eq.0) then
    call gagout('E-'//rname//',  BLANKED PIXEL(S) IN SUM OF IMAGES')
    return
  endif
  ! No zero in w : computes the MEAN
  where (abs(a-bval).gt.eval) a = a / w
  !
end subroutine plait_mean
!
subroutine plait_routine(nmap,cin,cout,theta,length,thres,dump,error)
  use image_def
  use gkernel_interfaces
  use plait_common
  !---------------------------------------------------------------------
  !
  !---------------------------------------------------------------------
  integer, intent(in) :: nmap
  type(gildas), intent(inout) :: cin(nmap)
  type(gildas), intent(inout) :: cout 
  real(kind=sp), intent(in) :: length(nmap)
  real(kind=sp), intent(in) :: theta(nmap) 
  real(kind=sp), intent(in) :: thres(2)
  logical, intent(in) :: dump
  logical, intent(out) :: error
  !
  real(kind=sp) :: blan(2)
  !
  integer(kind=4) :: i,j,ier
  integer(kind=4) :: dim(3),maxdim
  character(len=256) :: message
  complex(kind=4), allocatable :: b(:,:,:),res(:,:,:)
  real(kind=4),    allocatable :: a(:,:,:),som(:,:,:)
  integer(kind=4), allocatable :: w3d(:,:,:)
  real(kind=sp),   allocatable :: w2d(:,:),norm(:,:)
  real(kind=4),    allocatable :: wfft(:)
  type(gildas) :: hdump
  !
  blan(2) = -1.0
  !
  ! Check that input files exist
  do i=1,nmap
    ier = gag_inquire(cin(i)%file,lenc(cin(i)%file))
    if (ier.ne.0) then
      message = 'Input cube '//char(48+i)//' ' &
        &        //cin(i)%file(1:len_trim(cin(i)%file))//' not found'
      call gagout('F-'//rname//',  '//message)
      error = .true.
      return
    endif
  enddo
  !
  ! Check that input cubes match
  i = 1
  call gdf_read_header(cin(i),error)
  if (gildas_error(cin(i),rname,error)) return
  do i=2,nmap
    call gdf_read_header(cin(i),error)
    if (gildas_error(cin(i),rname,error)) return
    call plait_check(cin(i),cin(1),error)
    if (error) then
      call gagout('F-'//rname//',  Input cubes dimensions do not match')
      return
    endif
    !
    ! Setup a default blanking
    if (blan(2).lt.0) then
      if (cin(i)%gil%eval.ge.0) then
        blan(1) = cin(i)%gil%bval
        blan(2) = cin(i)%gil%eval
      endif
    endif
  enddo
  dim(1:3) = cin(1)%gil%dim(1:3)
  maxdim = 2*maxval(dim)
  !
  ! Allocating memory
  allocate(&
       w2d (dim(1),dim(2)),&
       norm(dim(1),dim(2)),&
       w3d (dim(1),dim(2),dim(3)),&
       a   (dim(1),dim(2),dim(3)), &
       b   (dim(1),dim(2),dim(3)), &
       res (dim(1),dim(2),dim(3)),&
       som (dim(1),dim(2),dim(3)),&
       wfft(maxdim),&
       stat=ier)
  if (failed_allocate(rname,'plait buffers',ier,error)) return
  if (dump) then
    call gildas_null(hdump)
    call gdf_copy_header(cin(1),hdump,error)
    hdump%gil%ndim = 2
  endif
  !
  ! Compute the mean of input cubes
  !
  w3d = 0.
  som = 0.
  do i=1,nmap
    call gdf_read_data(cin(i),a,error)
    if (gildas_error(cin(i),rname,error)) return
    call plait_cutoff(a,thres(1),thres(2),blan(1))
    call plait_add(som,w3d,cin(i),a)
  enddo
  call plait_mean(w3d,som,blan(1),blan(2))
  !
  ! Perform the plait algorithm
  !
  res  = cmplx(0.,0.,kind(0.))
  do i=1,nmap
    wfft = cmplx(0.,kind(0.))
    call gdf_read_data(cin(i),a,error)
    if (gildas_error(cin(i),rname,error)) return
    !
    ! Why should we cut with a Threshold ?
    !   It seems more natural to cut if the data deviates from the mean
    ! by more than the threshold ? 
    call plait_cutoff(a,thres(1),thres(2),blan(1))
    ! Replace by Mean if Thresholded
    call plait_replace(a,som,blan(1),blan(2))
    b(:,:,:) = cmplx(a,0.,kind(a))
    ! Compute the weighting function in Fourier plane
    call plait_weight(w2d,theta(i),length(i),cin(i))
    if (dump) then
      write(hdump%file,'(A,i0)') 'weight-',i
      call gdf_write_image(hdump,w2d,error)
    endif
    ! Apply it to the data
    do j=1,dim(3)
      call fourt(b(:,:,j),dim(1:2),2,1,1,wfft)
    enddo
    ! Add to the Fourier values
    do j=1,dim(3)
      res(:,:,j) = res(:,:,j) + b(:,:,j)*w2d
    enddo
    ! Increment the norm
    norm(:,:) = norm + w2d
  enddo
  ! Normalize the Fourier values
  call plait_norm(res,norm,blan(1),blan(2))
  wfft = cmplx(0.,kind(0.))
  ! Transform back to Image plane
  do j=1,dim(3)
    call fourt(res(:,:,j),dim(1:2),2,-1,1,wfft)
  enddo
  a(:,:,:) = real(res)
  a = a / (dim(1)*dim(2))
  !
  ! Write result in output image
  !
  call gdf_copy_header(cin(1),cout,error)
  call gdf_write_image(cout,a,error)
  if (gildas_error(cout,rname,error)) return
  deallocate(w2d,norm,w3d,a,b,res,som,wfft,stat=ier)
  if (ier.ne.0) then
    call gagout('F-'//rname//',  Deallocation error')
  endif
  do i=1,nmap
    call gdf_close_image(cin(i),error)
  enddo
end subroutine plait_routine
!
end subroutine plait_command
