!===========================================================================
!
! Routines()
!
! (1) mtxel()   Originally By (?)               Last Modified 5/6/2012 (FHJ)
!
!     Compute matrix elements (gme) for valence state iv with all
!     conduction bands and for all G-vectors.
!
!                <c,k,ispin|exp(-i(q+G).r)|v,k+q,ispin>
!
!     On exit,
!       pol%eden(band,spin) = 1/(e_val-e_cond) = energy denominators
!       pol%gme(band,g-vector,spin) = plane wave matrix elements
!       pol%isrtx   orders the |G(i)|^2   i=1,pol%nmtx
!       vwfn%isort  orders |qk+g|^2    (in vwfn type)
!
!       energies are apparently assumed in Rydbergs.
!
!===========================================================================

#include "f_defs.h"

module mtxel_m

  use global_m
  use fftw_m
  use misc_m
  use lin_denominator_m
  implicit none

  private

  public :: mtxel_init_FFT_cond, mtxel_free_FFT_cond, mtxel

contains

!> Precalculates the FFTs for all conduction bands
subroutine mtxel_init_FFT_cond(gvec,pol,cwfn,kp)
  type (gspace), intent(in) :: gvec
  type (polarizability), intent(inout) :: pol
  type (conduction_wfns), intent(inout) :: cwfn
  type (kpoints), intent(inout) :: kp

  real(DP) :: scale
  integer, dimension(3) :: Nfft
  integer :: j, jcad, jspinor

  PUSH_SUB(mtxel_init_FFT_cond)

  if (kp%nspinor>1) then
      call die('Epsilon on Steroids only supports one spin at the moment.')
  endif
      
  if(peinf%inode.eq.0) call timacc(40,1)
  if(peinf%inode.eq.0) call timacc(43,1)
  call setup_FFT_sizes(pol%FFTgrid,Nfft,scale)
  SAFE_ALLOCATE(cwfn%wfn_fft, (Nfft(1),Nfft(2),Nfft(3),peinf%ncownactual))

  jspinor = 1
  do j=1,peinf%ncownactual
    jcad = (j-1)*cwfn%ngc

    call put_into_fftbox(cwfn%ngc,cwfn%zc(jcad+1:,jspinor),gvec%components,cwfn%isort,cwfn%wfn_fft(:,:,:,j),Nfft)

    call do_FFT(cwfn%wfn_fft(:,:,:,j),Nfft,1)
  enddo
  
  if(peinf%inode.eq.0) call timacc(43,2)
  if(peinf%inode.eq.0) call timacc(40,2)
  
  POP_SUB(mtxel_init_FFT_cond)
  return

end subroutine mtxel_init_FFT_cond

!> Frees ffts_cond buffer
subroutine mtxel_free_FFT_cond(cwfn)
  type (conduction_wfns), intent(inout) :: cwfn
  
  PUSH_SUB(mtxel_free_FFT_cond)

  SAFE_DEALLOCATE_P(cwfn%wfn_fft)

  POP_SUB(mtxel_free_FFT_cond)
  return

end subroutine mtxel_free_FFT_cond

subroutine mtxel(iv,gvec,vwfn,cwfn,pol,ispin,irk,kp,kpq,rank_mtxel)
  integer, intent(in) :: iv
  type (gspace), intent(in) :: gvec
  type (valence_wfns), intent(in) :: vwfn
  type (conduction_wfns), intent(in) :: cwfn
  type (polarizability), intent(inout) :: pol
  type (kpoints), intent(inout) :: kp,kpq
  integer, intent(in) :: ispin,irk
  integer, intent(in) :: rank_mtxel

  integer :: j,jspinor,jspinormin,jspinormax,iband,jcad,jj,ig,iband1
  real(DP), allocatable :: edenTemp(:)
  type(cvpair_info), allocatable :: lin_edenTemp(:)
  real(DP) :: eval,econd,occ_v,occ_c,occ_diff,lin_eden_val
      
  integer, dimension(3) :: Nfft
  real(DP) :: scale
  complex(DPC), dimension(:,:,:), allocatable :: fftbox1,fftbox2
  SCALAR, dimension(:), allocatable :: tmparray
  real(DP) :: vk(2),vkq(2)
  logical :: keep_transition

  PUSH_SUB(mtxel)

!-----------------------
! Compute energy denominators; eden depends on which iv we treat

  if(peinf%inode.eq.0) call timacc(25,1)
  if (pol%freq_dep .eq. 2 .or. pol%freq_dep .eq. 3) then
    SAFE_ALLOCATE(edenTemp, (cwfn%nband-vwfn%nband))
    edentemp(:)=0D0
  endif

  eval=vwfn%ev(iv,ispin)
  !FHJ: Do we really need all energy denominators, even for the bands we don`t own?
  !!    I guess this doesn`t take much time anyways...
  do j=vwfn%nband+1,cwfn%nband
    
    iband=j-vwfn%nband
    econd=cwfn%ec(j,ispin)
    
    ! guess occupations based on efermi; eventually this should be replaced by use of kp%occ
    if(eval*ryd > pol%efermi + TOL_Degeneracy) then
      occ_v = 0d0
    else if (eval*ryd < pol%efermi - TOL_Degeneracy) then
      occ_v = 1d0
    else
      occ_v = 0.5  ! within TOL_Degeneracy of the Fermi level, use FD(E_F) = 1/2
    endif

    if(econd*ryd > pol%efermi + TOL_Degeneracy) then
      occ_c = 0d0
    else if (econd*ryd < pol%efermi - TOL_Degeneracy) then
      occ_c = 1d0
    else
      occ_c = 0.5  ! within TOL_Degeneracy of the Fermi level, use FD(E_F) = 1/2
    endif

    occ_diff = occ_v - occ_c

! JRD/JBN: If ncrit is specified we have a metal

    if(pol%freq_dep .eq. 0) then
      if(eval - econd < TOL_Degeneracy .and. occ_diff > TOL_Zero) then
        ! avoid dividing by zero or making eden > 0
        pol%eden(iv,iband,ispin) = occ_diff / (eval - econd)
      else
        pol%eden(iv,iband,ispin) = 0d0 ! in this case, occ_diff = 0 too
      endif



    endif
    
    if(pol%freq_dep .eq. 2 .or. pol%freq_dep .eq. 3) then
      if (.not.peinf%doiownv(iv)) then
        POP_SUB(mtxel)
        return
      endif

      if(eval - econd < TOL_Degeneracy .and. occ_diff > TOL_Zero) then
        edenTemp(iband) = (eval - econd) / occ_diff
      else
        edenTemp(iband) = 0.0d0
        !!write(6,*) peinf%inode,' In a O band pair',iv,iband
      endif


    endif
  enddo

  if (pol%freq_dep .eq. 2 .or. pol%freq_dep .eq. 3) then
    if (peinf%doiownv(iv)) then
      do j = 1, peinf%ncownactual
        if(pol%os_para_freqs .gt. 1 .and. pol%gcomm .eq. 0) then
          pol%edenDyn(peinf%indexv(iv),j,ispin,irk,1) = edenTemp(peinf%invindexc(j))
        else
          pol%edenDyn(peinf%indexv(iv),j,ispin,irk,rank_mtxel+1) = edenTemp(peinf%invindexc(j))
        endif
      enddo
    endif
    SAFE_DEALLOCATE(edenTemp)
  endif
  
  if(peinf%inode.eq.0) call timacc(25,2)
  
  if(peinf%inode.eq.0) call timacc(26,1)

!-------------------- Calculate Matrix Elements -------------------------------


!--------------------------
! Use FFTs to calculate matrix elements

! Compute size of FFT box we need

  call setup_FFT_sizes(pol%FFTgrid,Nfft,scale)
! Allocate FFT boxes
  SAFE_ALLOCATE(fftbox2, (Nfft(1),Nfft(2),Nfft(3)))

! Put the data for valence band iv into FFT box 1 and do the FFT

  call set_jspinor(jspinormin,jspinormax,ispin,kp%nspinor)

  if (pol%os_opt_ffts/=2) then
    SAFE_ALLOCATE(fftbox1, (Nfft(1),Nfft(2),Nfft(3)))
  endif

  if(pol%os_para_freqs .gt. 1 .and. pol%gcomm .eq. 0) then
    pol%gme(1:pol%nmtx,1:peinf%ncownactual,peinf%indexv(iv),ispin,irk,1) = ZERO
  else
    pol%gme(1:pol%nmtx,1:peinf%ncownactual,peinf%indexv(iv),ispin,irk,rank_mtxel+1) = ZERO
  endif

  SAFE_ALLOCATE(tmparray, (pol%nmtx))

  do jspinor=jspinormin,jspinormax

    if (pol%os_opt_ffts/=2) then
      call put_into_fftbox(vwfn%ngv,vwfn%zv(:,jspinor),gvec%components,vwfn%isort,fftbox1,Nfft)
      call do_FFT(fftbox1,Nfft,1)
      ! We need the complex conjugate of the |ivk> band actually
      call conjg_fftbox(fftbox1,Nfft)
    endif

! Now we loop over the conduction states and get the matrix elements:
! 1. Get conduction wave function and put it into box 2,
! 2. do FFT,
! 3. multiply by box1 contents,
! 4. do FFT again, and extract the resulting matrix elements and put the into pol
! We conjugate the final result since we really want <c|e^(-ig.r)|v>
! but we have calculated <v|e^(ig.r)|c>.

    do j=1,peinf%ncownactual
      iband = peinf%invindexc(j)
      jcad = (j-1)*cwfn%ngc

      if (pol%os_opt_ffts==2) then
        ! FHJ: optimization level 2 precomputed all the FFTs
        fftbox2(:,:,:) = vwfn%wfn_fft(:,:,:,peinf%indexv(iv)) * cwfn%wfn_fft(:,:,:,j)
      else
        if (pol%os_opt_ffts==1) then
          ! FHJ: Optimization level 1 precalculated at least these cond. FFTs
          fftbox2(:,:,:) = fftbox1(:,:,:) * cwfn%wfn_fft(:,:,:,j)
        else
          call put_into_fftbox(cwfn%ngc,cwfn%zc(jcad+1:,jspinor),gvec%components,cwfn%isort,fftbox2,Nfft)
          call do_FFT(fftbox2,Nfft,1)
          call multiply_fftboxes(fftbox1,fftbox2,Nfft)
        endif
      endif

      call do_FFT(fftbox2,Nfft,1)
      call get_from_fftbox(pol%nmtx,tmparray,gvec%components,pol%isrtx,fftbox2,Nfft,scale)

      keep_transition = .true.

      if (keep_transition) then
        if(pol%os_para_freqs .gt. 1 .and. pol%gcomm .eq. 0) then
          pol%gme(1:pol%nmtx,j,peinf%indexv(iv),ispin,irk,1) = &
            pol%gme(1:pol%nmtx,j,peinf%indexv(iv),ispin,irk,1) + MYCONJG(tmparray)
        else
          pol%gme(1:pol%nmtx,j,peinf%indexv(iv),ispin,irk,rank_mtxel+1) = &
            pol%gme(1:pol%nmtx,j,peinf%indexv(iv),ispin,irk,rank_mtxel+1) + MYCONJG(tmparray)
        endif
      endif

! JRD: Debugging of mtxel symmetries

!    do ijk = 1, pol%nmtx
!      call findvector(ijkm,-gvec%components(:,pol%isrtx(ijk)),gvec)
!      ijkm=pol%isrtxi(ijkm)
!      if (ijkm .le. pol%nmtx) then
!        write(3003,*) iv,j,ijk,ijkm,Abs(pol%gme(ijk,j,peinf%indexv(iv),ispin,irk,1))&
!                        -Abs(pol%gme(ijkm,j,peinf%indexv(iv),ispin,irk,1))
!      endif
!    enddo
     
      if (kp%nspinor.eq.1 .or. jspinor.eq.2) then
        if (pol%freq_dep .eq. 0) then
          pol%gme(1:pol%nmtx,j,peinf%indexv(iv),ispin,irk,1) = &
            pol%gme(1:pol%nmtx,j,peinf%indexv(iv),ispin,irk,1) * &
            sqrt(-1D0*pol%eden(iv,iband,ispin))

        endif
      endif

    enddo

  enddo 

  SAFE_DEALLOCATE(tmparray)

! We are done, so deallocate FFT boxes
  if (pol%os_opt_ffts.ne.2) then
    SAFE_DEALLOCATE(fftbox1)
  endif
  SAFE_DEALLOCATE(fftbox2)


! End FFT Case
!---------------------------

! End Calculation of matrix elements
!----------------------------------------------------------------------------

!--------- Renormalize q-->0 matrix elements by q0norm -----------------------

! SIB: gvec%components(:,isave) is the null vector.  Used to scale by 1/q0
! at the end.
! JRD:  The set of entries in gme with g-vector index isave (g=0)
! are divided by q0norm.  JRD: But if we truncate, we will include
! this factor in epsinv.f90

  if(peinf%inode.eq.0) call timacc(26,2)

  POP_SUB(mtxel)

  return
end subroutine mtxel

end module mtxel_m
