!==========================================================================
!
! Routines:
!
! (1) mtxel_vxc()                          Last Modified: 5/12/2008 (JRD)
!
!     Calculates matrix elements of the DFT exchange-correlation potential
!     sig%vxc and puts the results into alda:
!
!     alda(in,1) = <nk|Vxc(r)|nk>  with  n = sig%diag(in)
!
!==========================================================================

#include "f_defs.h"

subroutine mtxel_vxc(kp,gvec,sig,wfnk,wfnkoff,alda,vxc2_flag)

  use global_m
  use fftw_m
  use misc_m
  implicit none

  type (kpoints), intent(in) :: kp
  type (gspace), intent(in) :: gvec
  type (siginfo), intent(in) :: sig
  type (wfnkstates), intent(inout) :: wfnk, wfnkoff
  SCALAR, intent(out) :: alda(sig%ndiag+sig%noffdiag,sig%nspin)
  logical, intent(in) :: vxc2_flag
  
  integer :: in,ioff,ispin
  integer :: spinordim,jspinor,jspinormin,jspinormax
  integer :: ig,iso,jso,ik,ii,jj,ij,js,source,dest,tag
  character (len=100) :: tmpstr
  SCALAR, allocatable :: alda2(:,:)

  complex(DPC), allocatable :: &
    fftbox1(:,:,:,:),fftbox2(:,:,:,:),fftbox3(:,:,:,:)
  integer, dimension(:), allocatable :: gvecindex
  integer, dimension(3) :: Nfft
  integer :: ix, iy, iz
  real(DP) :: scale
  SCALAR, allocatable :: vxctemp(:,:)
  SCALAR :: temp, tempxy


!-------------------- Begin Routine --------------------------------------------


  PUSH_SUB(mtxel_vxc)

! Nullify output

  alda=0.0d0

! Using FFT to compute matrix elements

! Allocate temporary wavefunction array

  call set_jspinor(jspinormin,jspinormax,sig%spin_index(sig%nspin),kp%nspinor)  
  spinordim=jspinormax-jspinormin+1

  if (sig%noffdiag.gt.0) then
    if (mod(peinf%inode,peinf%npes/peinf%npools).eq.0)  then 
      SAFE_ALLOCATE(wfnkoff%zk, (2*wfnk%nkpt,spinordim))
    endif
  endif

! Get FFT box sizes and scale factor

  call setup_FFT_sizes(gvec%FFTgrid,Nfft,scale)

! Get FFT box sizes and scale factor
! Allocate FFT boxes

  SAFE_ALLOCATE(fftbox1, (Nfft(1),Nfft(2),Nfft(3),spinordim**2))
  SAFE_ALLOCATE(fftbox2, (Nfft(1),Nfft(2),Nfft(3),spinordim))
  SAFE_ALLOCATE(fftbox3, (Nfft(1),Nfft(2),Nfft(3),spinordim))
  SAFE_ALLOCATE(vxctemp, (gvec%ng,sig%nspin*spinordim**2))

! Put the vxc data into fftbox 1 and FFT it to real space

  SAFE_ALLOCATE(gvecindex, (gvec%ng))
  do ig=1,gvec%ng
    gvecindex(ig)=ig
  enddo

  do ispin=1,sig%nspin
    
    if (spinordim.eq.2) then
      ! BAB: components 1 and 2 are upper and lower diagonal of VXC 2x2 matrix
      !      component 3 is the upper off-diagonal, lower off-diag is complex conjugate
      vxctemp(:,1) = sig%vxc(:,1) +   sig%vxc(:,4)
      vxctemp(:,2) = sig%vxc(:,1) -   sig%vxc(:,4)
      vxctemp(:,3) = sig%vxc(:,2) - CMPLX(0.0d0,1.0d0)*sig%vxc(:,3)
      vxctemp(:,4) = sig%vxc(:,2) + CMPLX(0.0d0,1.0d0)*sig%vxc(:,3)
    else
      if (.not. vxc2_flag) then
        vxctemp(:,ispin) = sig%vxc(:,sig%spin_index(ispin))
      else
        vxctemp(:,ispin) = sig%vxc2(:,sig%spin_index(ispin))
      endif
    endif

    do jspinor=1,spinordim**2
      call put_into_fftbox(gvec%ng,vxctemp(:,ispin*jspinor),gvec%components,gvecindex,fftbox1(:,:,:,jspinor),Nfft)
      call do_FFT(fftbox1(:,:,:,jspinor),Nfft,1)
    enddo

! Loop over the bands for which we need the matrix elements
! For each one, put the band into fftbox2, FFT to real space,
! and then integrate in real space vxc(r)*|psi(r)|^2
! Store result into alda(in,ispin).

    do in=1,peinf%ndiag_max
      temp=ZERO
      tempxy=ZERO
      write(tmpstr,'("Computing <n|Vxc|n> for n=",i4)') &
        sig%diag(peinf%index_diag(in))
      call logit(tmpstr)

        if (mod(peinf%inode,peinf%npes/peinf%npools).eq.0) then

          do jspinor=1,spinordim
            ij = ispin*jspinor ! for shortening lengthy call to put_into_fftbox
            call put_into_fftbox(wfnk%nkpt,wfnk%zk((in-1)*wfnk%nkpt+1:,ij),gvec%components,wfnk%isrtk,fftbox2(:,:,:,jspinor),Nfft)
            call do_FFT(fftbox2(:,:,:,jspinor),Nfft,1)
          enddo

          do jspinor=1,spinordim
            do iz=1,Nfft(3)
              do iy=1,Nfft(2)
                do ix=1,Nfft(1)
                  temp = temp + fftbox1(ix,iy,iz,jspinor) * abs(fftbox2(ix,iy,iz,jspinor))**2
                  if (jspinor.eq.2) then
                    tempxy = tempxy + fftbox1(ix,iy,iz,2*spinordim-1) * &
                      dble(conjg(fftbox2(ix,iy,iz,jspinor-1))) * dble(fftbox2(ix,iy,iz,jspinor))
                  endif
                enddo
              enddo
            enddo
          enddo ! jspinor
        endif

      if (peinf%flag_diag(in)) then
        alda(peinf%index_diag(in),ispin)=(temp + 2.0d0*dble(tempxy))*scale
      endif

    enddo ! in
    
    do ioff=1,peinf%noffdiag_max

      if (spinordim .eq. 2) then
        call die('Off diagonal bands not correctly implemented yet for spinors')
      endif

      write(tmpstr,'("Computing offdiag <n|Vxc|m> for n,m=",2i4)') &
        sig%off1(peinf%index_offdiag(ioff)), &
        sig%off2(peinf%index_offdiag(ioff))
      call logit(tmpstr)

      temp=ZERO
      tempxy=ZERO
      do jspinor=1,spinordim

! (gsm) begin gathering wavefunctions over pools

! $$$ inefficient communication, this should be rewritten $$$

        do jj=1,peinf%npools
          dest=(jj-1)*(peinf%npes/peinf%npools)
          do ii=1,2
            iso=sig%offmap(peinf%index_offdiag(ioff),ii)
#ifdef MPI
            call MPI_Bcast(iso,1,MPI_INTEGER,dest,MPI_COMM_WORLD,mpierr)
#endif
            jso=(iso-1)/peinf%npools+1
            source=mod(iso-1,peinf%npools)*(peinf%npes/peinf%npools)
            if (peinf%inode.eq.source.and.peinf%inode.eq.dest) then
              do ik=1,wfnk%nkpt
                wfnkoff%zk((ii-1)*wfnk%nkpt+ik,jspinor)= &
                  wfnk%zk((jso-1)*wfnk%nkpt+ik,ispin*jspinor)
              enddo
            else
#ifdef MPI
              tag=1024
              if (peinf%inode.eq.source) call MPI_Send &
                (wfnk%zk((jso-1)*wfnk%nkpt+1,ispin*jspinor),wfnk%nkpt, &
                MPI_SCALAR,dest,tag,MPI_COMM_WORLD,mpierr)
                if (peinf%inode.eq.dest) call MPI_Recv &
                  (wfnkoff%zk((ii-1)*wfnk%nkpt+1,jspinor),wfnk%nkpt, &
                    MPI_SCALAR,source,tag,MPI_COMM_WORLD,mpistatus,mpierr)
#else
              do ik=1,wfnk%nkpt
                wfnkoff%zk((ii-1)*wfnk%nkpt+ik,jspinor)= &
                  wfnk%zk((jso-1)*wfnk%nkpt+ik,ispin*jspinor)
              enddo
#endif
            endif
          enddo
        enddo
      enddo ! jspinor

! (gsm) end gathering wavefunctions over pools

      if (mod(peinf%inode,peinf%npes/peinf%npools).eq.0) then

        do jspinor=1,spinordim
          js = jspinor ! for shortening lengthy call to put_into_fftbox
          call put_into_fftbox(wfnk%nkpt,wfnkoff%zk(1:,jspinor),gvec%components,wfnk%isrtk,fftbox2(:,:,:,jspinor),Nfft)
          call do_FFT(fftbox2(:,:,:,jspinor),Nfft,1)
          call conjg_fftbox(fftbox2(:,:,:,jspinor),Nfft)
          call put_into_fftbox(wfnk%nkpt,wfnkoff%zk(wfnk%nkpt+1:,js),gvec%components,wfnk%isrtk,fftbox3(:,:,:,js),Nfft)
          call do_FFT(fftbox3(:,:,:,jspinor),Nfft,1)
        enddo

        do jspinor=1,spinordim
          do iz=1,Nfft(3)
            do iy=1,Nfft(2)
              do ix=1,Nfft(1)
                temp = temp + fftbox1(ix,iy,iz,jspinor)*fftbox2(ix,iy,iz,jspinor)*fftbox3(ix,iy,iz,jspinor)
                if(jspinor.eq.2) then
                  tempxy = tempxy + fftbox1(ix,iy,iz,4)*dble(fftbox2(ix,iy,iz,2))*dble(fftbox3(ix,iy,iz,1))
                  tempxy = tempxy + fftbox1(ix,iy,iz,3)*dble(fftbox2(ix,iy,iz,1))*dble(fftbox3(ix,iy,iz,2))
                endif
              enddo
            enddo
          enddo
        enddo

      endif
      if (peinf%flag_offdiag(ioff)) then
        alda(peinf%index_offdiag(ioff)+sig%ndiag,ispin)=(temp + tempxy)*scale*ryd
      endif
    enddo ! ioff
  
  enddo ! ispin

! Deallocate temporary wavefunction array

  if (sig%noffdiag.gt.0) then
    if (mod(peinf%inode,peinf%npes/peinf%npools).eq.0) then
      SAFE_DEALLOCATE_P(wfnkoff%zk)
    end if
  endif

! Deallocate FFT boxes

  SAFE_DEALLOCATE(gvecindex)
  SAFE_DEALLOCATE(fftbox1)
  SAFE_DEALLOCATE(fftbox2)
  SAFE_DEALLOCATE(fftbox3)
  

! If MPI add up all the work done in parallel


#ifdef MPI
  SAFE_ALLOCATE(alda2, (sig%ndiag+sig%noffdiag,sig%nspin))
  alda2=0.0d0
  call MPI_Allreduce(alda(1,1),alda2(1,1),(sig%ndiag+sig%noffdiag)*sig%nspin, &
    MPI_SCALAR,MPI_SUM,MPI_COMM_WORLD,mpierr)
  alda(:,:)=alda2(:,:)
  SAFE_DEALLOCATE(alda2)
#endif

  POP_SUB(mtxel_vxc)

  return
end subroutine mtxel_vxc
