!! Copyright (C) 2011-2014 M. Oliveira
!! Copyright (C) 2012 T. Cerqueira
!! Copyright (C) 2014 P. Borlido
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module states_batch_m
  use global_m
  use oct_parser_m
  use messages_m
  use gsl_interface_m
  use units_m
  use io_m
  use output_m
  use quantum_numbers_m
  use mesh_m
  use ps_io_m
  use potentials_m
  use wave_equations_m
  use eigensolver_m
  use states_m
  implicit none


                    !---Interfaces---!

  interface assignment (=)
    module procedure states_batch_copy
  end interface


                    !---Derived Data Types---!

  type states_batch_t
    private
    type(state_ptr), pointer :: states(:) !< Pointers to the states
    integer :: n_states                   !< Number of states in the batch
  end type states_batch_t

  type state_ptr
    type(state_t), pointer :: ptr ! Pointer to the state
  end type state_ptr


                    !---Global Variables---!

  integer, parameter :: SORT_EV = 1, &
                        SORT_QN = 2

  integer, parameter :: OCC_FIXED           = 0, &
                        OCC_SEMICONDUCTING  = 1, &
                        OCC_AVERILL_PAINTER = 2

  integer, parameter :: INDEPENDENT_PARTICLES = 1, &
                        DFT                   = 2, &
                        HARTREE_FOCK          = 3, &
                        HYBRID                = 4


                    !---Public/Private Statements---!

  private
  public :: states_batch_t, &
            states_batch_null, &
            states_batch_init, &
            states_batch_end, &
            states_batch_deallocate, &
            assignment(=), &
            states_batch_size, &
            states_batch_add, &
            states_batch_get, &
            states_batch_number_of_folds, &
            states_batch_eigenvalues, &
            states_batch_split_folds, &
            states_batch_density, &
            states_batch_density_grad, &
            states_batch_density_lapl, &
            states_batch_charge_density, &
            states_batch_magnetization_density, &
            states_batch_tau, &
            states_batch_density_moment, &
            states_batch_charge, &
            states_batch_max_charge, &
            states_batch_ip, &
            states_batch_eigensolve, &
            states_batch_sort, &
            states_batch_orthogonalize, &
            states_batch_smearing, &
            states_batch_psp_generation, &
            states_batch_ld_test, &
            states_batch_output_configuration, &
            states_batch_output_eigenvalues, &
            states_batch_output_density, &
            states_batch_ps_io_set, &
            states_batch_calc_hf_terms, &
            states_batch_exchange_energy, &
            OCC_FIXED, OCC_SEMICONDUCTING, OCC_AVERILL_PAINTER, &
            SORT_EV, SORT_QN, &
            INDEPENDENT_PARTICLES, DFT, HARTREE_FOCK, HYBRID


contains

  !-----------------------------------------------------------------------
  !>  Nullifies and sets to zero all the components of the batch.         
  !-----------------------------------------------------------------------
  subroutine states_batch_null(batch)
    type(states_batch_t), intent(out) :: batch

    call push_sub("states_batch_null")

    batch%n_states = 0
    nullify(batch%states)

    call pop_sub()
  end subroutine states_batch_null

  !-----------------------------------------------------------------------
  !> Initializes a batch from an array of states.                         
  !-----------------------------------------------------------------------
  subroutine states_batch_init(batch, n_states, states)
    type(states_batch_t),  intent(inout) :: batch     !< the batch to be initialized
    integer,               intent(in)    :: n_states  !< the number of states
    type(state_t), target, intent(in)    :: states(:) !< the array of states

    integer :: i

    call push_sub("states_batch_init")

    batch%n_states = n_states
    allocate(batch%states(n_states))
    do i = 1, n_states
      batch%states(i)%ptr => states(i)
    end do

    call pop_sub()
  end subroutine states_batch_init

  !-----------------------------------------------------------------------
  !> Frees all memory associated to a batch                               
  !-----------------------------------------------------------------------
  subroutine states_batch_end(batch)
    type(states_batch_t), intent(inout) :: batch

    call push_sub("states_batch_end")

    if (associated(batch%states)) deallocate(batch%states)
    batch%n_states = 0

    call pop_sub()
  end subroutine states_batch_end

  !-----------------------------------------------------------------------
  !> Frees all memory associated to the states that are pointed to by the 
  !> batch and then frees the memory associated to the batch object itself
  !-----------------------------------------------------------------------
  subroutine states_batch_deallocate(batch)
    type(states_batch_t), intent(inout) :: batch

    integer :: i

    call push_sub("states_batch_deallocate")

    if (associated(batch%states)) then
      do i = 1, batch%n_states
        call state_end(batch%states(i)%ptr)
      end do
      deallocate(batch%states)
    end if
    batch%n_states = 0
      
    call pop_sub()
  end subroutine states_batch_deallocate

  !-----------------------------------------------------------------------
  !> Copy batch_b to batch_a                                              
  !-----------------------------------------------------------------------
  subroutine states_batch_copy(batch_a, batch_b)
    type(states_batch_t),  intent(inout) :: batch_a
    type(states_batch_t),  intent(in)    :: batch_b

    integer :: i

    call push_sub("states_batch_copy")

    call states_batch_end(batch_a)

    batch_a%n_states = batch_b%n_states
    allocate(batch_a%states(batch_a%n_states))
    do i = 1, batch_a%n_states
      batch_a%states(i)%ptr => batch_b%states(i)%ptr
    end do

    call pop_sub()
  end subroutine states_batch_copy

  !-----------------------------------------------------------------------
  !> Return the number of states in the batch.                            
  !-----------------------------------------------------------------------
  function states_batch_size(batch)
    type(states_batch_t), intent(in) :: batch
    integer :: states_batch_size

    call push_sub("states_batch_size")

    states_batch_size = batch%n_states

    call pop_sub()
  end function states_batch_size

  !-----------------------------------------------------------------------
  !> Add a state to a bacth.                                              
  !-----------------------------------------------------------------------
  subroutine states_batch_add(batch, state)
    type(states_batch_t),  intent(inout) :: batch !< the batch of states
    type(state_t), target, intent(in)    :: state !< the state to be added to the batch

    type(state_ptr), allocatable :: tmp(:)

    call push_sub("states_batch_add")

    if (batch%n_states == 0) then
      batch%n_states = 1
      allocate(batch%states(batch%n_states))
    else
      allocate(tmp(batch%n_states))
      tmp = batch%states
      deallocate(batch%states)
      allocate(batch%states(batch%n_states + 1))
      batch%states(1:batch%n_states) = tmp(1:batch%n_states)
      batch%n_states = batch%n_states + 1
    end if
    batch%states(batch%n_states)%ptr => state

    call pop_sub()
  end subroutine states_batch_add

  !-----------------------------------------------------------------------
  !> Returns a pointer to the ith state of the batch.                     
  !-----------------------------------------------------------------------
  function states_batch_get(batch, i)
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: i
    type(state_t), pointer :: states_batch_get

    call push_sub("states_batch_get")

    ASSERT(i > 0 .and. i <= batch%n_states)
    states_batch_get => batch%states(i)%ptr

    call pop_sub()
  end function states_batch_get

  !-----------------------------------------------------------------------
  !> Returns a pointer to the ith state of the batch.                     
  !-----------------------------------------------------------------------
  function states_batch_eigenvalues(batch)
    type(states_batch_t), intent(in) :: batch
    real(R8) :: states_batch_eigenvalues(batch%n_states)

    integer :: i

    call push_sub("states_batch_eigenvalues")

    do i = 1, batch%n_states
      states_batch_eigenvalues(i) = state_eigenvalue(batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_eigenvalues

  !-----------------------------------------------------------------------
  !> Given a batch, return the number of quantum number folds that exist  
  !> in the bacth (see quantum_numbers_m for the definition of a fold).   
  !-----------------------------------------------------------------------
  function states_batch_number_of_folds(batch, theory_level, polarized)
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: theory_level
    logical,              intent(in) :: polarized  
    integer :: states_batch_number_of_folds

    integer :: i
    type(qn_t) :: qn
    logical, allocatable :: checked(:)

    call push_sub("states_batch_number_of_folds")

    states_batch_number_of_folds = 0

    if (batch%n_states == 0) return

    select case (theory_level)
    case (HARTREE_FOCK, HYBRID)
      ! In Hartree-Fock, the potential felt by each state is different, so                                                                                  
      ! each state belongs to its own fold  
      states_batch_number_of_folds = batch%n_states

    case default

      allocate(checked(batch%n_states))
      checked = .false.
      do
        if (all(checked)) exit

        do i = 1, batch%n_states
          if (.not. checked(i)) then
            states_batch_number_of_folds = states_batch_number_of_folds + 1
            qn = state_qn(batch%states(i)%ptr)
            exit
          end if
        end do

        do i = 1, batch%n_states
          if (.not. checked(i) .and. qn_equal_fold(qn, state_qn(batch%states(i)%ptr), polarized)) then
            checked(i) = .true.
          end if
        end do
      end do

      ASSERT(all(checked))
      deallocate(checked)

    end select

    call pop_sub()
  end function states_batch_number_of_folds

  !-----------------------------------------------------------------------
  !> Computes the electronic density associated with a batch of states.   
  !-----------------------------------------------------------------------
  function states_batch_density(batch, nspin, mesh)
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: mesh
    real(R8) :: states_batch_density(mesh%np, nspin)

    integer :: i

    call push_sub("states_batch_density")

    states_batch_density = M_ZERO
    do i = 1, batch%n_states
      states_batch_density = states_batch_density + &
                             state_density(nspin, batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_density

  !-----------------------------------------------------------------------
  !> Computes the gradient of the electronic density associated with a    
  !> batch of states.                                                     
  !-----------------------------------------------------------------------
  function states_batch_density_grad(batch, nspin, mesh)
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: mesh
    real(R8) :: states_batch_density_grad(mesh%np, nspin)

    integer :: i

    call push_sub("states_batch_density_grad")

    states_batch_density_grad = M_ZERO
    do i = 1, batch%n_states
      states_batch_density_grad = states_batch_density_grad + &
                               state_density_grad(nspin, batch%states(i)%ptr, mesh)
    end do

    call pop_sub()
  end function states_batch_density_grad

  !-----------------------------------------------------------------------
  !> Computes the laplacian of the electronic density associated with a   
  !> batch of states.                                                     
  !-----------------------------------------------------------------------
  function states_batch_density_lapl(batch, nspin, mesh)
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: mesh
    real(R8) :: states_batch_density_lapl(mesh%np, nspin)

    integer :: i

    call push_sub("states_batch_density_lapl")

    states_batch_density_lapl = M_ZERO
    do i = 1, batch%n_states
      states_batch_density_lapl = states_batch_density_lapl + &
                               state_density_lapl(nspin, batch%states(i)%ptr, mesh)
    end do

    call pop_sub()
  end function states_batch_density_lapl

  !-----------------------------------------------------------------------
  !> Computes the charge density associated with a batch of states.       
  !-----------------------------------------------------------------------
  function states_batch_charge_density(batch, mesh)
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: mesh
    real(R8) :: states_batch_charge_density(mesh%np)

    integer :: i

    call push_sub("states_batch_charge_density")

    states_batch_charge_density = M_ZERO
    do i = 1, batch%n_states
      states_batch_charge_density = states_batch_charge_density + &
                             state_charge_density(batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_charge_density

  !-----------------------------------------------------------------------
  !> Computes the magnetization density associated with a batch of states.
  !-----------------------------------------------------------------------
  function states_batch_magnetization_density(batch, mesh)
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: mesh
    real(R8) :: states_batch_magnetization_density(mesh%np)

    integer :: i

    call push_sub("states_batch_magnetization_density")

    states_batch_magnetization_density = M_ZERO
    do i = 1, batch%n_states
      states_batch_magnetization_density = states_batch_magnetization_density + &
                             state_magnetization_density(batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_magnetization_density

  !-----------------------------------------------------------------------
  !> Computes the kinetic energy density associated with a batch of states
  !-----------------------------------------------------------------------
  function states_batch_tau(batch, nspin, mesh)
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: mesh
    real(R8) :: states_batch_tau(mesh%np, nspin)

    integer :: i

    call push_sub("states_batch_tau")

    states_batch_tau = M_ZERO
    do i = 1, batch%n_states
      states_batch_tau = states_batch_tau + &
                               state_tau(nspin, batch%states(i)%ptr, mesh)
    end do

    call pop_sub()
  end function states_batch_tau

  !-----------------------------------------------------------------------
  !> Computes the density moment:
  !>
  !>  \f$\sum_i \left< R_i(r) | r^{n} | R_i(r) \right> \f$
  !>
  !> of a given batch of states.                                           
  !-----------------------------------------------------------------------
  function states_batch_density_moment(batch, mesh, order)
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: mesh
    integer,              intent(in) :: order
    real(R8) :: states_batch_density_moment

    integer :: i

    call push_sub("states_batch_density_moment")

    states_batch_density_moment = M_ZERO
    do i = 1, batch%n_states
      states_batch_density_moment = states_batch_density_moment + &
                             state_density_moment(batch%states(i)%ptr, mesh, order)
    end do

  end function states_batch_density_moment

  !-----------------------------------------------------------------------
  !> Given a batch, splits the states into folds (see quantum_numbers_m   
  !> for the definition of a fold).                                       
  !-----------------------------------------------------------------------
  subroutine states_batch_split_folds(batch, folds, theory_level, polarized)
    type(states_batch_t), intent(in)    :: batch
    type(states_batch_t), intent(inout) :: folds(:)
    integer,              intent(in)    :: theory_level
    logical,              intent(in)    :: polarized

    integer :: i, n
    type(qn_t) :: qn
    logical, allocatable :: checked(:)

    call push_sub("states_batch_split_folds")

    ASSERT(batch%n_states > 0)

    select case (theory_level)
    case (HARTREE_FOCK, HYBRID)
      ! In Hartree-Fock, the potential felt by each state is different, so                                                                                  
      ! each state belongs to its own fold  
      do i = 1, batch%n_states
        call states_batch_add(folds(i), batch%states(i)%ptr)
      end do

    case default
      n = 0
      allocate(checked(batch%n_states))
      checked = .false.
      do
        if (all(checked)) exit
        n = n + 1
        call states_batch_end(folds(n))
        
        do i = 1, batch%n_states
          if (.not. checked(i)) then
            qn = state_qn(batch%states(i)%ptr)
            exit
          end if
        end do

        do i = 1, batch%n_states
          if (.not. checked(i) .and. qn_equal_fold(qn, state_qn(batch%states(i)%ptr), polarized)) then
            checked(i) = .true.
            call states_batch_add(folds(n), batch%states(i)%ptr)
          end if
        end do
      end do

      ASSERT(all(checked))
      deallocate(checked)

    end select

    call pop_sub()
  end subroutine states_batch_split_folds

  !-----------------------------------------------------------------------
  !> Returns the number of electons in a batch.                           
  !-----------------------------------------------------------------------
  function states_batch_charge(batch)
    type(states_batch_t), intent(in) :: batch
    real(R8) :: states_batch_charge

    integer :: i

    call push_sub("states_batch_charge")

    states_batch_charge = M_ZERO
    do i = 1, batch%n_states
      states_batch_charge = states_batch_charge + state_charge(batch%states(i)%ptr)
    end do

    call pop_sub()
  end function states_batch_charge

  !-----------------------------------------------------------------------
  !> Returns the maximum number of electons the batch can hold.           
  !-----------------------------------------------------------------------
  function states_batch_max_charge(batch)
    type(states_batch_t), intent(in) :: batch
    real(R8) :: states_batch_max_charge

    integer :: i

    call push_sub("states_batch_max_charge")

    states_batch_max_charge = M_ZERO
    do i = 1, batch%n_states
      states_batch_max_charge = states_batch_max_charge + qn_max_occ(state_qn(batch%states(i)%ptr))
    end do

    call pop_sub()
  end function states_batch_max_charge

  !-----------------------------------------------------------------------
  !> Returns the ionization potential of a batch (defined as minus the    
  !> eigenvalue of the highest occupied orbital) for each spin-channel.   
  !-----------------------------------------------------------------------
  function states_batch_ip(batch, nspin)
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    real(R8) :: states_batch_ip(nspin)

    integer :: i, is
    real(R8) :: ev, charge
    type(qn_t) :: qn

    call push_sub("states_batch_ip")

    !Get ionization potential
    states_batch_ip = M_ZERO
    do i = 1, batch%n_states
      qn = state_qn(batch%states(i)%ptr)
      ev = state_eigenvalue(batch%states(i)%ptr)
      charge = state_charge(batch%states(i)%ptr)

      is = 1
      if (qn%s == M_HALF) is = 2
      
      if (charge /= M_ZERO) then
        if (states_batch_ip(is) == M_ZERO) then
          states_batch_ip(is) =  - ev
        elseif (-states_batch_ip(is) < ev) then
          states_batch_ip(is) = - ev
        end if
      end if
    end do

    call pop_sub()
  end function states_batch_ip

  !-----------------------------------------------------------------------
  !> Get the eigenvalues and eigenfunctions of a batch of states for a    
  !> given potential.                                                     
  !-----------------------------------------------------------------------
  subroutine states_batch_eigensolve(batch, mesh, wave_eq, potential, integrator_dp, integrator_sp, eigensolver)
    type(states_batch_t), intent(inout) :: batch         !< bacth of states
    type(mesh_t),         intent(in)    :: mesh          !< mesh
    integer,              intent(in)    :: wave_eq       !< wave-equation to solve
    type(potential_t),    intent(in)    :: potential     !< potential to use in the wave-equation
    type(integrator_t),   intent(inout) :: integrator_dp !< single-precision integrator object
    type(integrator_t),   intent(inout) :: integrator_sp !< double-precision integrator object
    type(eigensolver_t),  intent(in)    :: eigensolver   !< information about the eigensolver

    integer :: i, max_n_bound
    real(R8) :: ev
    type(qn_t) :: qn
    type(states_batch_t) :: bound
    logical, allocatable :: bracketed(:)
    real(R8), allocatable :: brackets(:,:)
    type(qn_t), allocatable :: qns(:)

    call push_sub("states_batch_eigensolve")

    !Batch should be ordered by increasing quantum numbers
    call states_batch_sort(batch, SORT_QN)

    !Check the maximum number of bound states we can have
    qn = state_qn(batch%states(1)%ptr)
    max_n_bound = wavefunctions_n_bound_states(qn, wave_eq, mesh, &
                                           potential, integrator_sp)
    !Remove states that we are sure are unbound from the list
    call states_batch_null(bound)
    do i = 1, batch%n_states
      qn = state_qn(batch%states(i)%ptr)

      if (qn%n - qn%l <= max_n_bound) then
        call states_batch_add(bound, batch%states(i)%ptr)
      else
        if (state_charge(batch%states(i)%ptr) == M_ZERO) then
          !Unbound state
          call state_update(batch%states(i)%ptr, mesh, wave_eq, potential, integrator_dp, ev=M_ZERO)
        else
          !There is something wrong
          write(message(1),'("State: ",A," is unbound")') &
               trim(state_label(batch%states(i)%ptr, .true.))
          call write_fatal(1)
        end if
      end if
    end do

    !Bracket the eigenvalues
    allocate(brackets(2, bound%n_states), qns(bound%n_states), bracketed(bound%n_states))
    do i = 1, bound%n_states
      qns(i) = state_qn(bound%states(i)%ptr)
    end do
    call eigensolver_bracket(bound%n_states, qns, wave_eq, eigensolver, &
                             potential, integrator_sp, brackets, bracketed)

    !Find eigenvalues
    do i = 1, bound%n_states
      if (bracketed(i)) then
        call eigensolver_find_ev(qns(i), wave_eq, eigensolver, potential, &
                                 integrator_dp, brackets(:,i), ev)
      else if (.not. bracketed(i) .and. state_charge(bound%states(i)%ptr) == M_ZERO) then
        !This is probably and unbound state after all
        ev = M_ZERO
      else if (.not. bracketed(i) .and. state_charge(bound%states(i)%ptr) /= M_ZERO) then
        !There is something wrong
        message(1) = "Unable to bracket eigenvalues for state:"
        write(message(2),'(2X,A)') trim(qn_label(qns(i), .true.))
        call write_fatal(2)
      end if


      !Update the state with the new eigenvalue
      call state_update(bound%states(i)%ptr, mesh, wave_eq, potential, integrator_dp, ev)
    end do

    !Free memory
    deallocate(brackets, bracketed, qns)
    call states_batch_end(bound)

    call pop_sub()
  end subroutine states_batch_eigensolve

  !-----------------------------------------------------------------------
  !> Sort the states in the bacth.                                        
  !-----------------------------------------------------------------------
  subroutine states_batch_sort(batch, criteria)
    type(states_batch_t), intent(inout) :: batch    !< batch of states
    integer,              intent(in)    :: criteria !< criteria used to sort the states

    integer :: i, dum(1)
    logical, allocatable :: mask(:), qn_mask(:)
    integer, allocatable :: order(:)
    real(R8), allocatable :: ev(:)
    type(qn_t), allocatable :: qn(:)
    type(state_ptr), allocatable :: tmp(:)

    call push_sub("states_batch_sort")

    allocate(mask(batch%n_states), order(batch%n_states), tmp(batch%n_states))

    !sort states
    mask = .true.
    select case (criteria)
    case (SORT_EV)
      allocate(ev(batch%n_states))
      do i = 1, batch%n_states
        ev(i) = state_eigenvalue(batch%states(i)%ptr)
      end do

      do i = 1, batch%n_states
        dum = maxloc(abs(ev), mask=mask)
        order(i) = dum(1)
        mask(order(i)) = .false.
      end do
      deallocate(ev)

    case (SORT_QN)
      !Quantum numbers by increasing n,l,j,m,s and decreasing sg
      !(the algorithm is not very efficient, but we should never have to deal
      ! with a very large number of states)
      allocate(qn(batch%n_states))
      do i = 1, batch%n_states
        qn(i) = state_qn(batch%states(i)%ptr)
      end do

      allocate(qn_mask(batch%n_states))
      do i = 1, batch%n_states
        qn_mask = qn%n == minval(qn%n, mask)
        qn_mask = qn_mask .and. qn%l == minval(qn%l, mask .and. qn_mask)
        qn_mask = qn_mask .and. qn%j == minval(qn%j, mask .and. qn_mask)
        qn_mask = qn_mask .and. qn%m == minval(qn%m, mask .and. qn_mask)
        qn_mask = qn_mask .and. qn%s == minval(qn%s, mask .and. qn_mask)
        dum = maxloc(qn%sg, mask = mask .and. qn_mask)
        order(i) = dum(1)
        mask(order(i)) = .false.
      end do
      deallocate(qn, qn_mask)

    end select

    tmp = batch%states
    do i = 1, batch%n_states
      batch%states(i) = tmp(order(i))
    end do

    deallocate(mask, order, tmp)

    call pop_sub()
  end subroutine states_batch_sort

  !-----------------------------------------------------------------------
  !> 
  !-----------------------------------------------------------------------
  subroutine states_batch_orthogonalize(mesh, batch)
    type(mesh_t),         intent(in)    :: mesh
    type(states_batch_t), intent(inout) :: batch !< batch of states

    integer :: ia, ib
    real(R8) :: dot_product
    type(state_t), pointer :: state_a, state_b

    do ia = 1, batch%n_states
      state_a => states_batch_get(batch, ia)

      state_a%wf = state_a%wf/sqrt(state_dot_product(mesh, state_a, state_a))

      do ib = ia+1, batch%n_states
        state_b => states_batch_get(batch, ib)
       
        dot_product = state_dot_product(mesh, state_a, state_b)
        state_b%wf = state_b%wf - dot_product*state_a%wf
        state_b%wfp = state_b%wfp - dot_product*state_a%wfp
      end do
    end do

  end subroutine states_batch_orthogonalize

  !-----------------------------------------------------------------------
  !> Distributes the electrons among the states.                          
  !-----------------------------------------------------------------------
  subroutine states_batch_smearing(batch, smearing_function, tol, new_charge)
    type(states_batch_t), intent(inout) :: batch             !< batch of states
    integer,              intent(in)    :: smearing_function !< how should the electrons be distributed
    real(R8),             intent(in)    :: tol               !< 
    real(R8), optional,   intent(in)    :: new_charge        !< 

    integer :: n_states, is, i, qi, qf, k
    real(R8) :: charge_diff, total_charge, lambda, lambda_in, sum_evdiff
    integer,  allocatable :: order(:)
    real(R8), allocatable :: ev(:), occ(:), new_occ(:), max_occ(:), gamma(:)

    call push_sub("states_batch_smearing")

    n_states = batch%n_states

    !Sort states by increasing eigenvalue
    call states_batch_sort(batch, SORT_EV)

    !Allocate working arrays
    allocate(ev(batch%n_states), order(n_states))
    allocate(occ(n_states), max_occ(n_states), new_occ(n_states))

    !eigenvalues, occupancies, and maximum occupancies for all the states in a single vector
    do i = 1, n_states
      ev(i) = state_eigenvalue(batch%states(i)%ptr)
      occ(i) = state_charge(batch%states(i)%ptr)
      max_occ(i) = qn_max_occ(state_qn(batch%states(i)%ptr))
      order(i) = i
    end do

    !Order all arrays by decreasing occupancy
    if (n_states /= 1) then
      i = 1
      do
        i = i + 1
        if (ev(i-1) == ev(i) .and. occ(order(i-1)) < occ(order(i))) then
          is = order(i)
          order(i) = order(i-1)
          order(i-1) = is
          i = 1
        end if
        if (i == n_states) exit
      end do
    end if

    do i = 1, n_states
      new_occ(i) = occ(order(i))
    end do
    occ = new_occ
    do i = 1, n_states
      new_occ(i) = max_occ(order(i))
    end do
    max_occ = new_occ

    !
    select case (smearing_function)
    case (OCC_FIXED)
      ! Nothing to do unless the total charge changed
      if (present(new_charge)) then
        charge_diff = new_charge - states_batch_charge(batch)

        if (charge_diff > M_ZERO) then
          ! Add charge
          do i = 1, n_states
            if (charge_diff == M_ZERO) then
              exit
            else if (occ(i) + charge_diff > max_occ(i)) then
              charge_diff = charge_diff - (max_occ(i) - occ(i))
              occ(i) = max_occ(i)
            else
              occ(i) = occ(i) + charge_diff
              charge_diff = M_ZERO
            end if
          end do
        end if

        if (charge_diff < M_ZERO) then
          ! Remove charge
          do i = n_states, 1, -1
            if (charge_diff == M_ZERO) then
              exit
            else if (occ(i) + charge_diff < M_ZERO) then
              charge_diff = charge_diff + occ(i)
              occ(i) = M_ZERO
            else
              occ(i) = occ(i) + charge_diff
              charge_diff = M_ZERO
            end if
          end do
        end if

      end if

    case (OCC_SEMICONDUCTING)

      if (present(new_charge)) then
        total_charge = new_charge
      else
        total_charge = states_batch_charge(batch)
      end if

      do i = 1, n_states
        if (total_charge == M_ZERO) then
          occ(i) = M_ZERO
        elseif (total_charge - max_occ(i) < M_ZERO) then
          occ(i) = total_charge
        else
          occ(i) = max_occ(i)
        end if
        total_charge = total_charge - occ(i)
      end do

    case (OCC_AVERILL_PAINTER)

      !Allocate memory and initialize some quantities
      allocate(gamma(n_states))
      lambda_in = M_HALF
      gamma = M_ONE
      total_charge = states_batch_charge(batch)

      !Remove states that will not participate in the charge transfer
      do i = 1, n_states
        if (occ(i) /= max_occ(i)) then
          qi = i
          exit
        end if
      end do
      qf = qi
      do i = n_states, qi, -1
        if (occ(i) /= M_ZERO) then
          qf = i
          exit
        end if
      end do

      !Main loop
      do
        !Exit if we have less than two states to perform the charge transfer
        if (count(gamma(qi:qf) /= M_ZERO) < 2) exit

        !By default we dont change the occupancies
        new_occ = occ

        !Choose the reference state
        if (occ(qi)/max_occ(qi) < M_HALF) then
          k = qi
        else
          k = qf
        end if

        !Compute gamma
        sum_evdiff = M_ZERO
        do i = qi, qf
          if (i == k) cycle
          if (abs(ev(i) - ev(k)) < tol) then
            gamma(i) = M_ZERO
          else if (ev(i) - ev(k) > M_ZERO) then
            gamma(i) = occ(i)/max_occ(i)/(ev(i) - ev(k))
          elseif (ev(i) - ev(k) < M_ZERO) then
            gamma(i) = (occ(i)/max_occ(i) - M_ONE)/(ev(i) - ev(k))
          end if
          if (gamma(i) /= M_ZERO) sum_evdiff = sum_evdiff + (ev(i) - ev(k))*max_occ(i)
        end do

        if (sum_evdiff > M_ZERO) then
          gamma(k) = (M_ONE - occ(k)/max_occ(k))/sum_evdiff
        elseif (sum_evdiff < M_ZERO) then
          gamma(k) = -occ(k)/max_occ(k)/sum_evdiff
        else
          gamma(k) = M_ZERO
        end if

        !If gamma(k) is zero remove this state from the list and go back to the
        !beginning of the cycle
        if (abs(gamma(k)) == M_ZERO) then
          if (k == qi) then
            qi = qi + 1
          else
            qf = qf - 1
          end if
          cycle
        end if

        !Get lambda (ignore the states that have gamma = 0)
        lambda = min(lambda_in, minval(gamma(qi:qf), mask=gamma(qi:qf) /= M_ZERO))

        !Get the new occupancies (states with gamma = 0 are ignored)
        sum_evdiff = M_ZERO
        do i = qi, qf
          if (i /= k .and. gamma(i) /= M_ZERO) then

            new_occ(i) = occ(i) - lambda*(ev(i) - ev(k))*max_occ(i)

            if (new_occ(i) < M_EPSILON) then
              new_occ(i) = M_ZERO
              sum_evdiff = sum_evdiff + occ(i)/lambda

            else if ( abs(max_occ(i) - new_occ(i)) < M_EPSILON ) then
              new_occ(i) = max_occ(i)
              sum_evdiff = sum_evdiff + (occ(i) - max_occ(i))/lambda

            else
              sum_evdiff = sum_evdiff + (ev(i) - ev(k))*max_occ(i)

            end if

          end if
        end do

        new_occ(k) = occ(k) + lambda*sum_evdiff
        if (new_occ(k) < M_EPSILON) then
          new_occ(k) = M_ZERO
        else if (max_occ(k) - new_occ(k) < M_EPSILON) then
          new_occ(k) = max_occ(k)
        end if

        occ = new_occ

        !Check if we have reached convergence
        if (lambda == lambda_in) exit

        lambda_in = lambda
      end do

      deallocate(gamma)
    end select

    !Get the final occupancies in the correct order
    do i = 1, n_states 
      call state_update_charge(batch%states(order(i))%ptr, occ(i))
    end do

    !Deallocate memory
    deallocate(ev, order, occ, max_occ, new_occ)

    call pop_sub()
  end subroutine states_batch_smearing

  !-----------------------------------------------------------------------
  !>
  !-----------------------------------------------------------------------
  subroutine states_batch_psp_generation(batch, mesh, scheme, wave_eq, tol, ae_potential, &
       integrator_sp, integrator_dp, eigensolver, rc, ps_v)
    type(states_batch_t), intent(inout) :: batch
    type(mesh_t),         intent(in)    :: mesh
    integer,              intent(in)    :: scheme, wave_eq
    real(R8),             intent(in)    :: tol
    type(potential_t),    intent(in)    :: ae_potential
    type(integrator_t),   intent(inout) :: integrator_sp, integrator_dp
    type(eigensolver_t),  intent(in)    :: eigensolver
    real(R8),             intent(in)    :: rc
    real(R8),             intent(out)   :: ps_v(mesh%np)

    integer :: n, i
    type(state_t), pointer :: state, state2
    type(potential_t) :: ps_potential
    type(states_batch_t) :: other_states

    call push_sub("states_batch_psp_generation")

    state => states_batch_get(batch, 1)
    if (scheme == MRPP .or. scheme == RMRPP) then
      n = 2
      state2 => states_batch_get(batch, 2)
      call state_psp_generation(mesh, scheme, wave_eq, tol, ae_potential, &
           integrator_sp, integrator_dp, ps_v, state, rc, state2)
    else
      n = 1
      call state_psp_generation(mesh, scheme, wave_eq, tol, ae_potential, &
           integrator_sp, integrator_dp, ps_v, state, rc)
    end if


    !We need to solve the wave-equation with the pseudopotential for the other states in the batch
    if (n < states_batch_size(batch)) then
      call states_batch_null(other_states)
      call potential_null(ps_potential)
      call potential_init(ps_potential, mesh, ps_v)

      do i = n+1, states_batch_size(batch)
        state => states_batch_get(batch, i)
        call state_update_number_of_nodes(state, i-n)
        call states_batch_add(other_states, state)
      end do

      call states_batch_eigensolve(other_states, mesh, wave_eq, ps_potential, integrator_dp, integrator_sp, eigensolver)

      call potential_end(ps_potential)
      call states_batch_end(other_states)
    end if

    call pop_sub()
  end subroutine states_batch_psp_generation

  !-----------------------------------------------------------------------
  !> Computes the logarithmic derivatives for a given fold of states.     
  !-----------------------------------------------------------------------
  subroutine states_batch_ld_test(batch, mesh, ae_potential, ps_potential, integrator, r, de, emin, emax)
    type(states_batch_t),           intent(in)    :: batch         !< batch of states (they should belong to the same fold)
    type(mesh_t),                   intent(in)    :: mesh          !< mesh
    type(potential_t),              intent(in)    :: ae_potential  !< all-electron potential
    type(potential_t),              intent(in)    :: ps_potential  !< pseudotential
    type(integrator_t),             intent(inout) :: integrator    !< integrator object
    real(R8),                       intent(in)    :: r             !< diagnostic radius at which to do the test
    real(R8),                       intent(in)    :: de            !< energy step
    real(R8),             optional, intent(in)    :: emin          !< lower bound of the energy interval
    real(R8),             optional, intent(in)    :: emax          !< upper bound of the energy interval

    integer :: unit
    real(R8) :: e, e1, e2, ae_ld, ps_ld, step, dldde
    character(len=10) :: label

    call push_sub("states_batch_ld_test")

    ASSERT(batch%n_states > 0)

    !Set energy interval
    if (present(emin)) then
      e1 = emin
    else
      e1 = minval(states_batch_eigenvalues(batch)) - M_ONE
    end if
    if (present(emax)) then
      e2 = emax
    else
      e2 = maxval(states_batch_eigenvalues(batch)) + M_ONE
    end if
    if (e1 > e2) then
      message(1) = "The minimum energy at which to calculate the logarithmic"
      message(2) = "derivatives must be smaller than the maximum energy."
      call write_fatal(2)
    end if

    !Open file
    label = state_label(batch%states(1)%ptr)
    call io_open(unit, 'tests/ld-'//trim(label(2:5)))

    !Write some information
    write(message(1), '(2X,"Computing logarithmic derivative for states:",1X,A)') &
         trim(label(2:5))
    write(message(2), '(4X,"Minimum energy: ",F6.3,1X,A)') &
         e1/units_out%energy%factor, trim(units_out%energy%abbrev)
    write(message(3), '(4X,"Maximum energy: ",F6.3,1X,A)') &
         e2/units_out%energy%factor, trim(units_out%energy%abbrev)
    call write_info(3)
    call write_info(3,unit=info_unit("tests"))

    !Write header
    write(unit,'("# Logarithmic derivatives")')
    write(unit,'("# ")')
    write(unit,'("# Energy units: ",A)') trim(units_out%energy%name)
    write(unit,'("# Length units: ",A)') trim(units_out%length%name)
    write(unit,'("#")')
    write(unit,'("# ",53("-"))')
    write(unit,'("# |",7X,"e",7X,"|",5X,"ld_ae(e)",4X,"|",5X,"ld_pp(e)",4X,"|")')
    write(unit,'("# ",53("-"))')

    !Compute the logarithmic derivatives
    e = e1
    do
      ae_ld = state_ld(batch%states(1)%ptr, e, r, integrator, ae_potential, mesh)
      if (de == M_ZERO) then
        ps_ld = state_ld(batch%states(1)%ptr, e, r, integrator, ps_potential, mesh, dldde)
      else
        ps_ld = state_ld(batch%states(1)%ptr, e, r, integrator, ps_potential, mesh)
      end if

      write(unit,'(3X,ES14.7E2,3X,ES15.8E2,3X,ES15.8E2)') e/units_out%energy%factor, ae_ld, ps_ld

      if (e == e2) exit
      if (de /= M_ZERO) then
        step = de
      else
        step = max(abs(0.1_r8/dldde), 0.01_r8)
      end if
      e = min(e + step, e2)
    end do

    close(unit)

    call pop_sub()
  end subroutine states_batch_ld_test

  !-----------------------------------------------------------------------
  !> Outputs the states label and occupancies in a nice readable format.  
  !-----------------------------------------------------------------------
  subroutine states_batch_output_configuration(batch, nspin, unit, verbose_limit)
    type(states_batch_t),           intent(in) :: batch
    integer,                        intent(in) :: nspin
    integer,              optional, intent(in) :: unit
    integer,              optional, intent(in) :: verbose_limit

    integer :: i
    character(len=4) :: nspin_label
    type(qn_t) :: qn
    type(state_t), pointer :: state 

    call push_sub("states_batch_output_configuration")

    select case (nspin)
    case (1)
      write(message(1),'(2X,"Configuration    : State   Occupation")')
    case (2)
      qn = state_qn(batch%states(1)%ptr)
      if (qn%m == M_ZERO) then
        write(message(1),'(2X,"Configuration    : State   Spin   Occupation")')
      else
        write(message(1),'(2X,"Configuration    :  State   Sigma   Occupation")')
      end if
    end select

    do i = 1, batch%n_states
      state => batch%states(i)%ptr
      qn = state_qn(state)
      select case (nspin)
      case (1)
        if (qn%j == M_ZERO) then
          write(message(i+1),'(23X,A,6X,F5.2)') trim(state_label(state)), state_charge(state)
        else
          write(message(i+1),'(21X,A,5X,F5.2)') trim(state_label(state)), state_charge(state)
        end if
      case (2)
        if (qn%m == M_ZERO) then
          if (qn%s == -M_HALF) then
            nspin_label = "-1/2"			
          else
            nspin_label = "+1/2"
          end if
          write(message(i+1),'(21X,A,4X,A,5X,F5.2)') trim(state_label(state)), &
                                                     nspin_label, state_charge(state)
        else
          if (qn%sg == M_HALF) then
            nspin_label = 'up'
          elseif (qn%sg == -M_HALF) then
            nspin_label = 'dn'
          else
            nspin_label = '--'
          end if
          write(message(i+1),'(21X,A,5X,A,4X,F5.2)') trim(state_label(state)), &
                                                     nspin_label, state_charge(state)
        end if
      end select
    end do

    call write_info(batch%n_states+1, verbose_limit=verbose_limit, unit=unit)

    call pop_sub()
  end subroutine states_batch_output_configuration

  !-----------------------------------------------------------------------
  !> Writes the eigenvalues either to a file or to the screen in a nice   
  !> readable format.                                                     
  !-----------------------------------------------------------------------
  subroutine states_batch_output_eigenvalues(batch, nspin, unit, verbose_limit)
    type(states_batch_t),           intent(in) :: batch         !< batch of states
    integer,                        intent(in) :: nspin
    integer,              optional, intent(in) :: unit
    integer,              optional, intent(in) :: verbose_limit

    integer  :: i, j
    real(R8) :: u
    type(qn_t) :: qn
    character(len=4) :: nspin_label
    type(state_t), pointer :: state

    call push_sub("states_batch_output_eigenvalues")

    write(message(1),'(4X,"Eigenvalues [",A,"]")') trim(units_out%energy%abbrev)
    select case (nspin)
    case (1)
      write(message(2),'(5X,"State   Occupation   Eigenvalue")')
    case (2)
      qn = state_qn(batch%states(1)%ptr)
      if (qn%m == M_ZERO) then
        write(message(2),'(5X,"State   Spin   Occupation   Eigenvalue")')
      else
        write(message(2),'(5X," State   Sigma   Occupation   Eigenvalue")')
      end if
    end select
		
    u = units_out%energy%factor
    j = 2
    do i = 1, batch%n_states
      j = j + 1
      state => batch%states(i)%ptr

      qn = state_qn(state)
      select case (nspin)
      case (1)
        if (qn%j == M_ZERO) then
          write(message(j),'(7X,A,6X,F5.2,4X,F12.5)') trim(state_label(state)), &
                                state_charge(state), state_eigenvalue(state)/u
        else
          write(message(j),'(5X,A,5X,F5.2,4X,F12.5)') trim(state_label(state)), &
                                state_charge(state), state_eigenvalue(state)/u
        end if
      case (2)
        if (qn%m == M_ZERO) then
          if (qn%s == -M_HALF) then
            nspin_label = "-1/2"
          else
            nspin_label = "+1/2"
          end if
          write(message(j),'(7X,A,4X,A,5X,F5.2,4X,F12.5)') trim(state_label(state)), nspin_label, &
                                                           state_charge(state), state_eigenvalue(state)/u
        else
          if (qn%sg == M_HALF) then
            nspin_label = 'up'
          elseif (qn%sg == -M_HALF) then
            nspin_label = 'dn'
          else
            nspin_label = '--'
          end if
          write(message(j),'(5X,A,4X,A,5X,F5.2,4X,F12.5)')  trim(state_label(state)), nspin_label, &
                                                            state_charge(state), state_eigenvalue(state)/u
        end if
      end select
    end do

    call write_info(j, verbose_limit=verbose_limit, unit=unit)

    call pop_sub()
  end subroutine states_batch_output_eigenvalues

  !-----------------------------------------------------------------------
  !> Writes the electronic density to the "dir/density" file in a format  
  !> suitable for plotting.                                               
  !-----------------------------------------------------------------------
  subroutine states_batch_output_density(batch, nspin, mesh, dir)
    type(states_batch_t), intent(in) :: batch
    integer,              intent(in) :: nspin
    type(mesh_t),         intent(in) :: mesh
    character(len=*),     intent(in) :: dir

    integer  :: i, p, unit
    real(R8) :: u
    character(len=30) :: filename
    real(R8), allocatable :: rho(:,:), rho_grad(:,:), rho_lapl(:,:), tau(:,:)

    call push_sub("states_batch_output_density")

    !Get the density and its derivatives
    allocate(rho(mesh%np, nspin), rho_grad(mesh%np, nspin))
    allocate(rho_lapl(mesh%np, nspin), tau(mesh%np, nspin))
    rho      = states_batch_density     (batch, nspin, mesh)
    rho_grad = states_batch_density_grad(batch, nspin, mesh)
    rho_lapl = states_batch_density_lapl(batch, nspin, mesh)
    tau      = states_batch_tau         (batch, nspin, mesh)

    u = units_out%length%factor

    !Density, gradient, and laplacian
    do p = 1, nspin
      if (nspin == 1) then
        filename = trim(dir)//"/density"
      else
        if (p == 1) then
          filename = trim(dir)//"/density_dn"
        elseif (p == 2) then
          filename = trim(dir)//"/density_up"
        end if
      end if
      call io_open(unit, file=trim(filename))

      write(unit,'("#")')
      write(unit,'("# Radial density.")')
      write(unit,'("# Length units: ",A)') trim(units_out%length%name)
      write(unit,'("#")')

      write(unit,'("# ",71("-"))')
      write(unit,'("# |",7X,"r",7X,"|",6X,"n(r)",7X,"|",4X,"grad_n(r)",4X,"|",4X,"lapl_n(r)",4X,"|")')
      write(unit,'("# ",71("-"))')
      do i = 1, mesh%np
        write(unit,'(3X,ES14.8E2,3X,ES15.8E2,3X,ES15.8E2,3X,ES15.8E2)') &
             mesh%r(i)/u, rho(i, p)*u, rho_grad(i, p)*u**2, rho_lapl(i, p)*u**3
      end do

      close(unit)
    end do

    !Print kinetic energy density
    do p = 1, nspin
      if (nspin == 1) then
        filename = trim(dir)//"/tau"
      else
        if (p == 1) then
          filename = trim(dir)//"/tau_dn"
        elseif (p == 2) then
          filename = trim(dir)//"/tau_up"
        end if
      end if
      call io_open(unit, file=trim(filename))

      write(unit,'("#")')
      write(unit,'("# Radial kinetic energy density.")')
      write(unit,'("# Length units: ",A)') trim(units_out%length%name)
      write(unit,'("# Energy units: ",A)') trim(units_out%energy%name)
      write(unit,'("#")')

      write(unit,'("# ",35("-"))')
      write(unit,'("# |",7X,"r",7X,"|",5X,"tau(r)",6X,"|")')
      write(unit,'("# ",35("-"))')
      do i = 1, mesh%np
        write(unit,'(3X,ES14.8E2,3X,ES15.8E2)') mesh%r(i)/u, &
                                              tau(i, p)*u/units_out%energy%factor
      end do

      close(unit)
    end do

    !Free memory
    deallocate(rho, rho_grad, rho_lapl, tau)

    call pop_sub()
  end subroutine states_batch_output_density

  !-----------------------------------------------------------------------
  !> Pass the information about the wavefunctions to the ps_io module.    
  !-----------------------------------------------------------------------
  subroutine states_batch_ps_io_set(batch, mesh, rc)
    type(states_batch_t), intent(in) :: batch
    type(mesh_t),         intent(in) :: mesh
    real(R8),             intent(in) :: rc(batch%n_states)

    integer :: i
    character(len=10) :: label
    type(qn_t) :: qn
    integer, allocatable :: n(:), l(:)
    real(R8), allocatable :: j(:), occ(:), ev(:), wfs(:,:), rho_val(:,:)
    type(state_t), pointer :: state

    call push_sub("states_batch_ps_io_set")

    allocate(n(batch%n_states), l(batch%n_states), j(batch%n_states))
    allocate(occ(batch%n_states), ev(batch%n_states))
    allocate(wfs(mesh%np, batch%n_states))
    allocate(rho_val(mesh%np, 1))

    rho_val = states_batch_density(batch, 1, mesh)
    do i = 1, batch%n_states
      state => batch%states(i)%ptr
      qn = state_qn(state)
      label = state_label(state)
      read(label,'(I1)') n(i)
      l(i) = qn%l
      j(i) = qn%j
      occ(i) = state_charge(state)
      ev(i) = state_eigenvalue(state)
      wfs(:, i) = state%wf(:,1)
    end do

    call ps_io_set_wfs(mesh%np, batch%n_states, n, l, j, occ, ev, rc, wfs, &
                       sum(rho_val, dim=2))

    deallocate(n, l, j, occ, ev, wfs, rho_val)

    call pop_sub()
  end subroutine states_batch_ps_io_set

  !-----------------------------------------------------------------------
  !> Returns the Hartree-Fock potential and non homogeneous part.
  !-----------------------------------------------------------------------
  subroutine states_batch_calc_hf_terms(batch, mesh, z, qns, v, n)
    type(states_batch_t), intent(in)  :: batch
    type(mesh_t),         intent(in)  :: mesh
    real(R8),             intent(in)  :: z
    type(qn_t),           intent(out) :: qns(batch%n_states)
    real(R8),             intent(out) :: v(mesh%np, batch%n_states)
    real(R8),             intent(out) :: n(mesh%np, batch%n_states)
    
    integer :: ist
    real(R8), allocatable :: y(:), x(:), s(:)
    type(state_t), pointer :: state

    call push_sub("states_batch_calc_hf_terms")

    do ist = 1, batch%n_states
      state => states_batch_get(batch, ist)

      qns(ist) = state_qn(state)

      !Calculate the total potential part of HF
      allocate(y(mesh%np))

      call ypot_hf_calc(mesh, batch, state, y)
      
      v(:,ist) = y/mesh%r

      deallocate(y)

      n(:,ist) = M_ZERO
      !Calculate the full non homogeneous part of the HF ode
      allocate(x(mesh%np), s(mesh%np))

      call xnon_hf_calc(mesh, batch, state, x)
      call sumnon_hf_calc(mesh, z, batch, state, s)

      n(:, ist) = -x/mesh%r + s

      deallocate(x, s)
    
    end do
 
    call pop_sub() 
  end subroutine states_batch_calc_hf_terms

  !-----------------------------------------------------------------------
  !> Calculates the exchange energy by explicitly evaluating all the
  !> exchange integrals between all the states of the batch.
  !-----------------------------------------------------------------------
  subroutine states_batch_exchange_energy(mesh, batch, ex)
    type(mesh_t),         intent(in) :: mesh
    type(states_batch_t), intent(in) :: batch
    real(R8),             intent(out):: ex

    integer :: n_states, ia, ib, l_a, l_b, l
    real(R8) :: r
    type(state_t), pointer :: state_a, state_b

    call push_sub("state_batch_exchange_energy")

    ex = M_ZERO
    n_states = batch%n_states

    do ia = 1, n_states
      do ib = 1, n_states
        state_a => states_batch_get(batch, ia)
        state_b => states_batch_get(batch, ib)
        l_a = state_a%qn%l
        l_b = state_b%qn%l
        do l = abs(l_a-l_b), (l_a+l_b)
          r = state_r_integral(mesh, state_a, state_b, state_b, state_a, l)
          ex = ex + r*state_exchange_coefficients(state_a, state_b, l)
        end do
      end do
    end do
    ex = -M_HALF*ex

    call pop_sub()
  end subroutine states_batch_exchange_energy

  !-----------------------------------------------------------------------
  !> Returns the Y part of the Hartree-Fock semilocal potential.
  !-----------------------------------------------------------------------
  subroutine ypot_hf_calc(mesh, batch, state_a, f)
    type(mesh_t),         intent(in)  :: mesh       !< mesh
    type(states_batch_t), intent(in)  :: batch      !< batch of states
    type(state_t),        intent(in)  :: state_a    !< state
    real(R8),             intent(out) :: f(mesh%np) !< function

    integer :: la, ib, kk
    real(R8) :: f_weight
    real(R8), allocatable :: yk_hf(:)
    type(state_t), pointer :: state_b

    call push_sub("ypot_hf_calc")

    allocate(yk_hf(mesh%np))
    f = M_ZERO
    la = state_a%qn%l

    !First part: summation over k
    do kk = 0, 2*la
      call state_hf_yk(mesh, state_a, state_a, kk, yk_hf)

      if (kk == 0) then
        f_weight = M_ONE
      else if (kk > 0) then
        f_weight = -(M_TWO*la + M_ONE)/(M_FOUR*la + M_ONE)*&
             (gsl_sf_coupling_3j(2*la, 2*kk, 2*la, 0, 0, 0))**M_TWO
      end if

      f = f + (state_charge(state_a) - M_ONE)*f_weight*yk_hf
    end do

    !Do second part: summation over states
    do ib = 1, batch%n_states
      state_b => states_batch_get(batch, ib)

      if (state_a == state_b) cycle

      call state_hf_yk(mesh, state_b, state_b, 0, yk_hf)
      f = f + state_charge(state_b)*yk_hf
    end do

    deallocate(yk_hf)

    call pop_sub()
  end subroutine ypot_hf_calc

  !-----------------------------------------------------------------------
  !> Returns the X function of the non homogeneous part of the Hartree-Fock
  !> equations.
  !-----------------------------------------------------------------------
  subroutine xnon_hf_calc(mesh, batch, state_a, f)
    type(mesh_t),         intent(in)  :: mesh       !< mesh
    type(states_batch_t), intent(in)  :: batch      !< batch of states
    type(state_t),        intent(in)  :: state_a    !< 
    real(R8),             intent(out) :: f(mesh%np) !< function

    integer :: la, lb, ib, kk
    real(R8) :: g_weight
    real(R8), allocatable :: yk_hf(:)
    type(state_t), pointer :: state_b

    call push_sub("xnon_hf_calc")
      
    !L number of wanted function
    la = state_a%qn%l

    !Calculate X
    f = M_ZERO
    allocate(yk_hf(mesh%np))
    do ib = 1, batch%n_states
      state_b => states_batch_get(batch, ib)

      if (.not. state_a == state_b) then
        lb = state_b%qn%l

        do kk = abs(la - lb), la + lb
          !Summing over all k
          call state_hf_yk(mesh, state_b, state_a, kk, yk_hf)

          g_weight = -M_HALF*(gsl_sf_coupling_3j(2*la, 2*kk, 2*lb, 0, 0, 0))**M_TWO

          f = f + state_charge(state_b)*g_weight*yk_hf*state_b%wf(:,1)
        end do
      end if
    end do
    deallocate(yk_hf)

    call pop_sub()
  end subroutine xnon_hf_calc

  !-----------------------------------------------------------------------
  !> Returns the off diagonal enregy parameters e_ab
  !-----------------------------------------------------------------------
  function off_diagonal_energy(mesh, z, batch, state_a, state_b)
    type(mesh_t),          intent(in) :: mesh       !< mesh
    real(R8),              intent(in) :: z
    type(states_batch_t),  intent(in) :: batch      !< batch of states
    type(state_t),         intent(in) :: state_a    !< 
    type(state_t),         intent(in) :: state_b    !< 
    real(R8) :: off_diagonal_energy

    real(R8) :: occ_a, occ_b, ioc
    integer  :: la, lb, ip
    real(R8), allocatable :: xn_a(:), xn_b(:), yp_a(:), yp_b(:), integrand(:), &
                              dwf_a(:), dwf_b(:), d2wf_a(:), d2wf_b(:), rhs_a(:), rhs_b(:)

    call push_sub("off_diagonal_energy")

    ASSERT(.not. (state_a == state_b))

    occ_a = state_charge(state_a)
    la  = state_a%qn%l

    occ_b = state_charge(state_b)
    lb  = state_b%qn%l

    ! Two cases: equal occupation and different occupation
    ! For numerical safety a threshold value should be used
    ! instead of numerical equality
    ! Check later.
    allocate(yp_a(mesh%np), yp_b(mesh%np), xn_a(mesh%np), xn_b(mesh%np), integrand(mesh%np))
    yp_a = M_ZERO
    yp_b = M_ZERO
    xn_a = M_ZERO
    xn_b = M_ZERO
    integrand = M_ZERO

    call ypot_hf_calc(mesh, batch, state_a, yp_a)
    call ypot_hf_calc(mesh, batch, state_b, yp_b)
    call xnon_hf_calc(mesh, batch, state_a, xn_a)
    call xnon_hf_calc(mesh, batch, state_b, xn_b)
    
    !Case different occupations
    if (occ_a /= occ_b) then
      ioc = occ_b/(occ_b - occ_a)
      
      integrand = state_b%wf(:,1)*xn_a - state_a%wf(:,1)*xn_b + &
                  state_a%wf(:,1)*(yp_a-yp_b)*state_b%wf(:,1)

      off_diagonal_energy = ioc*mesh_integrate(mesh, integrand, dv=mesh%r)

      !Case same occupations
    else if (occ_a == occ_b) then

      allocate( dwf_a(mesh%np), dwf_b(mesh%np), d2wf_a(mesh%np), d2wf_b(mesh%np), &
                  rhs_a(mesh%np), rhs_b(mesh%np))

      dwf_a = mesh_derivative(mesh, state_a%wf(:,1))
      dwf_b = mesh_derivative(mesh, state_b%wf(:,1))
      
      d2wf_a = mesh_derivative2(mesh, state_a%wf(:,1))
      d2wf_b = mesh_derivative2(mesh, state_b%wf(:,1))

      do ip=1, mesh%np
        integrand(ip) = M_HALF*( (state_b%wf(ip,1)*xn_a(ip)+state_a%wf(ip,1)*xn_b(ip))/mesh%r(ip) +&
                        state_b%wf(ip,1)*state_a%wf(ip,1)*(yp_a(ip)+yp_b(ip))*M_HALF/mesh%r(ip) )

        rhs_a(ip) = state_a%wf(ip,1)*(-M_HALF*d2wf_b(ip)- &
                                         dwf_b(ip)/mesh%r(ip)- &
                                         z*state_b%wf(ip,1)/mesh%r(ip) + &
                                         real(lb,R8)*( real(lb,R8) + M_ONE)*M_HALF*state_b%wf(ip,1)/mesh%r(ip)**2)

        rhs_b(ip) = state_b%wf(ip,1)*(-M_HALF*d2wf_a(ip)- &
                                         dwf_a(ip)/mesh%r(ip)- &
                                         z*state_a%wf(ip,1)/mesh%r(ip) + &
                                         real(la,R8)*( real(la,R8) + M_ONE)*M_HALF*state_a%wf(ip,1)/mesh%r(ip)**2)
      end do
        
      off_diagonal_energy = M_HALF*(mesh_integrate(mesh, rhs_a) + mesh_integrate(mesh, rhs_b) ) + mesh_integrate(mesh, integrand)
        
      deallocate(rhs_a, rhs_b, dwf_a, dwf_b, d2wf_a, d2wf_b)
    end if

    deallocate(integrand, yp_a, yp_b, xn_a, xn_b)

    call pop_sub()
  end function off_diagonal_energy

  !-----------------------------------------------------------------------
  !> Returns the summation over non diagonal energy elements and states.
  !-----------------------------------------------------------------------
  subroutine sumnon_hf_calc(mesh, z, batch, state_a, f)
    type(mesh_t),         intent(in)  :: mesh       !< mesh
    real(R8),             intent(in) :: z
    type(states_batch_t), intent(in)  :: batch      !< batch of states
    type(state_t),        intent(in)  :: state_a    !< 
    real(R8),             intent(out) :: f(mesh%np) !< function

    integer  :: la, lb, ib
    real(R8) :: off
    type(state_t), pointer :: state_b

    call push_sub("sumnon_hf_calc")
 
    la = state_a%qn%l

    f = M_ZERO
    do ib = 1, batch%n_states
      state_b => states_batch_get(batch, ib)

      lb = state_b%qn%l
      !Main quantum number must be different
      if ( la == lb .and. .not. state_a == state_b) then

        off = off_diagonal_energy(mesh, z, batch, state_a, state_b)

        f = f + off*state_b%wf(:,1)
      end if
    end do

    call pop_sub()
  end subroutine sumnon_hf_calc

end module states_batch_m
